Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions rapina/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,47 @@ impl Rapina {
self
}

/// Adds a pre-existing `Arc<T>` as shared state.
///
/// This is the preferred way to register trait objects for dependency
/// injection. Without this method, passing an `Arc<dyn MyTrait>` to
/// [`.state()`](Self::state) requires a newtype wrapper; with `state_arc`
/// you can register the arc directly.
///
/// With `state_arc` the `Arc` is registered under its own [`TypeId`], so
/// no newtype wrapper is required.
///
/// # Example
///
/// ```rust,no_run
/// use rapina::prelude::*;
/// use std::sync::Arc;
///
/// trait MyRepo: Send + Sync {
/// fn find_all(&self) -> Vec<String>;
/// }
///
/// struct PgRepo;
/// impl MyRepo for PgRepo {
/// fn find_all(&self) -> Vec<String> { vec![] }
/// }
///
/// #[tokio::main]
/// async fn main() -> std::io::Result<()> {
/// Rapina::new()
/// .state_arc(Arc::new(PgRepo) as Arc<dyn MyRepo>)
/// .listen("127.0.0.1:3000")
/// .await
/// }
/// ```
pub fn state_arc<T: ?Sized + Send + Sync + 'static>(
mut self,
value: std::sync::Arc<T>,
) -> Self {
self.state = self.state.with_arc(value);
self
}

/// Adds a middleware to the application.
pub fn middleware<M: Middleware>(mut self, middleware: M) -> Self {
self.middlewares.add(middleware);
Expand Down
31 changes: 30 additions & 1 deletion rapina/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ impl<T: Send + Sync + 'static> FromRequestParts for State<T> {
) -> Result<Self, Error> {
let arc = state.get_arc::<T>().ok_or_else(|| {
Error::internal(format!(
"State not registered for type '{}'. Did you forget to call .state()?",
"State not registered for type '{}'. Did you forget to call .state() or .state_arc()?",
std::any::type_name::<T>()
))
})?;
Expand Down Expand Up @@ -1239,6 +1239,35 @@ mod tests {
assert_eq!(result.unwrap_err().status(), 500);
}

#[tokio::test]
async fn test_state_extractor_arc_trait_object() {
trait Greeter: Send + Sync {
fn greet(&self) -> &'static str;
}

struct Hello;
impl Greeter for Hello {
fn greet(&self) -> &'static str {
"hello"
}
}

let greeter: std::sync::Arc<dyn Greeter> = std::sync::Arc::new(Hello);
let state = std::sync::Arc::new(crate::state::AppState::new().with_arc(greeter));
let (parts, _) = TestRequest::get("/").into_parts();

let result = State::<std::sync::Arc<dyn Greeter>>::from_request_parts(
&parts,
&empty_params(),
&state,
)
.await;

assert!(result.is_ok());
// Deref chain: State<Arc<dyn Greeter>> -> Arc<dyn Greeter> -> dyn Greeter
assert_eq!(result.unwrap().greet(), "hello");
}

// into_inner tests
#[test]
fn test_json_into_inner() {
Expand Down
119 changes: 119 additions & 0 deletions rapina/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,42 @@ impl AppState {
.and_then(|arc| arc.downcast_ref::<T>())
}

/// Registers a pre-existing `Arc<T>` as shared state.
///
/// Use this when `T` is a trait object (e.g. `Arc<dyn MyTrait>`) and you
/// want to access it via [`State<Arc<dyn MyTrait>>`](crate::extract::State)
/// in handlers without needing a newtype wrapper.
///
/// Internally the value is stored under `TypeId::of::<Arc<T>>()` wrapped in
/// one additional `Arc` (as required by the state map). Handlers receive
/// `State<Arc<dyn MyTrait>>` and can call methods directly via auto-deref,
/// or clone the inner arc with `(*state).clone()`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be clearer written as Arc::clone(&*state) instead of (*state).clone() since that makes the intent explicit

///
/// # Examples
///
/// ```
/// use rapina::state::AppState;
/// use std::sync::Arc;
///
/// trait Greeter: Send + Sync {
/// fn greet(&self) -> String;
/// }
///
/// struct HelloGreeter;
/// impl Greeter for HelloGreeter {
/// fn greet(&self) -> String { "hello".to_string() }
/// }
///
/// let greeter: Arc<dyn Greeter> = Arc::new(HelloGreeter);
/// let state = AppState::new().with_arc(greeter);
///
/// // Access via State<Arc<dyn Greeter>> in handlers; deref gives Arc<dyn Greeter>
/// ```
pub fn with_arc<T: ?Sized + Send + Sync + 'static>(mut self, value: Arc<T>) -> Self {
self.inner.insert(TypeId::of::<Arc<T>>(), Arc::new(value));
self
}

/// Retrieves a shared `Arc<T>` for a value of type `T`, if registered.
///
/// This is useful when you want to share state without cloning the
Expand Down Expand Up @@ -211,4 +247,87 @@ mod tests {
assert_eq!(state.get::<f64>(), Some(&3.0));
assert_eq!(state.get::<String>(), Some(&"test".to_string()));
}

#[test]
fn test_with_arc_concrete_type() {
#[derive(Debug, PartialEq)]
struct Repo {
name: &'static str,
}

let arc = Arc::new(Repo { name: "pg" });
let state = AppState::new().with_arc(Arc::clone(&arc));

// Extracted via get_arc::<Arc<Repo>>()
let extracted = state.get_arc::<Arc<Repo>>().unwrap();
assert_eq!(extracted.name, "pg");
}

#[test]
fn test_with_arc_trait_object() {
trait Greeter: Send + Sync {
fn greet(&self) -> &'static str;
}

struct Hello;
impl Greeter for Hello {
fn greet(&self) -> &'static str {
"hello"
}
}

let greeter: Arc<dyn Greeter> = Arc::new(Hello);
let state = AppState::new().with_arc(greeter);

let extracted = state.get_arc::<Arc<dyn Greeter>>().unwrap();
assert_eq!(extracted.greet(), "hello");
}

#[test]
fn test_with_arc_does_not_conflict_with_with() {
// with() and with_arc() on same logical type use different TypeIds
// (TypeId::of::<T>() vs TypeId::of::<Arc<T>>()) so they coexist.
#[derive(Debug, PartialEq)]
struct Config {
val: i32,
}

let concrete = Config { val: 1 };
let arc = Arc::new(Config { val: 2 });

let state = AppState::new().with(concrete).with_arc(Arc::clone(&arc));

assert_eq!(state.get::<Config>().unwrap().val, 1);
assert_eq!(state.get_arc::<Arc<Config>>().unwrap().val, 2);
}

#[test]
fn test_with_arc_missing_returns_none() {
trait Repo: Send + Sync {}

let state = AppState::new();
assert!(state.get_arc::<Arc<dyn Repo>>().is_none());
}

#[test]
fn test_with_arc_overwrites_same_arc_type() {
trait Counter: Send + Sync {
fn count(&self) -> u32;
}

struct CounterImpl(u32);
impl Counter for CounterImpl {
fn count(&self) -> u32 {
self.0
}
}

let first: Arc<dyn Counter> = Arc::new(CounterImpl(1));
let second: Arc<dyn Counter> = Arc::new(CounterImpl(2));

let state = AppState::new().with_arc(first).with_arc(second);

let extracted = state.get_arc::<Arc<dyn Counter>>().unwrap();
assert_eq!(extracted.count(), 2);
}
}
Loading