Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-911: Add srvServiceName URI option #1235

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,11 +910,16 @@ impl Default for HostInfo {
}

impl HostInfo {
async fn resolve(self, resolver_config: Option<ResolverConfig>) -> Result<ResolvedHostInfo> {
async fn resolve(
self,
resolver_config: Option<ResolverConfig>,
srv_service_name: Option<String>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to bundle srv_service_name into resolver_config? Mechanically, they're always passed around as a pair but I could see going either way on whether it's the right thing conceptually.

) -> Result<ResolvedHostInfo> {
Ok(match self {
Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
Self::DnsRecord(hostname) => {
let mut resolver = SrvResolver::new(resolver_config.clone(), None).await?;
let mut resolver =
SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
let config = resolver.resolve_client_options(&hostname).await?;
ResolvedHostInfo::DnsRecord { hostname, config }
}
Expand Down
4 changes: 3 additions & 1 deletion src/client/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ impl Action for ParseConnectionString {
options.resolver_config.clone_from(&self.resolver_config);
}

let resolved = host_info.resolve(self.resolver_config).await?;
let resolved = host_info
.resolve(self.resolver_config, options.srv_service_name.clone())
.await?;
options.hosts = match resolved {
ResolvedHostInfo::HostIdentifiers(hosts) => hosts,
ResolvedHostInfo::DnsRecord {
Expand Down
2 changes: 0 additions & 2 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
"maxPoolSize=0 does not error",
// TODO RUST-226: unskip this test
"Valid tlsCertificateKeyFilePassword is parsed correctly",
// TODO RUST-911: unskip this test
"SRV URI with custom srvServiceName",
// TODO RUST-229: unskip the following tests
"Single IP literal host without port",
"Single IP literal host with port",
Expand Down
2 changes: 1 addition & 1 deletion src/sdam/srv_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl SrvPollingMonitor {

let resolver = SrvResolver::new(
self.client_options.resolver_config().cloned(),
Option::Some(self.client_options.clone()),
self.client_options.srv_service_name.clone(),
)
.await?;

Expand Down
27 changes: 27 additions & 0 deletions src/sdam/srv_polling/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,30 @@ async fn srv_max_hosts_random() {
assert_eq!(2, actual.len());
assert!(actual.contains(&localhost_test_build_10gen(27017)));
}

#[tokio::test]
async fn srv_service_name() {
if get_client_options().await.srv_service_name.is_none() {
log_uncaptured("skipping srv_service_name due to no custom srvServiceName");
return;
}
let mut options = ClientOptions::new_srv();
let hosts = vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];
let rescan_interval = options.original_srv_info.as_ref().cloned().unwrap().min_ttl;
options.hosts.clone_from(&hosts);
options.srv_service_name = Some("customname".to_string());
options.test_options_mut().mock_lookup_hosts = Some(make_lookup_hosts(vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
]));
let mut topology = Topology::new(options).unwrap();
topology.watch().wait_until_initialized().await;
tokio::time::sleep(rescan_interval * 2).await;
assert_eq!(
hosts.into_iter().collect::<HashSet<_>>(),
topology.server_addresses()
);
}
26 changes: 10 additions & 16 deletions src/srv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ use std::time::Duration;

#[cfg(feature = "dns-resolver")]
use crate::error::ErrorKind;
use crate::{
client::options::ResolverConfig,
error::Result,
options::{ClientOptions, ServerAddress},
};
use crate::{client::options::ResolverConfig, error::Result, options::ServerAddress};

#[derive(Debug)]
pub(crate) struct ResolvedConfig {
Expand Down Expand Up @@ -94,20 +90,20 @@ pub(crate) enum DomainMismatch {
#[cfg(feature = "dns-resolver")]
pub(crate) struct SrvResolver {
resolver: crate::runtime::AsyncResolver,
client_options: Option<ClientOptions>,
srv_service_name: Option<String>,
}

#[cfg(feature = "dns-resolver")]
impl SrvResolver {
pub(crate) async fn new(
config: Option<ResolverConfig>,
client_options: Option<ClientOptions>,
srv_service_name: Option<String>,
) -> Result<Self> {
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;

Ok(Self {
resolver,
client_options,
srv_service_name,
})
}

Expand Down Expand Up @@ -160,15 +156,13 @@ impl SrvResolver {
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let default_service_name = "mongodb".to_string();
let service_name = match &self.client_options {
None => default_service_name,
Some(opts) => opts
.srv_service_name
let lookup_hostname = format!(
"_{}._tcp.{}",
self.srv_service_name
.clone()
.unwrap_or(default_service_name),
};
let lookup_hostname = format!("_{}._tcp.{}", service_name, original_hostname);
.unwrap_or("mongodb".to_string()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny style nit: we can do the following to avoid allocating an extra string here

self.srv_service_name
    .as_deref() // converts into an Option<&str> to get a reference to the inner string
    .unwrap_or("mongodb")

original_hostname
);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
Expand Down
18 changes: 0 additions & 18 deletions src/test/spec/initial_dns_seedlist_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ struct ParsedOptions {
}

async fn run_test(mut test_file: TestFile) {
if let Some(ref options) = test_file.options {
// TODO RUST-933: Remove this skip.
let skip = if options.srv_service_name.is_some() {
Some("srvServiceName")
} else {
None
};

if let Some(skip) = skip {
log_uncaptured(format!(
"skipping initial_dns_seedlist_discovery test case due to unsupported connection \
string option: {}",
skip,
));
return;
}
}

// "encoded-userinfo-and-db.json" specifies a database name with a question mark which is
// disallowed on Windows. See
// <https://www.mongodb.com/docs/manual/reference/limits/#restrictions-on-db-names>
Expand Down