Skip to content

Commit

Permalink
feat: request, response and response body hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriele Musco committed Sep 26, 2024
1 parent d85f44b commit 1f5e562
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 1 deletion.
47 changes: 47 additions & 0 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::task::{Context, Poll};
use tokio::time::Sleep;

use super::decoder::Accepts;
use super::hooks::{RequestHook, ResponseBodyHook, ResponseHook};
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
Expand Down Expand Up @@ -167,6 +168,9 @@ struct Config {
quic_send_window: Option<u64>,
dns_overrides: HashMap<String, Vec<SocketAddr>>,
dns_resolver: Option<Arc<dyn Resolve>>,
request_hook: Option<Arc<dyn RequestHook>>,
response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -265,6 +269,9 @@ impl ClientBuilder {
#[cfg(feature = "http3")]
quic_send_window: None,
dns_resolver: None,
request_hook: None,
response_hook: None,
response_body_hook: None,
},
}
}
Expand Down Expand Up @@ -777,6 +784,9 @@ impl ClientBuilder {
proxies,
proxies_maybe_http_auth,
https_only: config.https_only,
request_hook: config.request_hook,
response_hook: config.response_hook,
response_body_hook: config.response_body_hook,
}),
})
}
Expand Down Expand Up @@ -1889,6 +1899,24 @@ impl ClientBuilder {
self.config.quic_send_window = Some(value);
self
}

/// Set request hook
pub fn request_hook(mut self, hook: Arc<dyn RequestHook>) -> ClientBuilder {
self.config.request_hook = Some(hook);
self
}

/// Set response hook
pub fn response_hook(mut self, hook: Arc<dyn ResponseHook>) -> ClientBuilder {
self.config.response_hook = Some(hook);
self
}

/// Set response body hook
pub fn response_body_hook(mut self, hook: Arc<dyn ResponseBodyHook>) -> ClientBuilder {
self.config.response_body_hook = Some(hook);
self
}
}

