Skip to content

Commit b5490f7

Browse files
committed
Add request modifier to direct twirp calls
1 parent 2ab449f commit b5490f7

3 files changed

Lines changed: 161 additions & 33 deletions

File tree

crates/twirp/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- `ClientBuilder::with_extension` attaches a typed extension that is inserted onto every
13+
outbound request before middleware and handlers run. Per-call extensions set on the
14+
inbound `http::Request` take precedence.
15+
16+
### Fixed
17+
18+
- `Client::request` now propagates per-call request extensions from the inbound
19+
`http::Request` through to the `reqwest::Request`, and `details::decode_request`
20+
preserves them when decoding back into `http::Request`, so extensions are visible to
21+
`DirectHandler`-routed API trait methods.
22+
1023
## [0.11.0](https://github.com/github/twirp-rs/compare/twirp-v0.10.1...twirp-v0.11.0) - 2026-04-10
1124

1225
### Other

crates/twirp/src/client.rs

Lines changed: 133 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@ use url::Host;
1212
use url::Url;
1313

1414
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
15-
use crate::{serialize_proto_message, Result, TwirpErrorResponse};
15+
use crate::{malformed, serialize_proto_message, Result, TwirpErrorResponse};
1616

1717
/// Builder to easily create twirp clients.
1818
pub struct ClientBuilder {
1919
base_url: Url,
2020
http_client: Option<reqwest::Client>,
2121
handlers: Option<RequestHandlers>,
2222
middleware: Vec<Box<dyn Middleware>>,
23+
extensions: http::Extensions,
2324
}
2425

2526
impl ClientBuilder {
@@ -30,6 +31,7 @@ impl ClientBuilder {
3031
http_client: None,
3132
middleware: vec![],
3233
handlers: None,
34+
extensions: http::Extensions::new(),
3335
}
3436
}
3537

@@ -44,6 +46,7 @@ impl ClientBuilder {
4446
http_client: None,
4547
middleware: vec![],
4648
handlers: Some(RequestHandlers::new()),
49+
extensions: http::Extensions::new(),
4750
}
4851
}
4952

@@ -100,13 +103,46 @@ impl ClientBuilder {
100103
self
101104
}
102105

106+
/// Attach a typed extension that will be inserted onto every outbound request before it
107+
/// reaches middleware and handlers. Per-call extensions set on the inbound `http::Request`
108+
/// take precedence and override these values. Calling this with a value of the same type
109+
/// replaces the previous value.
110+
pub fn with_extension<T>(mut self, value: T) -> Self
111+
where
112+
T: Clone + Send + Sync + 'static,
113+
{
114+
self.extensions.insert(value);
115+
self
116+
}
117+
103118
/// Creates a `twirp::Client`.
104119
///
105120
/// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
106121
/// you create one and **reuse** it.
107122
pub fn build(self) -> Client {
108123
let http_client = self.http_client.unwrap_or_default();
109-
Client::new(self.base_url, http_client, self.middleware, self.handlers)
124+
Client {
125+
http_client,
126+
inner: Arc::new(ClientRef {
127+
base_url: normalize_base_url(self.base_url),
128+
middlewares: self.middleware,
129+
handlers: self.handlers,
130+
extensions: self.extensions,
131+
}),
132+
host: None,
133+
}
134+
}
135+
}
136+
137+
fn normalize_base_url(base_url: Url) -> Url {
138+
if base_url.path().ends_with('/') {
139+
base_url
140+
} else {
141+
let mut base_url = base_url;
142+
let mut path = base_url.path().to_string();
143+
path.push('/');
144+
base_url.set_path(&path);
145+
base_url
110146
}
111147
}
112148

@@ -126,6 +162,7 @@ struct ClientRef {
126162
base_url: Url,
127163
middlewares: Vec<Box<dyn Middleware>>,
128164
handlers: Option<RequestHandlers>,
165+
extensions: http::Extensions,
129166
}
130167

