Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: request, response and response body hooks #2432

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::task::{Context, Poll};
use tokio::time::Sleep;

use super::decoder::Accepts;
use super::hooks::{RequestHook, ResponseBodyHook, ResponseHook};
use super::request::{Request, RequestBuilder};
use super::response::Response;
use super::Body;
Expand Down Expand Up @@ -167,6 +168,9 @@ struct Config {
quic_send_window: Option<u64>,
dns_overrides: HashMap<String, Vec<SocketAddr>>,
dns_resolver: Option<Arc<dyn Resolve>>,
request_hook: Option<Arc<dyn RequestHook>>,
response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl Default for ClientBuilder {
Expand Down Expand Up @@ -265,6 +269,9 @@ impl ClientBuilder {
#[cfg(feature = "http3")]
quic_send_window: None,
dns_resolver: None,
request_hook: None,
response_hook: None,
response_body_hook: None,
},
}
}
Expand Down Expand Up @@ -777,6 +784,9 @@ impl ClientBuilder {
proxies,
proxies_maybe_http_auth,
https_only: config.https_only,
request_hook: config.request_hook,
response_hook: config.response_hook,
response_body_hook: config.response_body_hook,
}),
})
}
Expand Down Expand Up @@ -1889,6 +1899,24 @@ impl ClientBuilder {
self.config.quic_send_window = Some(value);
self
}

/// Set request hook
pub fn request_hook(mut self, hook: Arc<dyn RequestHook>) -> ClientBuilder {
self.config.request_hook = Some(hook);
self
}

/// Set response hook
pub fn response_hook(mut self, hook: Arc<dyn ResponseHook>) -> ClientBuilder {
self.config.response_hook = Some(hook);
self
}

/// Set response body hook
pub fn response_body_hook(mut self, hook: Arc<dyn ResponseBodyHook>) -> ClientBuilder {
self.config.response_body_hook = Some(hook);
self
}
}