type HyperClient = hyper_util::client::legacy::Client<Connector, super::Body>;
Expand Down Expand Up @@ -2007,6 +2035,11 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let req = if let Some(req_hook) = self.inner.request_hook.as_ref() {
req_hook.intercept(req)
} else {
req
};
let (method, url, mut headers, body, timeout, version) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
Expand Down Expand Up @@ -2105,6 +2138,8 @@ impl Client {
total_timeout,
read_timeout_fut,
read_timeout: self.inner.read_timeout,
response_hook: self.inner.response_hook.clone(),
response_body_hook: self.inner.response_body_hook.clone(),
}),
}
}
Expand Down Expand Up @@ -2314,6 +2349,9 @@ struct ClientRef {
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
https_only: bool,
request_hook: Option<Arc<dyn RequestHook>>,
response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl ClientRef {
Expand Down Expand Up @@ -2379,6 +2417,9 @@ pin_project! {

client: Arc<ClientRef>,

response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,

#[pin]
in_flight: ResponseFuture,
#[pin]
Expand Down Expand Up @@ -2745,7 +2786,13 @@ impl Future for PendingRequest {
self.client.accepts,
self.total_timeout.take(),
self.read_timeout,
self.response_body_hook.clone(),
);
let res = if let Some(res_hook) = self.response_hook.as_ref() {
res_hook.intercept(res)
} else {
res
};
return Poll::Ready(Ok(res));
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/async_impl/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//! Hooks to intercept the request, response and response body

use super::{Request, Response};

/// Hook that gets called before sending the request, right after it's constructed
pub trait RequestHook: Send + Sync {
/// Intercept the request and return it with or without changes
fn intercept(&self, req: Request) -> Request;
}

/// Hook that gets called once the request is completed and headers have been received
pub trait ResponseHook: Send + Sync {
/// Intercept the response and return it with or without changes
fn intercept(&self, res: Response) -> Response;
}

/// Hook that gets called once the request is completed and the full body has been received
pub trait ResponseBodyHook: Send + Sync {
/// Intercept the response body and return it with or without changes
fn intercept(&self, body: String) -> String;
}
2 changes: 2 additions & 0 deletions src/async_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub(crate) use self::decoder::Decoder;
pub mod body;
pub mod client;
pub mod decoder;
pub mod hooks;

pub mod h3_client;
#[cfg(feature = "multipart")]
pub mod multipart;
Expand Down
14 changes: 13 additions & 1 deletion src/async_impl/response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use bytes::Bytes;
Expand All @@ -16,6 +17,7 @@ use url::Url;

use super::body::Body;
use super::decoder::{Accepts, Decoder};
use super::hooks::ResponseBodyHook;
use crate::async_impl::body::ResponseBody;
#[cfg(feature = "cookies")]
use crate::cookie;
Expand All @@ -31,6 +33,7 @@ pub struct Response {
// Boxed to save space (11 words to 1 word), and it's not accessed
// frequently internally.
url: Box<Url>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl Response {
Expand All @@ -40,6 +43,7 @@ impl Response {
accepts: Accepts,
total_timeout: Option<Pin<Box<Sleep>>>,
read_timeout: Option<Duration>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
) -> Response {
let (mut parts, body) = res.into_parts();
let decoder = Decoder::detect(
Expand All @@ -52,6 +56,7 @@ impl Response {
Response {
res,
url: Box::new(url),
response_body_hook,
}
}

Expand Down Expand Up @@ -215,10 +220,16 @@ impl Response {
.unwrap_or(default_encoding);
let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);

let res_body_hook = self.response_body_hook.clone();
let full = self.bytes().await?;

let (text, _, _) = encoding.decode(&full);
Ok(text.into_owned())
let text = if let Some(res_body_hook) = res_body_hook.as_ref() {
res_body_hook.intercept(text.into_owned())
} else {
text.into_owned()
};
Ok(text)
}

/// Try to deserialize the response body as JSON.
Expand Down Expand Up @@ -468,6 +479,7 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
Response {
res,
url: Box::new(url),
response_body_hook: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ if_hyper! {
#[cfg(feature = "multipart")]
pub use self::async_impl::multipart;

pub use self::async_impl::hooks;

mod async_impl;
#[cfg(feature = "blocking")]
Expand Down
68 changes: 68 additions & 0 deletions tests/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#![cfg(not(target_arch = "wasm32"))]
#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))]
mod support;

use std::sync::{Arc, Mutex};

use support::server;

use reqwest::{
hooks::{RequestHook, ResponseBodyHook, ResponseHook},
Client,
};

#[derive(Default)]
struct MyHook {
pub req_visited: Arc<Mutex<bool>>,
pub res_visited: Arc<Mutex<bool>>,
pub body_visited: Arc<Mutex<bool>>,
}

impl RequestHook for MyHook {
fn intercept(&self, req: reqwest::Request) -> reqwest::Request {
*self.req_visited.lock().unwrap() = true;
req
}
}

impl ResponseHook for MyHook {
fn intercept(&self, res: reqwest::Response) -> reqwest::Response {
*self.res_visited.lock().unwrap() = true;
res
}
}

impl ResponseBodyHook for MyHook {
fn intercept(&self, body: String) -> String {
*self.body_visited.lock().unwrap() = true;
body
}
}

#[tokio::test]
async fn full_hook_chain() {
let _ = env_logger::try_init();

let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let hook = Arc::new(MyHook::default());

let client = Client::builder()
.request_hook(hook.clone())
.response_hook(hook.clone())
.response_body_hook(hook.clone())
.build()
.unwrap();

let res = client
.get(&format!("http://{}/text", server.addr()))
.send()
.await
.expect("Failed to get");
assert_eq!(res.content_length(), Some(5));
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
assert!(*hook.req_visited.lock().unwrap());
assert!(*hook.res_visited.lock().unwrap());
assert!(*hook.body_visited.lock().unwrap());
}

0 comments on commit 1f5e562

Please sign in to comment.