131168
impl std::fmt::Debug for Client {
@@ -143,6 +180,7 @@ impl std::fmt::Debug for Client {
143180
.map(|x| x.len())
144181
.unwrap_or_default(),
145182
)
183+
.field("extensions", &self.inner.extensions.len())
146184
.finish()
147185
}
148186
}
@@ -158,21 +196,13 @@ impl Client {
158196
middlewares: Vec<Box<dyn Middleware>>,
159197
handlers: Option<RequestHandlers>,
160198
) -> Self {
161-
let base_url = if base_url.path().ends_with('/') {
162-
base_url
163-
} else {
164-
let mut base_url = base_url;
165-
let mut path = base_url.path().to_string();
166-
path.push('/');
167-
base_url.set_path(&path);
168-
base_url
169-
};
170199
Client {
171200
http_client,
172201
inner: Arc::new(ClientRef {
173-
base_url,
202+
base_url: normalize_base_url(base_url),
174203
middlewares,
175204
handlers,
205+
extensions: http::Extensions::new(),
176206
}),
177207
host: None,
178208
}
@@ -216,13 +246,25 @@ impl Client {
216246
url.set_host(Some(host))?
217247
};
218248
let (parts, body) = req.into_parts();
219-
let request = self
220-
.http_client
221-
.post(url)
222-
.headers(parts.headers)
223-
.header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
224-
.body(serialize_proto_message(body))
225-
.build()?;
249+
// Build as `http::Request<reqwest::Body>` so that extensions propagate through
250+
// the public `TryFrom<http::Request<T>> for reqwest::Request` impl (the inherent
251+
// `extensions_mut` on `reqwest::Request` is `pub(crate)`).
252+
let mut http_req = http::Request::builder()
253+
.method(http::Method::POST)
254+
.uri(url.to_string())
255+
.body(reqwest::Body::from(serialize_proto_message(body)))
256+
.map_err(|e| malformed(format!("failed to build the request: {e}")))?;
257+
*http_req.headers_mut() = parts.headers;
258+
http_req.headers_mut().insert(
259+
CONTENT_TYPE,
260+
HeaderValue::from_bytes(CONTENT_TYPE_PROTOBUF)
261+
.expect("CONTENT_TYPE_PROTOBUF is always a valid header value"),
262+
);
263+
// Apply per-client extensions first, then let per-call extensions override.
264+
let request_extensions = http_req.extensions_mut();
265+
request_extensions.extend(self.inner.extensions.clone());
266+
request_extensions.extend(parts.extensions);
267+
let request = reqwest::Request::try_from(http_req)?;
226268

227269
// Create and execute the middleware handlers
228270
let next = Next::new(
@@ -464,4 +506,76 @@ mod tests {
464506
assert_eq!(data.name, "hi");
465507
h.abort()
466508
}
509+
510+
struct RecordingHandler {
511+
observed_request_id: Arc<std::sync::Mutex<Option<RequestId>>>,
512+
}
513+
514+
#[async_trait]
515+
impl DirectHandler for RecordingHandler {
516+
fn service(&self) -> &str {
517+
"test.TestAPI"
518+
}
519+
520+
async fn handle(&self, _method: &str, req: Request) -> Result<Response> {
521+
let decoded: http::Request<PingRequest> = crate::details::decode_request(req).await?;
522+
*self.observed_request_id.lock().unwrap() =
523+
decoded.extensions().get::<RequestId>().cloned();
524+
let body = serialize_proto_message(PingResponse {
525+
name: "pong".to_string(),
526+
});
527+
let response = http::Response::builder()
528+
.status(200)
529+
.header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
530+
.body(body)
531+
.expect("valid response");
532+
Ok(Response::from(response))
533+
}
534+
}
535+
536+
#[tokio::test]
537+
async fn test_with_extension_propagates_to_handler() {
538+
let observed = Arc::new(std::sync::Mutex::new(None));
539+
let client = ClientBuilder::direct()
540+
.with_handler(RecordingHandler {
541+
observed_request_id: observed.clone(),
542+
})
543+
.with_extension(RequestId("req-from-builder".to_string()))
544+
.build();
545+
546+
let resp = client
547+
.ping(http::Request::new(PingRequest {
548+
name: "hi".to_string(),
549+
}))
550+
.await
551+
.expect("request succeeds");
552+
assert_eq!(resp.into_body().name, "pong");
553+
assert_eq!(
554+
observed.lock().unwrap().clone(),
555+
Some(RequestId("req-from-builder".to_string()))
556+
);
557+
}
558+
559+
#[tokio::test]
560+
async fn test_per_call_extension_overrides_builder() {
561+
let observed = Arc::new(std::sync::Mutex::new(None));
562+
let client = ClientBuilder::direct()
563+
.with_handler(RecordingHandler {
564+
observed_request_id: observed.clone(),
565+
})
566+
.with_extension(RequestId("builder".to_string()))
567+
.build();
568+
569+
let mut req = http::Request::new(PingRequest {
570+
name: "hi".to_string(),
571+
});
572+
req.extensions_mut()
573+
.insert(RequestId("per-call".to_string()));
574+
let resp = client.ping(req).await.expect("request succeeds");
575+
assert_eq!(resp.into_body().name, "pong");
576+
assert_eq!(
577+
observed.lock().unwrap().clone(),
578+
Some(RequestId("per-call".to_string()))
579+
);
580+
}
467581
}

crates/twirp/src/details.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,25 +65,26 @@ where
6565
}
6666

6767
/// Decode a `reqwest::Request` into a `http::Request<I>`.
68-
pub async fn decode_request<I>(mut req: reqwest::Request) -> Result<http::Request<I>>
68+
pub async fn decode_request<I>(req: reqwest::Request) -> Result<http::Request<I>>
6969
where
7070
I: prost::Message + Default,
7171
{
72-
let url = req.url().clone();
73-
let headers = req.headers().clone();
74-
let body = std::mem::take(req.body_mut())
75-
.ok_or_else(|| malformed("failed to read the request body"))?
72+
let http_req: http::Request<reqwest::Body> = req
73+
.try_into()
74+
.map_err(|e| malformed(format!("failed to convert request: {e}")))?;
75+
let (parts, body) = http_req.into_parts();
76+
let bytes = body
7677
.collect()
77-
.await?
78+
.await
79+
.map_err(|e| malformed(format!("failed to read the request body: {e}")))?
7880
.to_bytes();
79-
let data = I::decode(body).map_err(|e| malformed(format!("failed to decode request: {e}")))?;
80-
let mut req = Request::builder().method("POST").uri(url.to_string());
81-
req.headers_mut()
82-
.expect("failed to get headers")
83-
.extend(headers);
84-
let req = req
85-
.body(data)
86-
.map_err(|e| malformed(format!("failed to build the request: {e}")))?;
81+
let data = I::decode(bytes).map_err(|e| malformed(format!("failed to decode request: {e}")))?;
82+
let mut req = http::Request::new(data);
83+
*req.method_mut() = parts.method;
84+
*req.uri_mut() = parts.uri;
85+
*req.version_mut() = parts.version;
86+
*req.headers_mut() = parts.headers;
87+
*req.extensions_mut() = parts.extensions;
8788
Ok(req)
8889
}
8990

0 commit comments

Comments
 (0)