diff --git a/Cargo.lock b/Cargo.lock index 0df8f1c..a2ed002 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15,6 +15,7 @@ dependencies = [ "hyper-rustls", "hyper-util", "openssl", + "parking_lot", "pem", "ring", "rustls", @@ -340,6 +341,16 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.21" @@ -451,6 +462,29 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.4", +] + [[package]] name = "pem" version = "3.0.3" @@ -517,6 +551,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags", +] + [[package]] name = "ring" version = "0.17.8" @@ -585,6 +628,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.197" diff --git a/Cargo.toml b/Cargo.toml index 47ed37c..7740ddf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ ring = { version = "0.17", features = ["std"], optional = true } hyper-rustls = { version = "0.26.0", default-features = false, features = ["http2", "webpki-roots", "ring"] } rustls-pemfile = "2.1.1" rustls = "0.22.0" +parking_lot = "0.12" [dev-dependencies] argparse = "0.2" diff --git a/src/client.rs b/src/client.rs index 1ed5ad2..e331f95 100644 --- a/src/client.rs +++ b/src/client.rs @@ -123,7 +123,7 @@ impl Client { /// See [ErrorReason](enum.ErrorReason.html) for possible errors. #[cfg_attr(feature = "tracing", ::tracing::instrument)] pub async fn send(&self, payload: T) -> Result { - let request = self.build_request(payload); + let request = self.build_request(payload)?; let requesting = self.http_client.request(request); let response = requesting.await?; @@ -152,7 +152,7 @@ impl Client { } } - fn build_request(&self, payload: T) -> hyper::Request> { + fn build_request(&self, payload: T) -> Result>, Error> { let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token()); let mut builder = hyper::Request::builder() @@ -180,18 +180,16 @@ impl Client { builder = builder.header("apns-topic", apns_topic.as_bytes()); } if let Some(ref signer) = self.signer { - let auth = signer - .with_signature(|signature| format!("Bearer {}", signature)) - .unwrap(); + let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?; builder = builder.header(AUTHORIZATION, auth.as_bytes()); } - let payload_json = payload.to_json_string().unwrap(); + let payload_json = payload.to_json_string()?; builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes()); let request_body = Full::from(payload_json.into_bytes()).boxed(); - builder.body(request_body).unwrap() + builder.body(request_body).map_err(Error::BuildRequestError) } } @@ -247,7 +245,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri); @@ -258,7 +256,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Sandbox); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri); @@ -269,17 +267,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); } + #[test] + fn test_request_invalid() { + let builder = DefaultNotificationBuilder::new(); + let payload = builder.build("\r\n", Default::default()); + let client = Client::new(default_connector(), None, Endpoint::Production); + let request = client.build_request(payload); + + assert!(matches!(request, Err(Error::BuildRequestError(_)))); + } + #[test] fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); } @@ -289,7 +297,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload.clone()); + let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -301,7 +309,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); } @@ -319,7 +327,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), Some(signer), Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); } @@ -333,7 +341,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }; let payload = builder.build("a_test_id", options); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); assert_eq!("background", apns_push_type); @@ -344,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); assert_eq!(None, apns_priority); @@ -363,7 +371,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); assert_eq!("5", apns_priority); @@ -382,7 +390,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); assert_eq!("10", apns_priority); @@ -395,7 +403,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); assert_eq!(None, apns_id); @@ -414,7 +422,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); assert_eq!("a-test-apns-id", apns_id); @@ -427,7 +435,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); assert_eq!(None, apns_expiration); @@ -446,7 +454,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); assert_eq!("420", apns_expiration); @@ -459,7 +467,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); assert_eq!(None, apns_collapse_id); @@ -478,7 +486,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); assert_eq!("a_collapse_id", apns_collapse_id); @@ -491,7 +499,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); assert_eq!(None, apns_topic); @@ -510,7 +518,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload); + let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); assert_eq!("a_topic", apns_topic); @@ -521,7 +529,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); let client = Client::new(default_connector(), None, Endpoint::Production); - let request = client.build_request(payload.clone()); + let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes(); let body_str = String::from_utf8(body.to_vec()).unwrap(); diff --git a/src/error.rs b/src/error.rs index 4dfbe12..b818f74 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,6 +44,10 @@ pub enum Error { #[error("Error building TLS config: {0}")] Tls(#[from] rustls::Error), + /// Error while creating the HTTP request + #[error("Failed to construct HTTP request: {0}")] + BuildRequestError(#[source] http::Error), + /// Unexpected private key (only EC keys are supported). #[cfg(all(not(feature = "openssl"), feature = "ring"))] #[error("Unexpected private key: {0}")] diff --git a/src/lib.rs b/src/lib.rs index cc316ad..a59beb5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,6 +106,8 @@ //! } //! # } //! ``` +#![warn(clippy::unwrap_used)] + #[cfg(not(any(feature = "openssl", feature = "ring")))] compile_error!("either feature \"openssl\" or feature \"ring\" has to be enabled"); diff --git a/src/signer.rs b/src/signer.rs index c6dfd90..2d411c6 100644 --- a/src/signer.rs +++ b/src/signer.rs @@ -1,10 +1,8 @@ use crate::error::Error; +use parking_lot::RwLock; use std::io::Read; use std::sync::Arc; -use std::{ - sync::RwLock, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use base64::prelude::*; #[cfg(feature = "openssl")] @@ -138,7 +136,7 @@ impl Signer { self.renew()?; } - let signature = self.signature.read().unwrap(); + let signature = self.signature.read(); #[cfg(feature = "tracing")] { @@ -191,7 +189,7 @@ impl Signer { ); } - let mut signature = self.signature.write().unwrap(); + let mut signature = self.signature.write(); *signature = Signature { key: Self::create_signature(&self.secret, &self.key_id, &self.team_id, issued_at)?, @@ -202,7 +200,7 @@ impl Signer { } fn is_expired(&self) -> bool { - let sig = self.signature.read().unwrap(); + let sig = self.signature.read(); let expiry = get_time() - sig.issued_at; expiry >= self.expire_after_s.as_secs() as i64 }