diff --git a/crates/rustapi-core/Cargo.toml b/crates/rustapi-core/Cargo.toml index adb4c79..1680c57 100644 --- a/crates/rustapi-core/Cargo.toml +++ b/crates/rustapi-core/Cargo.toml @@ -90,6 +90,7 @@ proptest = "1.4" rustapi-testing = { workspace = true } reqwest = { version = "0.12", features = ["json", "stream"] } async-stream = "0.3" +async-trait = { workspace = true } [features] default = ["swagger-ui", "tracing"] swagger-ui = ["rustapi-openapi/swagger-ui"] diff --git a/crates/rustapi-core/src/extract.rs b/crates/rustapi-core/src/extract.rs index a720409..6051994 100644 --- a/crates/rustapi-core/src/extract.rs +++ b/crates/rustapi-core/src/extract.rs @@ -378,8 +378,12 @@ impl FromRequest for AsyncVal let value: T = json::from_slice(&body)?; // Create validation context from request - // TODO: Extract validators from App State - let ctx = ValidationContext::default(); + // Check if validators are configured in App State + let ctx = if let Some(ctx) = req.state().get::() { + ctx.clone() + } else { + ValidationContext::default() + }; // Perform full validation (sync + async) if let Err(errors) = value.validate_full(&ctx).await { @@ -1715,4 +1719,178 @@ mod tests { assert_eq!(cookies.get("token").unwrap().value(), "xyz789"); } } + + #[tokio::test] + async fn test_async_validated_json_with_state_context() { + use async_trait::async_trait; + use rustapi_validate::prelude::*; + use rustapi_validate::v2::{ + AsyncValidationRule, DatabaseValidator, ValidationContextBuilder, + }; + use serde::{Deserialize, Serialize}; + + struct MockDbValidator { + unique_values: Vec, + } + + #[async_trait] + impl DatabaseValidator for MockDbValidator { + async fn exists( + &self, + _table: &str, + _column: &str, + _value: &str, + ) -> Result { + Ok(true) + } + async fn is_unique( + &self, + _table: &str, + _column: &str, + value: &str, + ) -> Result { + Ok(!self.unique_values.contains(&value.to_string())) + } + async fn is_unique_except( + &self, + _table: &str, + _column: &str, + value: &str, + _except_id: &str, + ) -> Result { + Ok(!self.unique_values.contains(&value.to_string())) + } + } + + #[derive(Debug, Deserialize, Serialize)] + struct TestUser { + email: String, + } + + impl Validate for TestUser { + fn validate_with_group( + &self, + _group: rustapi_validate::v2::ValidationGroup, + ) -> Result<(), rustapi_validate::v2::ValidationErrors> { + Ok(()) + } + } + + #[async_trait] + impl AsyncValidate for TestUser { + async fn validate_async_with_group( + &self, + ctx: &ValidationContext, + _group: rustapi_validate::v2::ValidationGroup, + ) -> Result<(), rustapi_validate::v2::ValidationErrors> { + let mut errors = rustapi_validate::v2::ValidationErrors::new(); + + let rule = AsyncUniqueRule::new("users", "email"); + if let Err(e) = rule.validate_async(&self.email, ctx).await { + errors.add("email", e); + } + + errors.into_result() + } + } + + // Test 1: Without context in state (should fail due to missing validator) + let uri: http::Uri = "/test".parse().unwrap(); + let user = TestUser { + email: "new@example.com".to_string(), + }; + let body_bytes = serde_json::to_vec(&user).unwrap(); + + let builder = http::Request::builder() + .method(Method::POST) + .uri(uri.clone()) + .header("content-type", "application/json"); + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + // Construct Request with BodyVariant::Buffered + let mut request = Request::new( + parts, + crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())), + Arc::new(Extensions::new()), + PathParams::new(), + ); + + let result = AsyncValidatedJson::::from_request(&mut request).await; + + assert!(result.is_err(), "Expected error when validator is missing"); + let err = result.unwrap_err(); + let err_str = format!("{:?}", err); + assert!( + err_str.contains("Database validator not configured") + || err_str.contains("async_unique"), + "Error should mention missing configuration or rule: {:?}", + err_str + ); + + // Test 2: With context in state (should succeed) + let db_validator = MockDbValidator { + unique_values: vec!["taken@example.com".to_string()], + }; + let ctx = ValidationContextBuilder::new() + .database(db_validator) + .build(); + + let mut extensions = Extensions::new(); + extensions.insert(ctx); + + let builder = http::Request::builder() + .method(Method::POST) + .uri(uri.clone()) + .header("content-type", "application/json"); + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + let mut request = Request::new( + parts, + crate::request::BodyVariant::Buffered(Bytes::from(body_bytes.clone())), + Arc::new(extensions), + PathParams::new(), + ); + + let result = AsyncValidatedJson::::from_request(&mut request).await; + assert!( + result.is_ok(), + "Expected success when validator is present and value is unique. Error: {:?}", + result.err() + ); + + // Test 3: With context in state (should fail validation logic) + let user_taken = TestUser { + email: "taken@example.com".to_string(), + }; + let body_taken = serde_json::to_vec(&user_taken).unwrap(); + + let db_validator = MockDbValidator { + unique_values: vec!["taken@example.com".to_string()], + }; + let ctx = ValidationContextBuilder::new() + .database(db_validator) + .build(); + + let mut extensions = Extensions::new(); + extensions.insert(ctx); + + let builder = http::Request::builder() + .method(Method::POST) + .uri("/test") + .header("content-type", "application/json"); + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + let mut request = Request::new( + parts, + crate::request::BodyVariant::Buffered(Bytes::from(body_taken)), + Arc::new(extensions), + PathParams::new(), + ); + + let result = AsyncValidatedJson::::from_request(&mut request).await; + assert!(result.is_err(), "Expected validation error for taken email"); + } } diff --git a/crates/rustapi-validate/src/v2/context.rs b/crates/rustapi-validate/src/v2/context.rs index 91c2698..923b210 100644 --- a/crates/rustapi-validate/src/v2/context.rs +++ b/crates/rustapi-validate/src/v2/context.rs @@ -53,7 +53,7 @@ pub trait CustomValidator: Send + Sync { /// /// user.validate_async(&ctx).await?; /// ``` -#[derive(Default)] +#[derive(Clone, Default)] pub struct ValidationContext { database: Option>, http: Option>, @@ -114,7 +114,7 @@ impl std::fmt::Debug for ValidationContext { } /// Builder for constructing a `ValidationContext`. -#[derive(Default)] +#[derive(Clone, Default)] pub struct ValidationContextBuilder { database: Option>, http: Option>,