@@ -12,14 +12,15 @@ use url::Host;
1212use url:: Url ;
1313
1414use crate :: headers:: { CONTENT_TYPE_JSON , CONTENT_TYPE_PROTOBUF } ;
15- use crate :: { serialize_proto_message, Result , TwirpErrorResponse } ;
15+ use crate :: { malformed , serialize_proto_message, Result , TwirpErrorResponse } ;
1616
1717/// Builder to easily create twirp clients.
1818pub struct ClientBuilder {
1919 base_url : Url ,
2020 http_client : Option < reqwest:: Client > ,
2121 handlers : Option < RequestHandlers > ,
2222 middleware : Vec < Box < dyn Middleware > > ,
23+ extensions : http:: Extensions ,
2324}
2425
2526impl ClientBuilder {
@@ -30,6 +31,7 @@ impl ClientBuilder {
3031 http_client : None ,
3132 middleware : vec ! [ ] ,
3233 handlers : None ,
34+ extensions : http:: Extensions :: new ( ) ,
3335 }
3436 }
3537
@@ -44,6 +46,7 @@ impl ClientBuilder {
4446 http_client : None ,
4547 middleware : vec ! [ ] ,
4648 handlers : Some ( RequestHandlers :: new ( ) ) ,
49+ extensions : http:: Extensions :: new ( ) ,
4750 }
4851 }
4952
@@ -100,13 +103,46 @@ impl ClientBuilder {
100103 self
101104 }
102105
106+ /// Attach a typed extension that will be inserted onto every outbound request before it
107+ /// reaches middleware and handlers. Per-call extensions set on the inbound `http::Request`
108+ /// take precedence and override these values. Calling this with a value of the same type
109+ /// replaces the previous value.
110+ pub fn with_extension < T > ( mut self , value : T ) -> Self
111+ where
112+ T : Clone + Send + Sync + ' static ,
113+ {
114+ self . extensions . insert ( value) ;
115+ self
116+ }
117+
103118 /// Creates a `twirp::Client`.
104119 ///
105120 /// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
106121 /// you create one and **reuse** it.
107122 pub fn build ( self ) -> Client {
108123 let http_client = self . http_client . unwrap_or_default ( ) ;
109- Client :: new ( self . base_url , http_client, self . middleware , self . handlers )
124+ Client {
125+ http_client,
126+ inner : Arc :: new ( ClientRef {
127+ base_url : normalize_base_url ( self . base_url ) ,
128+ middlewares : self . middleware ,
129+ handlers : self . handlers ,
130+ extensions : self . extensions ,
131+ } ) ,
132+ host : None ,
133+ }
134+ }
135+ }
136+
137+ fn normalize_base_url ( base_url : Url ) -> Url {
138+ if base_url. path ( ) . ends_with ( '/' ) {
139+ base_url
140+ } else {
141+ let mut base_url = base_url;
142+ let mut path = base_url. path ( ) . to_string ( ) ;
143+ path. push ( '/' ) ;
144+ base_url. set_path ( & path) ;
145+ base_url
110146 }
111147}
112148
@@ -126,6 +162,7 @@ struct ClientRef {
126162 base_url : Url ,
127163 middlewares : Vec < Box < dyn Middleware > > ,
128164 handlers : Option < RequestHandlers > ,
165+ extensions : http:: Extensions ,
129166}
130167
131168impl std:: fmt:: Debug for Client {
@@ -143,6 +180,7 @@ impl std::fmt::Debug for Client {
143180 . map ( |x| x. len ( ) )
144181 . unwrap_or_default ( ) ,
145182 )
183+ . field ( "extensions" , & self . inner . extensions . len ( ) )
146184 . finish ( )
147185 }
148186}
@@ -158,21 +196,13 @@ impl Client {
158196 middlewares : Vec < Box < dyn Middleware > > ,
159197 handlers : Option < RequestHandlers > ,
160198 ) -> Self {
161- let base_url = if base_url. path ( ) . ends_with ( '/' ) {
162- base_url
163- } else {
164- let mut base_url = base_url;
165- let mut path = base_url. path ( ) . to_string ( ) ;
166- path. push ( '/' ) ;
167- base_url. set_path ( & path) ;
168- base_url
169- } ;
170199 Client {
171200 http_client,
172201 inner : Arc :: new ( ClientRef {
173- base_url,
202+ base_url : normalize_base_url ( base_url ) ,
174203 middlewares,
175204 handlers,
205+ extensions : http:: Extensions :: new ( ) ,
176206 } ) ,
177207 host : None ,
178208 }
@@ -216,13 +246,25 @@ impl Client {
216246 url. set_host ( Some ( host) ) ?
217247 } ;
218248 let ( parts, body) = req. into_parts ( ) ;
219- let request = self
220- . http_client
221- . post ( url)
222- . headers ( parts. headers )
223- . header ( CONTENT_TYPE , CONTENT_TYPE_PROTOBUF )
224- . body ( serialize_proto_message ( body) )
225- . build ( ) ?;
249+ // Build as `http::Request<reqwest::Body>` so that extensions propagate through
250+ // the public `TryFrom<http::Request<T>> for reqwest::Request` impl (the inherent
251+ // `extensions_mut` on `reqwest::Request` is `pub(crate)`).
252+ let mut http_req = http:: Request :: builder ( )
253+ . method ( http:: Method :: POST )
254+ . uri ( url. to_string ( ) )
255+ . body ( reqwest:: Body :: from ( serialize_proto_message ( body) ) )
256+ . map_err ( |e| malformed ( format ! ( "failed to build the request: {e}" ) ) ) ?;
257+ * http_req. headers_mut ( ) = parts. headers ;
258+ http_req. headers_mut ( ) . insert (
259+ CONTENT_TYPE ,
260+ HeaderValue :: from_bytes ( CONTENT_TYPE_PROTOBUF )
261+ . expect ( "CONTENT_TYPE_PROTOBUF is always a valid header value" ) ,
262+ ) ;
263+ // Apply per-client extensions first, then let per-call extensions override.
264+ let request_extensions = http_req. extensions_mut ( ) ;
265+ request_extensions. extend ( self . inner . extensions . clone ( ) ) ;
266+ request_extensions. extend ( parts. extensions ) ;
267+ let request = reqwest:: Request :: try_from ( http_req) ?;
226268
227269 // Create and execute the middleware handlers
228270 let next = Next :: new (
@@ -464,4 +506,76 @@ mod tests {
464506 assert_eq ! ( data. name, "hi" ) ;
465507 h. abort ( )
466508 }
509+
510+ struct RecordingHandler {
511+ observed_request_id : Arc < std:: sync:: Mutex < Option < RequestId > > > ,
512+ }
513+
514+ #[ async_trait]
515+ impl DirectHandler for RecordingHandler {
516+ fn service ( & self ) -> & str {
517+ "test.TestAPI"
518+ }
519+
520+ async fn handle ( & self , _method : & str , req : Request ) -> Result < Response > {
521+ let decoded: http:: Request < PingRequest > = crate :: details:: decode_request ( req) . await ?;
522+ * self . observed_request_id . lock ( ) . unwrap ( ) =
523+ decoded. extensions ( ) . get :: < RequestId > ( ) . cloned ( ) ;
524+ let body = serialize_proto_message ( PingResponse {
525+ name : "pong" . to_string ( ) ,
526+ } ) ;
527+ let response = http:: Response :: builder ( )
528+ . status ( 200 )
529+ . header ( CONTENT_TYPE , CONTENT_TYPE_PROTOBUF )
530+ . body ( body)
531+ . expect ( "valid response" ) ;
532+ Ok ( Response :: from ( response) )
533+ }
534+ }
535+
536+ #[ tokio:: test]
537+ async fn test_with_extension_propagates_to_handler ( ) {
538+ let observed = Arc :: new ( std:: sync:: Mutex :: new ( None ) ) ;
539+ let client = ClientBuilder :: direct ( )
540+ . with_handler ( RecordingHandler {
541+ observed_request_id : observed. clone ( ) ,
542+ } )
543+ . with_extension ( RequestId ( "req-from-builder" . to_string ( ) ) )
544+ . build ( ) ;
545+
546+ let resp = client
547+ . ping ( http:: Request :: new ( PingRequest {
548+ name : "hi" . to_string ( ) ,
549+ } ) )
550+ . await
551+ . expect ( "request succeeds" ) ;
552+ assert_eq ! ( resp. into_body( ) . name, "pong" ) ;
553+ assert_eq ! (
554+ observed. lock( ) . unwrap( ) . clone( ) ,
555+ Some ( RequestId ( "req-from-builder" . to_string( ) ) )
556+ ) ;
557+ }
558+
559+ #[ tokio:: test]
560+ async fn test_per_call_extension_overrides_builder ( ) {
561+ let observed = Arc :: new ( std:: sync:: Mutex :: new ( None ) ) ;
562+ let client = ClientBuilder :: direct ( )
563+ . with_handler ( RecordingHandler {
564+ observed_request_id : observed. clone ( ) ,
565+ } )
566+ . with_extension ( RequestId ( "builder" . to_string ( ) ) )
567+ . build ( ) ;
568+
569+ let mut req = http:: Request :: new ( PingRequest {
570+ name : "hi" . to_string ( ) ,
571+ } ) ;
572+ req. extensions_mut ( )
573+ . insert ( RequestId ( "per-call" . to_string ( ) ) ) ;
574+ let resp = client. ping ( req) . await . expect ( "request succeeds" ) ;
575+ assert_eq ! ( resp. into_body( ) . name, "pong" ) ;
576+ assert_eq ! (
577+ observed. lock( ) . unwrap( ) . clone( ) ,
578+ Some ( RequestId ( "per-call" . to_string( ) ) )
579+ ) ;
580+ }
467581}
0 commit comments