From 0648e6d44f177f6468ba2d41b01c085cac16c420 Mon Sep 17 00:00:00 2001 From: Gabriele Musco Date: Wed, 25 Sep 2024 15:32:37 +0200 Subject: [PATCH] feat: request, response and response body hooks --- src/async_impl/client.rs | 47 ++++++++++++++++++++++++++ src/async_impl/hooks.rs | 23 +++++++++++++ src/async_impl/mod.rs | 2 ++ src/async_impl/response.rs | 20 +++++++++-- src/lib.rs | 1 + tests/hooks.rs | 69 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 src/async_impl/hooks.rs create mode 100644 tests/hooks.rs diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 095adf4d8..6dba9aefc 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -23,6 +23,7 @@ use std::task::{Context, Poll}; use tokio::time::Sleep; use super::decoder::Accepts; +use super::hooks::{RequestHook, ResponseBodyHook, ResponseHook}; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::Body; @@ -167,6 +168,9 @@ struct Config { quic_send_window: Option, dns_overrides: HashMap>, dns_resolver: Option>, + request_hook: Option>, + response_hook: Option>, + response_body_hook: Option>, } impl Default for ClientBuilder { @@ -265,6 +269,9 @@ impl ClientBuilder { #[cfg(feature = "http3")] quic_send_window: None, dns_resolver: None, + request_hook: None, + response_hook: None, + response_body_hook: None, }, } } @@ -777,6 +784,9 @@ impl ClientBuilder { proxies, proxies_maybe_http_auth, https_only: config.https_only, + request_hook: config.request_hook, + response_hook: config.response_hook, + response_body_hook: config.response_body_hook, }), }) } @@ -1889,6 +1899,24 @@ impl ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Set request hook + pub fn request_hook(mut self, hook: Arc) -> ClientBuilder { + self.config.request_hook = Some(hook); + self + } + + /// Set response hook + pub fn response_hook(mut self, hook: Arc) -> ClientBuilder { + self.config.response_hook = Some(hook); + self + } + + /// Set response body hook + pub fn response_body_hook(mut self, hook: Arc) -> ClientBuilder { + self.config.response_body_hook = Some(hook); + self + } } type HyperClient = hyper_util::client::legacy::Client; @@ -2007,6 +2035,11 @@ impl Client { } pub(super) fn execute_request(&self, req: Request) -> Pending { + let req = if let Some(req_hook) = self.inner.request_hook.as_ref() { + req_hook.intercept(req) + } else { + req + }; let (method, url, mut headers, body, timeout, version) = req.pieces(); if url.scheme() != "http" && url.scheme() != "https" { return Pending::new_err(error::url_bad_scheme(url)); @@ -2105,6 +2138,8 @@ impl Client { total_timeout, read_timeout_fut, read_timeout: self.inner.read_timeout, + response_hook: self.inner.response_hook.clone(), + response_body_hook: self.inner.response_body_hook.clone(), }), } } @@ -2314,6 +2349,9 @@ struct ClientRef { proxies: Arc>, proxies_maybe_http_auth: bool, https_only: bool, + request_hook: Option>, + response_hook: Option>, + response_body_hook: Option>, } impl ClientRef { @@ -2379,6 +2417,9 @@ pin_project! { client: Arc, + response_hook: Option>, + response_body_hook: Option>, + #[pin] in_flight: ResponseFuture, #[pin] @@ -2745,7 +2786,13 @@ impl Future for PendingRequest { self.client.accepts, self.total_timeout.take(), self.read_timeout, + self.response_body_hook.clone(), ); + let res = if let Some(res_hook) = self.response_hook.as_ref() { + res_hook.intercept(res) + } else { + res + }; return Poll::Ready(Ok(res)); } } diff --git a/src/async_impl/hooks.rs b/src/async_impl/hooks.rs new file mode 100644 index 000000000..5cbd108df --- /dev/null +++ b/src/async_impl/hooks.rs @@ -0,0 +1,23 @@ +//! Hooks to intercept the request, response and response body + +use bytes::Bytes; + +use super::{Request, Response}; + +/// Hook that gets called before sending the request, right after it's constructed +pub trait RequestHook: Send + Sync { + /// Intercept the request and return it with or without changes + fn intercept(&self, req: Request) -> Request; +} + +/// Hook that gets called once the request is completed and headers have been received +pub trait ResponseHook: Send + Sync { + /// Intercept the response and return it with or without changes + fn intercept(&self, res: Response) -> Response; +} + +/// Hook that gets called once the request is completed and the full body has been received +pub trait ResponseBodyHook: Send + Sync { + /// Intercept the response body and return it with or without changes + fn intercept(&self, body: Bytes) -> Bytes; +} diff --git a/src/async_impl/mod.rs b/src/async_impl/mod.rs index 5d99ef027..5553ba0a3 100644 --- a/src/async_impl/mod.rs +++ b/src/async_impl/mod.rs @@ -10,6 +10,8 @@ pub(crate) use self::decoder::Decoder; pub mod body; pub mod client; pub mod decoder; +pub mod hooks; + pub mod h3_client; #[cfg(feature = "multipart")] pub mod multipart; diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 23e30d3ed..83f07a6a2 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -1,6 +1,7 @@ use std::fmt; use std::net::SocketAddr; use std::pin::Pin; +use std::sync::Arc; use std::time::Duration; use bytes::Bytes; @@ -16,6 +17,7 @@ use url::Url; use super::body::Body; use super::decoder::{Accepts, Decoder}; +use super::hooks::ResponseBodyHook; use crate::async_impl::body::ResponseBody; #[cfg(feature = "cookies")] use crate::cookie; @@ -31,6 +33,7 @@ pub struct Response { // Boxed to save space (11 words to 1 word), and it's not accessed // frequently internally. url: Box, + response_body_hook: Option>, } impl Response { @@ -40,6 +43,7 @@ impl Response { accepts: Accepts, total_timeout: Option>>, read_timeout: Option, + response_body_hook: Option>, ) -> Response { let (mut parts, body) = res.into_parts(); let decoder = Decoder::detect( @@ -52,6 +56,7 @@ impl Response { Response { res, url: Box::new(url), + response_body_hook, } } @@ -218,6 +223,7 @@ impl Response { let full = self.bytes().await?; let (text, _, _) = encoding.decode(&full); + Ok(text.into_owned()) } @@ -288,9 +294,18 @@ impl Response { pub async fn bytes(self) -> crate::Result { use http_body_util::BodyExt; - BodyExt::collect(self.res.into_body()) + let bytes = BodyExt::collect(self.res.into_body()) .await - .map(|buf| buf.to_bytes()) + .map(|buf| buf.to_bytes())?; + + let res_body_hook = self.response_body_hook.clone(); + let bytes = if let Some(res_body_hook) = res_body_hook.as_ref() { + res_body_hook.intercept(bytes) + } else { + bytes + }; + + Ok(bytes) } /// Stream a chunk of the response body. @@ -468,6 +483,7 @@ impl> From> for Response { Response { res, url: Box::new(url), + response_body_hook: None, } } } diff --git a/src/lib.rs b/src/lib.rs index cf3d39d0f..4d328be5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -352,6 +352,7 @@ if_hyper! { #[cfg(feature = "multipart")] pub use self::async_impl::multipart; + pub use self::async_impl::hooks; mod async_impl; #[cfg(feature = "blocking")] diff --git a/tests/hooks.rs b/tests/hooks.rs new file mode 100644 index 000000000..a8732bc08 --- /dev/null +++ b/tests/hooks.rs @@ -0,0 +1,69 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::sync::{Arc, Mutex}; + +use bytes::Bytes; +use support::server; + +use reqwest::{ + hooks::{RequestHook, ResponseBodyHook, ResponseHook}, + Client, +}; + +#[derive(Default)] +struct MyHook { + pub req_visited: Arc>, + pub res_visited: Arc>, + pub body_visited: Arc>, +} + +impl RequestHook for MyHook { + fn intercept(&self, req: reqwest::Request) -> reqwest::Request { + *self.req_visited.lock().unwrap() = true; + req + } +} + +impl ResponseHook for MyHook { + fn intercept(&self, res: reqwest::Response) -> reqwest::Response { + *self.res_visited.lock().unwrap() = true; + res + } +} + +impl ResponseBodyHook for MyHook { + fn intercept(&self, body: Bytes) -> Bytes { + *self.body_visited.lock().unwrap() = true; + body + } +} + +#[tokio::test] +async fn full_hook_chain() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let hook = Arc::new(MyHook::default()); + + let client = Client::builder() + .request_hook(hook.clone()) + .response_hook(hook.clone()) + .response_body_hook(hook.clone()) + .build() + .unwrap(); + + let res = client + .get(&format!("http://{}/text", server.addr())) + .send() + .await + .expect("Failed to get"); + assert_eq!(res.content_length(), Some(5)); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); + assert!(*hook.req_visited.lock().unwrap()); + assert!(*hook.res_visited.lock().unwrap()); + assert!(*hook.body_visited.lock().unwrap()); +}