Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-911: Add srvServiceName URI option #1235

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const URI_OPTIONS: &[&str] = &[
"waitqueuetimeoutms",
"wtimeoutms",
"zlibcompressionlevel",
"srvservicename",
];

/// Reserved characters as defined by [Section 2.2 of RFC-3986](https://tools.ietf.org/html/rfc3986#section-2.2).
Expand Down Expand Up @@ -521,6 +522,9 @@ pub struct ClientOptions {
/// By default, no default database is specified.
pub default_database: Option<String>,

/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
pub srv_service_name: Option<String>,

#[builder(setter(skip))]
#[derivative(Debug = "ignore")]
pub(crate) socket_timeout: Option<Duration>,
Expand Down Expand Up @@ -676,6 +680,8 @@ impl Serialize for ClientOptions {
loadbalanced: &'a Option<bool>,

srvmaxhosts: Option<i32>,

srvservicename: &'a Option<String>,
}

let client_options = ClientOptionsHelper {
Expand Down Expand Up @@ -709,6 +715,7 @@ impl Serialize for ClientOptions {
.map(|v| v.try_into())
.transpose()
.map_err(serde::ser::Error::custom)?,
srvservicename: &self.srv_service_name,
};

client_options.serialize(serializer)
Expand Down Expand Up @@ -865,6 +872,9 @@ pub struct ConnectionString {
/// Limit on the number of mongos connections that may be created for sharded topologies.
pub srv_max_hosts: Option<u32>,

/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
pub srv_service_name: Option<String>,

wait_queue_timeout: Option<Duration>,
tls_insecure: Option<bool>,

Expand Down Expand Up @@ -900,11 +910,16 @@ impl Default for HostInfo {
}

impl HostInfo {
async fn resolve(self, resolver_config: Option<ResolverConfig>) -> Result<ResolvedHostInfo> {
async fn resolve(
self,
resolver_config: Option<ResolverConfig>,
srv_service_name: Option<String>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to bundle srv_service_name into resolver_config? Mechanically, they're always passed around as a pair but I could see going either way on whether it's the right thing conceptually.

) -> Result<ResolvedHostInfo> {
Ok(match self {
Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
Self::DnsRecord(hostname) => {
let mut resolver = SrvResolver::new(resolver_config.clone()).await?;
let mut resolver =
SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
let config = resolver.resolve_client_options(&hostname).await?;
ResolvedHostInfo::DnsRecord { hostname, config }
}
Expand Down Expand Up @@ -1486,6 +1501,12 @@ impl ConnectionString {
ConnectionStringParts::default()
};

if conn_str.srv_service_name.is_some() && !srv {
return Err(Error::invalid_argument(
"srvServiceName cannot be specified with a non-SRV URI",
));
}

if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
if !srv {
return Err(Error::invalid_argument(
Expand Down Expand Up @@ -1976,6 +1997,9 @@ impl ConnectionString {
k @ "srvmaxhosts" => {
self.srv_max_hosts = Some(get_u32!(value, k));
}
"srvservicename" => {
self.srv_service_name = Some(value.to_string());
}
k @ "tls" | k @ "ssl" => {
let tls = get_bool!(value, k);

Expand Down
5 changes: 4 additions & 1 deletion src/client/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ impl Action for ParseConnectionString {
options.resolver_config.clone_from(&self.resolver_config);
}

let resolved = host_info.resolve(self.resolver_config).await?;
let resolved = host_info
.resolve(self.resolver_config, options.srv_service_name.clone())
.await?;
options.hosts = match resolved {
ResolvedHostInfo::HostIdentifiers(hosts) => hosts,
ResolvedHostInfo::DnsRecord {
Expand Down Expand Up @@ -159,6 +161,7 @@ impl ClientOptions {
#[cfg(feature = "tracing-unstable")]
tracing_max_document_length_bytes: None,
srv_max_hosts: conn_str.srv_max_hosts,
srv_service_name: conn_str.srv_service_name,
}
}
}
2 changes: 0 additions & 2 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
"maxPoolSize=0 does not error",
// TODO RUST-226: unskip this test
"Valid tlsCertificateKeyFilePassword is parsed correctly",
// TODO RUST-911: unskip this test
"SRV URI with custom srvServiceName",
// TODO RUST-229: unskip the following tests
"Single IP literal host without port",
"Single IP literal host with port",
Expand Down
12 changes: 10 additions & 2 deletions src/sdam/srv_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ impl SrvPollingMonitor {
}

fn rescan_interval(&self) -> Duration {
std::cmp::max(self.rescan_interval, MIN_RESCAN_SRV_INTERVAL)
if cfg!(test) {
self.rescan_interval
} else {
std::cmp::max(self.rescan_interval, MIN_RESCAN_SRV_INTERVAL)
}
}

async fn execute(mut self) {
Expand Down Expand Up @@ -130,7 +134,11 @@ impl SrvPollingMonitor {
return Ok(resolver);
}

let resolver = SrvResolver::new(self.client_options.resolver_config().cloned()).await?;
let resolver = SrvResolver::new(
self.client_options.resolver_config().cloned(),
self.client_options.srv_service_name.clone(),
)
.await?;

// Since the connection was not `Some` above, this will always insert the new connection and
// return a reference to it.
Expand Down
24 changes: 24 additions & 0 deletions src/sdam/srv_polling/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,27 @@ async fn srv_max_hosts_random() {
assert_eq!(2, actual.len());
assert!(actual.contains(&localhost_test_build_10gen(27017)));
}

#[tokio::test]
async fn srv_service_name() {
let rescan_interval = Duration::from_secs(1);
let new_hosts = vec![
ServerAddress::Tcp {
host: "localhost.test.build.10gen.cc".to_string(),
port: Some(27019),
},
ServerAddress::Tcp {
host: "localhost.test.build.10gen.cc".to_string(),
port: Some(27020),
},
];
let uri = "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname";
let mut options = ClientOptions::parse(uri).await.unwrap();
// override the min_ttl to speed up lookup interval
options.original_srv_info.as_mut().unwrap().min_ttl = rescan_interval;
options.test_options_mut().mock_lookup_hosts = Some(make_lookup_hosts(new_hosts.clone()));
let mut topology = Topology::new(options).unwrap();
topology.watch().wait_until_initialized().await;
tokio::time::sleep(rescan_interval * 2).await;
assert_eq!(topology.server_addresses(), new_hosts.into_iter().collect());
}
17 changes: 14 additions & 3 deletions src/srv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,21 @@ pub(crate) enum DomainMismatch {
#[cfg(feature = "dns-resolver")]
pub(crate) struct SrvResolver {
resolver: crate::runtime::AsyncResolver,
srv_service_name: Option<String>,
}

#[cfg(feature = "dns-resolver")]
impl SrvResolver {
pub(crate) async fn new(config: Option<ResolverConfig>) -> Result<Self> {
pub(crate) async fn new(
config: Option<ResolverConfig>,
srv_service_name: Option<String>,
) -> Result<Self> {
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;

Ok(Self { resolver })
Ok(Self {
resolver,
srv_service_name,
})
}

pub(crate) async fn resolve_client_options(
Expand Down Expand Up @@ -149,7 +156,11 @@ impl SrvResolver {
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
let lookup_hostname = format!(
"_{}._tcp.{}",
self.srv_service_name.as_deref().unwrap_or("mongodb"),
original_hostname
);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
Expand Down
18 changes: 0 additions & 18 deletions src/test/spec/initial_dns_seedlist_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ struct ParsedOptions {
}

async fn run_test(mut test_file: TestFile) {
if let Some(ref options) = test_file.options {
// TODO RUST-933: Remove this skip.
let skip = if options.srv_service_name.is_some() {
Some("srvServiceName")
} else {
None
};

if let Some(skip) = skip {
log_uncaptured(format!(
"skipping initial_dns_seedlist_discovery test case due to unsupported connection \
string option: {}",
skip,
));
return;
}
}

// "encoded-userinfo-and-db.json" specifies a database name with a question mark which is
// disallowed on Windows. See
// <https://www.mongodb.com/docs/manual/reference/limits/#restrictions-on-db-names>
Expand Down
Loading