type HyperClient = hyper_util::client::legacy::Client<Connector, super::Body>;
Expand Down Expand Up @@ -2007,6 +2035,11 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let req = if let Some(req_hook) = self.inner.request_hook.as_ref() {
req_hook.intercept(req)
} else {
req
};
let (method, url, mut headers, body, timeout, version) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
Expand Down Expand Up @@ -2105,6 +2138,8 @@ impl Client {
total_timeout,
read_timeout_fut,
read_timeout: self.inner.read_timeout,
response_hook: self.inner.response_hook.clone(),
response_body_hook: self.inner.response_body_hook.clone(),
}),
}
}
Expand Down Expand Up @@ -2314,6 +2349,9 @@ struct ClientRef {
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
https_only: bool,
request_hook: Option<Arc<dyn RequestHook>>,
response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl ClientRef {
Expand Down Expand Up @@ -2379,6 +2417,9 @@ pin_project! {

client: Arc<ClientRef>,

response_hook: Option<Arc<dyn ResponseHook>>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,

#[pin]
in_flight: ResponseFuture,
#[pin]
Expand Down Expand Up @@ -2745,7 +2786,13 @@ impl Future for PendingRequest {
self.client.accepts,
self.total_timeout.take(),
self.read_timeout,
self.response_body_hook.clone(),
);
let res = if let Some(res_hook) = self.response_hook.as_ref() {
res_hook.intercept(res)
} else {
res
};
return Poll::Ready(Ok(res));
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/async_impl/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//! Hooks to intercept the request, response and response body

use bytes::Bytes;

use super::{Request, Response};

/// Hook that gets called before sending the request, right after it's constructed
pub trait RequestHook: Send + Sync {
/// Intercept the request and return it with or without changes
fn intercept(&self, req: Request) -> Request;
}

/// Hook that gets called once the request is completed and headers have been received
pub trait ResponseHook: Send + Sync {
/// Intercept the response and return it with or without changes
fn intercept(&self, res: Response) -> Response;
}

/// Hook that gets called once the request is completed and the full body has been received
pub trait ResponseBodyHook: Send + Sync {
/// Intercept the response body and return it with or without changes
fn intercept(&self, body: Bytes) -> Bytes;
}
2 changes: 2 additions & 0 deletions src/async_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub(crate) use self::decoder::Decoder;
pub mod body;
pub mod client;
pub mod decoder;
pub mod hooks;

pub mod h3_client;
#[cfg(feature = "multipart")]
pub mod multipart;
Expand Down
20 changes: 18 additions & 2 deletions src/async_impl/response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use bytes::Bytes;
Expand All @@ -16,6 +17,7 @@ use url::Url;

use super::body::Body;
use super::decoder::{Accepts, Decoder};
use super::hooks::ResponseBodyHook;
use crate::async_impl::body::ResponseBody;
#[cfg(feature = "cookies")]
use crate::cookie;
Expand All @@ -31,6 +33,7 @@ pub struct Response {
// Boxed to save space (11 words to 1 word), and it's not accessed
// frequently internally.
url: Box<Url>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
}

impl Response {
Expand All @@ -40,6 +43,7 @@ impl Response {
accepts: Accepts,
total_timeout: Option<Pin<Box<Sleep>>>,
read_timeout: Option<Duration>,
response_body_hook: Option<Arc<dyn ResponseBodyHook>>,
) -> Response {
let (mut parts, body) = res.into_parts();
let decoder = Decoder::detect(
Expand All @@ -52,6 +56,7 @@ impl Response {
Response {
res,
url: Box::new(url),
response_body_hook,
}
}

Expand Down Expand Up @@ -218,6 +223,7 @@ impl Response {
let full = self.bytes().await?;

let (text, _, _) = encoding.decode(&full);

Ok(text.into_owned())
}

Expand Down Expand Up @@ -288,9 +294,18 @@ impl Response {
pub async fn bytes(self) -> crate::Result<Bytes> {
use http_body_util::BodyExt;

BodyExt::collect(self.res.into_body())
let bytes = BodyExt::collect(self.res.into_body())
.await
.map(|buf| buf.to_bytes())
.map(|buf| buf.to_bytes())?;

let res_body_hook = self.response_body_hook.clone();
let bytes = if let Some(res_body_hook) = res_body_hook.as_ref() {
res_body_hook.intercept(bytes)
} else {
bytes
};

Ok(bytes)
}

/// Stream a chunk of the response body.
Expand Down Expand Up @@ -468,6 +483,7 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
Response {
res,
url: Box::new(url),
response_body_hook: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ if_hyper! {
#[cfg(feature = "multipart")]
pub use self::async_impl::multipart;

pub use self::async_impl::hooks;

mod async_impl;
#[cfg(feature = "blocking")]
Expand Down
69 changes: 69 additions & 0 deletions tests/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#![cfg(not(target_arch = "wasm32"))]
#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))]
mod support;

use std::sync::{Arc, Mutex};

use bytes::Bytes;
use support::server;

use reqwest::{
hooks::{RequestHook, ResponseBodyHook, ResponseHook},
Client,
};

#[derive(Default)]
struct MyHook {
pub req_visited: Arc<Mutex<bool>>,
pub res_visited: Arc<Mutex<bool>>,
pub body_visited: Arc<Mutex<bool>>,
}

impl RequestHook for MyHook {
fn intercept(&self, req: reqwest::Request) -> reqwest::Request {
*self.req_visited.lock().unwrap() = true;
req
}
}

impl ResponseHook for MyHook {
fn intercept(&self, res: reqwest::Response) -> reqwest::Response {
*self.res_visited.lock().unwrap() = true;
res
}
}

impl ResponseBodyHook for MyHook {
fn intercept(&self, body: Bytes) -> Bytes {
*self.body_visited.lock().unwrap() = true;
body
}
}

#[tokio::test]
async fn full_hook_chain() {
let _ = env_logger::try_init();

let server = server::http(move |_req| async { http::Response::new("Hello".into()) });

let hook = Arc::new(MyHook::default());

let client = Client::builder()
.request_hook(hook.clone())
.response_hook(hook.clone())
.response_body_hook(hook.clone())
.build()
.unwrap();

let res = client
.get(&format!("http://{}/text", server.addr()))
.send()
.await
.expect("Failed to get");
assert_eq!(res.content_length(), Some(5));
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text);
assert!(*hook.req_visited.lock().unwrap());
assert!(*hook.res_visited.lock().unwrap());
assert!(*hook.body_visited.lock().unwrap());
}