Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Per-request redirect policy override #2440

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ impl ClientBuilder {
/// Set a `RedirectPolicy` for this client.
///
/// Default will follow redirects up to a maximum of 10.
pub fn redirect(mut self, policy: redirect::Policy) -> ClientBuilder {
pub fn redirect_policy(mut self, policy: redirect::Policy) -> ClientBuilder {
self.config.redirect_policy = policy;
self
}
Expand Down Expand Up @@ -2007,7 +2007,7 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let (method, url, mut headers, body, timeout, version) = req.pieces();
let (method, url, mut headers, body, timeout, redirect_policy, version) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
}
Expand Down Expand Up @@ -2098,6 +2098,7 @@ impl Client {
urls: Vec::new(),

retry_count: 0,
redirect_policy,

client: self.inner.clone(),

Expand Down Expand Up @@ -2376,6 +2377,7 @@ pin_project! {
urls: Vec<Url>,

retry_count: usize,
redirect_policy: Option<redirect::Policy>,

client: Arc<ClientRef>,

Expand Down Expand Up @@ -2660,9 +2662,11 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
// This request's redirect policy overrides the client's redirect policy
let action = self
.client
.redirect_policy
.as_ref()
.unwrap_or(&self.client.redirect_policy)
.check(res.status(), &loc, &self.urls);

match action {
Expand Down
27 changes: 26 additions & 1 deletion src/async_impl/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::response::Response;
#[cfg(feature = "multipart")]
use crate::header::CONTENT_LENGTH;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{Method, Url};
use crate::{redirect, Method, Url};
use http::{request::Parts, Request as HttpRequest, Version};

/// A request which can be executed with `Client::execute()`.
Expand All @@ -25,6 +25,7 @@ pub struct Request {
headers: HeaderMap,
body: Option<Body>,
timeout: Option<Duration>,
redirect_policy: Option<redirect::Policy>,
version: Version,
}

Expand All @@ -47,6 +48,7 @@ impl Request {
headers: HeaderMap::new(),
body: None,
timeout: None,
redirect_policy: None,
version: Version::default(),
}
}
Expand Down Expand Up @@ -111,6 +113,18 @@ impl Request {
&mut self.timeout
}

/// Get a this request's redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.redirect_policy.as_ref()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
&mut self.redirect_policy
}

/// Get the http version.
#[inline]
pub fn version(&self) -> Version {
Expand Down Expand Up @@ -147,6 +161,7 @@ impl Request {
HeaderMap,
Option<Body>,
Option<Duration>,
Option<redirect::Policy>,
Version,
) {
(
Expand All @@ -155,6 +170,7 @@ impl Request {
self.headers,
self.body,
self.timeout,
self.redirect_policy,
self.version,
)
}
Expand Down Expand Up @@ -290,6 +306,14 @@ impl RequestBuilder {
self
}

/// Overrides the client's redirect policy for this request
pub fn redirect_policy(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy)
}
self
}

/// Sends a multipart/form-data body.
///
/// ```
Expand Down Expand Up @@ -620,6 +644,7 @@ where
headers,
body: Some(body.into()),
timeout: None,
redirect_policy: None,
version,
})
}
Expand Down
63 changes: 62 additions & 1 deletion tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async fn test_redirect_policy_can_stop_redirects_without_an_error() {
let url = format!("http://{}/no-redirect", server.addr());

let res = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.redirect_policy(reqwest::redirect::Policy::none())
.build()
.unwrap()
.get(&url)
Expand Down Expand Up @@ -376,3 +376,64 @@ async fn test_redirect_https_only_enforced_gh1312() {
let err = res.unwrap_err();
assert!(err.is_redirect());
}

#[tokio::test]
async fn test_request_redirect() {
let code = 301u16;

let redirect = server::http(move |req| async move {
if req.method() == "POST" {
assert_eq!(req.uri(), &*format!("/{}", code));
http::Response::builder()
.status(code)
.header("location", "/dst")
.header("server", "test-redirect")
.body(Default::default())
.unwrap()
} else {
assert_eq!(req.method(), "GET");

http::Response::builder()
.header("server", "test-dst")
.body(Default::default())
.unwrap()
}
});

let url = format!("http://{}/{}", redirect.addr(), code);
let dst = format!("http://{}/{}", redirect.addr(), "dst");

let default_redirect_client = reqwest::Client::new();
let res = default_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::none())
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-redirect"
);

let no_redirect_client = reqwest::Client::builder()
.redirect_policy(reqwest::redirect::Policy::none())
.build()
.unwrap();

let res = no_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::limited(2))
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-dst"
);
}