diff --git a/crates/protoc-gen-protovalidate-buffa/src/emit/cel.rs b/crates/protoc-gen-protovalidate-buffa/src/emit/cel.rs index a1cfa89..16f369f 100644 --- a/crates/protoc-gen-protovalidate-buffa/src/emit/cel.rs +++ b/crates/protoc-gen-protovalidate-buffa/src/emit/cel.rs @@ -43,9 +43,7 @@ pub fn emit_message_level(msg: &MessageValidators) -> (Vec, Vec (Vec, Vec { + crate::scan::FieldKind::Message { .. } => { quote! { if let Some(inner) = self.#field_ident.as_option() { if let Err(v) = #ident.#method( @@ -115,6 +113,19 @@ pub fn emit_message_level(msg: &MessageValidators) -> (Vec, Vec { + quote! { + if let Some(inner) = self.#field_ident.as_option() { + if let Err(v) = #ident.#method( + ::protovalidate_buffa::cel::to_cel_value(&inner.value), + #fp, + #idx_lit, + ) { + violations.push(v); + } + } + } + } crate::scan::FieldKind::Optional(_) => quote! { if let Some(v) = &self.#field_ident { if let Err(viol) = #ident.#method( @@ -469,24 +480,24 @@ pub(crate) fn predef_family_for( /// message-typed fields to avoid requiring downstream `ToCelValue` bounds on /// proto-generated structs. /// -/// WKT fields (`google.protobuf.*`) are skipped — they do not implement `AsCelValue`. -/// Repeated WKT fields are also skipped. +/// WKT fields are included when the runtime provides a CEL representation. /// /// # Errors /// /// Returns an error if a field's Rust path cannot be parsed from the proto type name. pub fn emit_as_cel_value(msg: &MessageValidators, rust_path: &Path) -> Result { + let parent_msg_name = msg.proto_name.rsplit('.').next().unwrap_or(&msg.proto_name); let inserts: Vec = msg .field_rules .iter() // Skip inner-only synthetic validators (field_number == -1 means no real field). .filter(|f| f.field_number != -1) - // Skip oneof-member fields: they're represented as `Option` enums - // in buffa, not as flat struct fields, so `self.` would not compile. - .filter(|f| f.oneof_name.is_none()) - // Skip WKT message fields — no AsCelValue impl for google.protobuf.* - .filter(|f| !is_wkt_field(&f.field_type)) + // Skip WKT message fields without a runtime CEL representation. + .filter(|f| !is_unsupported_wkt_field_for_cel(&f.field_type)) .map(|f| { + if let Some(oneof_name) = &f.oneof_name { + return oneof_field_to_cel_insert(f, parent_msg_name, oneof_name); + } let field_ident = format_ident!("{}", f.field_name); let field_name = &f.field_name; // For fields with explicit presence (Optional scalar / Message), @@ -508,6 +519,14 @@ pub fn emit_as_cel_value(msg: &MessageValidators, rust_path: &Path) -> Result quote! { + if let Some(v) = self.#field_ident.as_option() { + map.insert( + ::std::string::String::from(#field_name), + ::protovalidate_buffa::cel::to_cel_value(&v.value), + ); + } + }, _ => { let insert_val = field_to_cel_value_expr(f, &field_ident); quote! { @@ -543,16 +562,31 @@ pub fn emit_as_cel_value(msg: &MessageValidators, rust_path: &Path) -> Result bool { +fn is_unsupported_wkt_field_for_cel(kind: &FieldKind) -> bool { match kind { - FieldKind::Message { full_name } => full_name.starts_with("google.protobuf."), - FieldKind::Repeated(inner) | FieldKind::Optional(inner) => is_wkt_field(inner), + FieldKind::Message { full_name } => is_unsupported_wkt_for_cel(full_name), + FieldKind::Repeated(inner) | FieldKind::Optional(inner) => { + is_unsupported_wkt_field_for_cel(inner) + } _ => false, } } +pub(crate) fn is_unsupported_wkt_for_cel(full_name: &str) -> bool { + full_name.starts_with("google.protobuf.") && !supports_wkt_as_cel_value(full_name) +} + +pub(crate) const fn supports_wkt_as_cel_value(full_name: &str) -> bool { + matches!( + full_name.as_bytes(), + b"google.protobuf.Any" + | b"google.protobuf.Empty" + | b"google.protobuf.FieldMask" + | b"google.protobuf.Timestamp" + | b"google.protobuf.Duration" + ) +} + /// Generate the expression that converts a field to a CEL Value for insertion /// into the `AsCelValue` map. /// @@ -580,6 +614,14 @@ fn field_to_cel_value_expr(f: &FieldValidator, field_ident: &syn::Ident) -> Toke } } } + FieldKind::Wrapper(_) => { + quote! { + match self.#field_ident.as_option() { + Some(v) => ::protovalidate_buffa::cel::to_cel_value(&v.value), + None => ::protovalidate_buffa::cel_core::Value::Null, + } + } + } FieldKind::Repeated(inner) => { match inner.as_ref() { FieldKind::Message { .. } => { @@ -611,6 +653,69 @@ fn field_to_cel_value_expr(f: &FieldValidator, field_ident: &syn::Ident) -> Toke } } +fn oneof_field_to_cel_insert( + f: &FieldValidator, + parent_msg_name: &str, + oneof_name: &str, +) -> TokenStream { + let oneof_ident = format_ident!("{}", oneof_name); + let module_ident = format_ident!("{}", to_snake_case(parent_msg_name)); + let oneof_enum_ident = format_ident!("{}", to_pascal_case(oneof_name)); + let variant_ident = format_ident!("{}", to_pascal_case(&f.field_name)); + let field_name = &f.field_name; + let value_expr = quote! { v }; + let value = oneof_value_to_cel_expr(&f.field_type, &value_expr); + quote! { + if let Some(__buffa::oneof::#module_ident::#oneof_enum_ident::#variant_ident(v)) = &self.#oneof_ident { + map.insert( + ::std::string::String::from(#field_name), + #value, + ); + } + } +} + +fn oneof_value_to_cel_expr(kind: &FieldKind, value: &TokenStream) -> TokenStream { + match kind { + FieldKind::Message { .. } => quote! { + ::protovalidate_buffa::cel::AsCelValue::as_cel_value(#value.as_ref()) + }, + FieldKind::Wrapper(_) => quote! { + ::protovalidate_buffa::cel::to_cel_value(&#value.value) + }, + _ => quote! { + ::protovalidate_buffa::cel::to_cel_value(#value) + }, + } +} + +fn to_snake_case(s: &str) -> String { + let chars: Vec = s.chars().collect(); + let mut out = String::with_capacity(s.len() + 2); + for (i, &c) in chars.iter().enumerate() { + if c.is_uppercase() && i > 0 { + let prev = chars[i - 1]; + let next_is_lower = chars.get(i + 1).is_some_and(|n| n.is_lowercase()); + if prev.is_lowercase() || (prev.is_uppercase() && next_is_lower) { + out.push('_'); + } + } + out.push(c.to_ascii_lowercase()); + } + out +} + +fn to_pascal_case(s: &str) -> String { + s.split('_') + .map(|part| { + let mut chars = part.chars(); + chars.next().map_or_else(String::new, |c| { + c.to_uppercase().collect::() + chars.as_str() + }) + }) + .collect() +} + /// Build the identifier for a static CEL constraint. /// /// e.g. `proto_name` `"test.v1.UpdatePomRequest"`, id `"update_pom.pom.id_required"` → diff --git a/crates/protoc-gen-protovalidate-buffa/src/emit/field.rs b/crates/protoc-gen-protovalidate-buffa/src/emit/field.rs index ffdcc84..6cdcf15 100644 --- a/crates/protoc-gen-protovalidate-buffa/src/emit/field.rs +++ b/crates/protoc-gen-protovalidate-buffa/src/emit/field.rs @@ -474,7 +474,8 @@ pub fn emit(field: &FieldValidator) -> Result { } FieldKind::Bool => { if let Some(b) = &field.standard.bool_rules { - blocks.extend(emit_bool(&accessor, name_lit, field.field_number, b)); + let value = quote! { self.#accessor }; + blocks.extend(emit_bool(&value, name_lit, field.field_number, b)); } } FieldKind::Wrapper(inner) => { @@ -948,25 +949,12 @@ fn emit_optional_inner( } FieldKind::Bytes => { if let Some(b) = &field.standard.bytes { - out.extend(emit_bytes_on(&v, name_lit, b)); + out.extend(emit_bytes_checks_on(&v, name_lit, field.field_number, b)); } } FieldKind::Bool => { - if let Some(b) = &field.standard.bool_rules - && let Some(c) = b.r#const - { - let fp = field_path_scalar(name_lit, field.field_number, "Bool"); - let rp = rule_path_scalar("bool", 13, "const", 1, "Bool"); - out.push(quote! { - if #v != #c { - violations.push(::protovalidate_buffa::Violation { - field: #fp, rule: #rp, - rule_id: ::std::borrow::Cow::Borrowed("bool.const"), - message: ::std::borrow::Cow::Borrowed(""), - for_key: false, - }); - } - }); + if let Some(b) = &field.standard.bool_rules { + out.extend(emit_bool_checks_on(&v, name_lit, field.field_number, b)); } } FieldKind::Enum { .. } @@ -981,37 +969,35 @@ fn emit_optional_inner( out } -fn emit_bytes_on(val: &syn::Ident, _name_lit: &str, b: &BytesStandard) -> Vec { - let mut out: Vec = Vec::new(); - if let Some(n) = b.min_len { - let n_usize = usize::try_from(n).expect("proto length bound fits in usize"); - out.push(quote! { - if #val.len() < #n_usize { - violations.push(::protovalidate_buffa::Violation { - field: ::protovalidate_buffa::FieldPath::default(), - rule: ::protovalidate_buffa::FieldPath::default(), - rule_id: ::std::borrow::Cow::Borrowed("bytes.min_len"), - message: ::std::borrow::Cow::Borrowed(""), - for_key: false, - }); - } - }); - } - if let Some(n) = b.max_len { - let n_usize = usize::try_from(n).expect("proto length bound fits in usize"); - out.push(quote! { - if #val.len() > #n_usize { - violations.push(::protovalidate_buffa::Violation { - field: ::protovalidate_buffa::FieldPath::default(), - rule: ::protovalidate_buffa::FieldPath::default(), - rule_id: ::std::borrow::Cow::Borrowed("bytes.max_len"), - message: ::std::borrow::Cow::Borrowed(""), - for_key: false, - }); - } - }); - } - out +pub(crate) fn emit_bytes_checks_on( + val: &syn::Ident, + name_lit: &str, + field_number: i32, + b: &BytesStandard, +) -> Vec { + let value = quote! { #val }; + emit_bytes_value(&value, name_lit, field_number, b) +} + +pub(crate) fn emit_bool_checks_on( + val: &syn::Ident, + name_lit: &str, + field_number: i32, + b: &BoolStandard, +) -> Vec { + let value = quote! { #val }; + emit_bool(&value, name_lit, field_number, b) +} + +pub(crate) fn emit_enum_checks_on( + val: &syn::Ident, + name_lit: &str, + field_number: i32, + e: &EnumStandard, + full_name: &str, +) -> Result> { + let value = quote! { #val }; + emit_enum_value(&value, name_lit, field_number, e, full_name) } // Variants of numeric/float emitters that take an explicit `value_ident` @@ -1983,16 +1969,28 @@ fn emit_bytes( name_lit: &str, field_number: i32, b: &BytesStandard, +) -> Vec { + let value = quote! { self.#accessor }; + emit_bytes_value(&value, name_lit, field_number, b) +} + +fn emit_bytes_value( + value: &TokenStream, + name_lit: &str, + field_number: i32, + b: &BytesStandard, ) -> Vec { let mut out: Vec = Vec::new(); let fp = || bytes_field_path(name_lit, field_number); + let value_len = quote! { #value.len() }; + let value_slice = quote! { #value.as_slice() }; // bytes.ip = 4 or 16 bytes; bytes.ipv4 = 4 bytes; bytes.ipv6 = 16 bytes. if b.ip == Some(true) { let field = fp(); let rule = bytes_rule_path_ty("ip", 10, "Bool"); out.push(quote! { - if self.#accessor.len() != 4 && self.#accessor.len() != 16 { + if #value_len != 4 && #value_len != 16 { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.ip"), @@ -2006,7 +2004,7 @@ fn emit_bytes( let field = fp(); let rule = bytes_rule_path_ty("ipv4", 11, "Bool"); out.push(quote! { - if self.#accessor.len() != 4 { + if #value_len != 4 { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.ipv4"), @@ -2020,7 +2018,7 @@ fn emit_bytes( let field = fp(); let rule = bytes_rule_path_ty("ipv6", 12, "Bool"); out.push(quote! { - if self.#accessor.len() != 16 { + if #value_len != 16 { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.ipv6"), @@ -2037,14 +2035,14 @@ fn emit_bytes( out.push(quote! { { // Empty is a special-cased _empty rule. - if self.#accessor.is_empty() { + if #value_slice.is_empty() { violations.push(::protovalidate_buffa::Violation { field: #field.clone(), rule: #rule.clone(), rule_id: ::std::borrow::Cow::Borrowed("bytes.uuid_empty"), message: ::std::borrow::Cow::Borrowed(""), for_key: false, }); - } else if self.#accessor.len() != 16 { + } else if #value_len != 16 { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.uuid"), @@ -2063,7 +2061,7 @@ fn emit_bytes( out.push(quote! { { const EXPECTED: &[u8] = &[ #( #set ),* ]; - if self.#accessor.as_slice() != EXPECTED { + if #value_slice != EXPECTED { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.const"), @@ -2080,13 +2078,13 @@ fn emit_bytes( let field = fp(); let rule = bytes_rule_path_ty("min_len", 2, "Uint64"); out.push(quote! { - if self.#accessor.len() < #n_usize { + if #value_len < #n_usize { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.min_len"), message: ::std::borrow::Cow::Owned(::std::format!( "value length must be at least {} bytes (got {})", - #n_usize, self.#accessor.len() + #n_usize, #value_len )), for_key: false, }); @@ -2099,13 +2097,13 @@ fn emit_bytes( let field = fp(); let rule = bytes_rule_path_ty("len", 13, "Uint64"); out.push(quote! { - if self.#accessor.len() != #n_usize { + if #value_len != #n_usize { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.len"), message: ::std::borrow::Cow::Owned(::std::format!( "value must be exactly {} bytes (got {})", - #n_usize, self.#accessor.len() + #n_usize, #value_len )), for_key: false, }); @@ -2118,13 +2116,13 @@ fn emit_bytes( let field = fp(); let rule = bytes_rule_path_ty("max_len", 3, "Uint64"); out.push(quote! { - if self.#accessor.len() > #n_usize { + if #value_len > #n_usize { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.max_len"), message: ::std::borrow::Cow::Owned(::std::format!( "value length must be at most {} bytes (got {})", - #n_usize, self.#accessor.len() + #n_usize, #value_len )), for_key: false, }); @@ -2134,7 +2132,7 @@ fn emit_bytes( if let Some(pat) = &b.pattern { let pat_str = pat.as_str(); - let cache_ident = format_ident!("RE_BYTES_{}", accessor.to_string().to_uppercase()); + let cache_ident = format_ident!("RE_BYTES_{}", name_lit.to_uppercase()); let field = fp(); let rule = bytes_rule_path_ty("pattern", 4, "String"); out.push(quote! { @@ -2145,7 +2143,7 @@ fn emit_bytes( ::protovalidate_buffa::regex::Regex::new(#pat_str) .expect("pattern regex compiled at code-gen time") }); - match ::std::str::from_utf8(&self.#accessor) { + match ::std::str::from_utf8(#value_slice) { Ok(s) => { if !re.is_match(s) { violations.push(::protovalidate_buffa::Violation { @@ -2180,7 +2178,7 @@ fn emit_bytes( out.push(quote! { { const PREFIX: &[u8] = &[ #( #p ),* ]; - if !self.#accessor.starts_with(PREFIX) { + if !#value_slice.starts_with(PREFIX) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.prefix"), @@ -2199,7 +2197,7 @@ fn emit_bytes( out.push(quote! { { const SUFFIX: &[u8] = &[ #( #p ),* ]; - if !self.#accessor.ends_with(SUFFIX) { + if !#value_slice.ends_with(SUFFIX) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.suffix"), @@ -2218,7 +2216,7 @@ fn emit_bytes( out.push(quote! { { const NEEDLE: &[u8] = &[ #( #p ),* ]; - if !self.#accessor.windows(NEEDLE.len()).any(|w| w == NEEDLE) { + if !#value_slice.windows(NEEDLE.len()).any(|w| w == NEEDLE) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.contains"), @@ -2245,7 +2243,7 @@ fn emit_bytes( out.push(quote! { { let allowed: &[&[u8]] = &[ #( #bytes_lits ),* ]; - if !allowed.iter().any(|a| *a == self.#accessor.as_slice()) { + if !allowed.iter().any(|a| *a == #value_slice) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.in"), @@ -2271,7 +2269,7 @@ fn emit_bytes( out.push(quote! { { let disallowed: &[&[u8]] = &[ #( #bytes_lits ),* ]; - if disallowed.iter().any(|a| *a == self.#accessor.as_slice()) { + if disallowed.iter().any(|a| *a == #value_slice) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, rule_id: ::std::borrow::Cow::Borrowed("bytes.not_in"), @@ -3429,6 +3427,17 @@ fn emit_enum( field_number: i32, e: &EnumStandard, full_name: &str, +) -> Result> { + let value = quote! { self.#accessor }; + emit_enum_value(&value, name_lit, field_number, e, full_name) +} + +fn emit_enum_value( + value: &TokenStream, + name_lit: &str, + field_number: i32, + e: &EnumStandard, + full_name: &str, ) -> Result> { // EnumRules outer field number = 16; inner: const=1 (TYPE_INT32), // defined_only=2 (TYPE_BOOL), in=3 (TYPE_INT32), not_in=4 (TYPE_INT32). @@ -3442,7 +3451,7 @@ fn emit_enum( let rule = rule_path("const", 1, "Int32"); out.push(quote! { { - let raw_val: i32 = self.#accessor.to_i32(); + let raw_val: i32 = #value.to_i32(); if raw_val != #c { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, @@ -3463,7 +3472,7 @@ fn emit_enum( let rule = rule_path("defined_only", 2, "Bool"); out.push(quote! { { - let raw_val: i32 = self.#accessor.to_i32(); + let raw_val: i32 = #value.to_i32(); if <#enum_type as ::buffa::Enumeration>::from_i32(raw_val).is_none() { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, @@ -3485,7 +3494,7 @@ fn emit_enum( out.push(quote! { { const ALLOWED: &[i32] = &[ #( #set ),* ]; - let raw_val: i32 = self.#accessor.to_i32(); + let raw_val: i32 = #value.to_i32(); if !ALLOWED.contains(&raw_val) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, @@ -3507,7 +3516,7 @@ fn emit_enum( out.push(quote! { { const DISALLOWED: &[i32] = &[ #( #set ),* ]; - let raw_val: i32 = self.#accessor.to_i32(); + let raw_val: i32 = #value.to_i32(); if DISALLOWED.contains(&raw_val) { violations.push(::protovalidate_buffa::Violation { field: #field, rule: #rule, @@ -3634,7 +3643,7 @@ fn to_snake_case(s: &str) -> String { /// entries where the inner scalar type is known. /// Wrapper-specific: the outer field (in the rule path) is TYPE_MESSAGE /// (the wrapper), not the inner scalar. -fn emit_wrapper_inner( +pub(crate) fn emit_wrapper_inner( name_lit: &str, field_number: i32, inner: &FieldKind, @@ -4155,6 +4164,7 @@ pub(crate) fn emit_numeric_checks_on( _ => return Vec::new(), }; let field_path_expr = field_path_scalar(name_lit, field_number, fam.scalar_ty); + let value = quote! { #v }; let mut out: Vec = Vec::new(); let push_cmp = |out: &mut Vec, inner_name: &str, @@ -4180,10 +4190,33 @@ pub(crate) fn emit_numeric_checks_on( } }); }; - let is_float = matches!(kind, FieldKind::Float | FieldKind::Double); match kind { FieldKind::Int32 | FieldKind::Sint32 | FieldKind::Sfixed32 => { if let Some(n) = &std.int32 { + if let (Some(lo), Some(hi)) = (n.gt, n.lt) { + out.push(range_check( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (n.gte, n.lte) { + out.push(range_check( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = n.r#const { push_cmp( &mut out, @@ -4229,10 +4262,84 @@ pub(crate) fn emit_numeric_checks_on( quote! { #v > #hi }, ); } + if !n.in_set.is_empty() { + let set = &n.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[i32] = &[ #( #set ),* ]; + if !ALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !n.not_in.is_empty() { + let set = &n.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[i32] = &[ #( #set ),* ]; + if DISALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } } } FieldKind::Int64 | FieldKind::Sint64 | FieldKind::Sfixed64 => { if let Some(n) = &std.int64 { + if let (Some(lo), Some(hi)) = (n.gt, n.lt) { + out.push(range_check( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (n.gte, n.lte) { + out.push(range_check( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = n.r#const { push_cmp( &mut out, @@ -4278,10 +4385,84 @@ pub(crate) fn emit_numeric_checks_on( quote! { #v > #hi }, ); } + if !n.in_set.is_empty() { + let set = &n.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[i64] = &[ #( #set ),* ]; + if !ALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !n.not_in.is_empty() { + let set = &n.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[i64] = &[ #( #set ),* ]; + if DISALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } } } FieldKind::Uint32 | FieldKind::Fixed32 => { if let Some(n) = &std.uint32 { + if let (Some(lo), Some(hi)) = (n.gt, n.lt) { + out.push(range_check( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (n.gte, n.lte) { + out.push(range_check( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = n.r#const { push_cmp( &mut out, @@ -4327,10 +4508,84 @@ pub(crate) fn emit_numeric_checks_on( quote! { #v > #hi }, ); } + if !n.in_set.is_empty() { + let set = &n.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[u32] = &[ #( #set ),* ]; + if !ALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !n.not_in.is_empty() { + let set = &n.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[u32] = &[ #( #set ),* ]; + if DISALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } } } FieldKind::Uint64 | FieldKind::Fixed64 => { if let Some(n) = &std.uint64 { + if let (Some(lo), Some(hi)) = (n.gt, n.lt) { + out.push(range_check( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (n.gte, n.lte) { + out.push(range_check( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = n.r#const { push_cmp( &mut out, @@ -4376,10 +4631,84 @@ pub(crate) fn emit_numeric_checks_on( quote! { #v > #hi }, ); } + if !n.in_set.is_empty() { + let set = &n.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[u64] = &[ #( #set ),* ]; + if !ALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !n.not_in.is_empty() { + let set = &n.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[u64] = &[ #( #set ),* ]; + if DISALLOWED.contains(&#v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } } } FieldKind::Float => { if let Some(f) = &std.float { + if let (Some(lo), Some(hi)) = (f.gt, f.lt) { + out.push(range_check_fp( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (f.gte, f.lte) { + out.push(range_check_fp( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = f.r#const { push_cmp( &mut out, @@ -4425,10 +4754,93 @@ pub(crate) fn emit_numeric_checks_on( quote! { !(#v <= #hi) }, ); } + if !f.in_set.is_empty() { + let set = &f.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[f32] = &[ #( #set ),* ]; + if !ALLOWED.iter().any(|c| *c == #v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !f.not_in.is_empty() { + let set = &f.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[f32] = &[ #( #set ),* ]; + if DISALLOWED.iter().any(|c| *c == #v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if f.finite { + push_cmp( + &mut out, + "finite", + INNER_FINITE, + format!("{}.finite", fam.family), + quote! { !::protovalidate_buffa::rules::float::is_finite_f32(#v) }, + ); + } } } FieldKind::Double => { if let Some(d) = &std.double { + if let (Some(lo), Some(hi)) = (d.gt, d.lt) { + out.push(range_check_fp( + &field_path_expr, + fam, + false, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } + if let (Some(lo), Some(hi)) = (d.gte, d.lte) { + out.push(range_check_fp( + &field_path_expr, + fam, + true, + "e! { #lo }, + "e! { #hi }, + &value, + hi < lo, + )); + return out; + } if let Some(c) = d.r#const { push_cmp( &mut out, @@ -4474,11 +4886,69 @@ pub(crate) fn emit_numeric_checks_on( quote! { !(#v <= #hi) }, ); } + if !d.in_set.is_empty() { + let set = &d.in_set; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "in", + INNER_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.in", fam.family); + out.push(quote! { + { + const ALLOWED: &[f64] = &[ #( #set ),* ]; + if !ALLOWED.iter().any(|c| *c == #v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !d.not_in.is_empty() { + let set = &d.not_in; + let field = field_path_expr.clone(); + let rule = rule_path_scalar( + fam.family, + fam.outer_number, + "not_in", + INNER_NOT_IN, + fam.scalar_ty, + ); + let rule_id = format!("{}.not_in", fam.family); + out.push(quote! { + { + const DISALLOWED: &[f64] = &[ #( #set ),* ]; + if DISALLOWED.iter().any(|c| *c == #v) { + violations.push(::protovalidate_buffa::Violation { + field: #field, rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if d.finite { + push_cmp( + &mut out, + "finite", + INNER_FINITE, + format!("{}.finite", fam.family), + quote! { !::protovalidate_buffa::rules::float::is_finite_f64(#v) }, + ); + } } } _ => {} } - let _ = is_float; out } @@ -4586,7 +5056,7 @@ const INNER_IN: i32 = 6; const INNER_NOT_IN: i32 = 7; fn emit_bool( - accessor: &syn::Ident, + value: &TokenStream, name_lit: &str, field_number: i32, b: &BoolStandard, @@ -4596,13 +5066,13 @@ fn emit_bool( let field_path = field_path_scalar(name_lit, field_number, "Bool"); let rule_path = rule_path_scalar("bool", 13, "const", 1, "Bool"); out.push(quote! { - if self.#accessor != #c { + if #value != #c { violations.push(::protovalidate_buffa::Violation { field: #field_path, rule: #rule_path, rule_id: ::std::borrow::Cow::Borrowed("bool.const"), message: ::std::borrow::Cow::Owned(::std::format!( - "value must equal {} (got {})", #c, self.#accessor + "value must equal {} (got {})", #c, #value )), for_key: false, }); diff --git a/crates/protoc-gen-protovalidate-buffa/src/emit/mod.rs b/crates/protoc-gen-protovalidate-buffa/src/emit/mod.rs index c5053b9..318e81e 100644 --- a/crates/protoc-gen-protovalidate-buffa/src/emit/mod.rs +++ b/crates/protoc-gen-protovalidate-buffa/src/emit/mod.rs @@ -25,6 +25,7 @@ pub mod repeated; pub fn render(messages: &[MessageValidators]) -> Result> { use std::collections::BTreeMap; + let cel_set = cel::cel_value_set(messages); let mut by_file: BTreeMap> = BTreeMap::new(); for m in messages { by_file.entry(m.source_file.clone()).or_default().push(m); @@ -32,7 +33,7 @@ pub fn render(messages: &[MessageValidators]) -> Result> { let mut files = Vec::new(); for (source_file, msgs) in by_file { - let body = render_file(&msgs)?; + let body = render_file(&msgs, &cel_set)?; let stem = source_file.trim_end_matches(".proto").replace('/', "."); let path = format!("{stem}.rs"); let body_str = body.to_string(); @@ -50,13 +51,13 @@ pub fn render(messages: &[MessageValidators]) -> Result> { Ok(files) } -fn render_file(msgs: &[&MessageValidators]) -> Result { - // Build the set of proto names that need AsCelValue impls. - let cel_set = cel::cel_value_set(msgs.iter().copied()); - +fn render_file( + msgs: &[&MessageValidators], + cel_set: &std::collections::HashSet, +) -> Result { let impls: Vec = msgs .iter() - .map(|m| render_message(m, &cel_set)) + .map(|m| render_message(m, cel_set)) .collect::>()?; Ok(quote! { use super::*; diff --git a/crates/protoc-gen-protovalidate-buffa/src/emit/oneof.rs b/crates/protoc-gen-protovalidate-buffa/src/emit/oneof.rs index fd1f70f..8c17282 100644 --- a/crates/protoc-gen-protovalidate-buffa/src/emit/oneof.rs +++ b/crates/protoc-gen-protovalidate-buffa/src/emit/oneof.rs @@ -213,12 +213,19 @@ fn has_field_rules(f: &FieldValidator) -> bool { f.required || f.standard.string.is_some() || f.standard.bytes.is_some() + || f.standard.bool_rules.is_some() + || f.standard.enum_rules.is_some() || f.standard.int32.is_some() || f.standard.int64.is_some() || f.standard.uint32.is_some() || f.standard.uint64.is_some() || f.standard.float.is_some() || f.standard.double.is_some() + || f.standard.any_rules.is_some() + || f.standard.duration.is_some() + || f.standard.timestamp.is_some() + || f.standard.field_mask.is_some() + || !f.standard.predefined.is_empty() || !f.cel.is_empty() || matches!(f.field_type, FieldKind::Message { ref full_name } if !full_name.starts_with("google.protobuf.")) } @@ -253,6 +260,16 @@ fn emit_variant_arm(v: &OneofValidator, f: &FieldValidator) -> Result { + if let Some(b) = &f.standard.bytes { + checks.extend(crate::emit::field::emit_bytes_checks_on( + &val_ident, + name_lit, + f.field_number, + b, + )); + } + } FieldKind::Int32 | FieldKind::Sint32 | FieldKind::Sfixed32 @@ -273,6 +290,27 @@ fn emit_variant_arm(v: &OneofValidator, f: &FieldValidator) -> Result { + if let Some(b) = &f.standard.bool_rules { + checks.extend(crate::emit::field::emit_bool_checks_on( + &val_ident, + name_lit, + f.field_number, + b, + )); + } + } + FieldKind::Enum { full_name } => { + if let Some(e) = &f.standard.enum_rules { + checks.extend(crate::emit::field::emit_enum_checks_on( + &val_ident, + name_lit, + f.field_number, + e, + full_name, + )?); + } + } FieldKind::Message { full_name } if !full_name.starts_with("google.protobuf.") => { let fnum = f.field_number; let nlit = name_lit.clone(); @@ -292,9 +330,27 @@ fn emit_variant_arm(v: &OneofValidator, f: &FieldValidator) -> Result { + checks.extend(emit_oneof_wkt_checks(f, full_name, &val_ident)); + } + FieldKind::Wrapper(inner) => { + let inner_checks = + crate::emit::field::emit_wrapper_inner(name_lit, f.field_number, inner, f); + if !inner_checks.is_empty() { + checks.push(quote! { + { + let v = v.value.clone(); + #( #inner_checks )* + } + }); + } + } _ => {} } + checks.extend(emit_oneof_field_cel(v, f, &val_ident)); + checks.extend(emit_oneof_predefined(v, f, &val_ident)); + if checks.is_empty() { // Arm with no checks — return a wildcard so caller can decide. return Ok(quote! {}); @@ -336,6 +392,678 @@ fn emit_variant_arm(v: &OneofValidator, f: &FieldValidator) -> Result`. +fn emit_oneof_field_cel( + oneof: &OneofValidator, + f: &FieldValidator, + val_ident: &syn::Ident, +) -> Vec { + if f.cel.is_empty() { + return Vec::new(); + } + if let FieldKind::Message { full_name } = &f.field_type + && crate::emit::cel::is_unsupported_wkt_for_cel(full_name) + { + return Vec::new(); + } + + let field_path = oneof_field_path(f); + let this_expr = oneof_value_to_cel_expr(&f.field_type, val_ident); + let mut out = Vec::new(); + let mut cel_idx: u64 = 0; + let mut expr_idx: u64 = 0; + + for rule in &f.cel { + let ident = crate::emit::cel::const_ident( + &format!("{}_{}", oneof.parent_msg_name, f.field_name), + &rule.id, + ); + let id = &rule.id; + let message = &rule.message; + let expr = &rule.expression; + let (idx_lit, method) = if rule.is_cel_expression { + let i = expr_idx; + expr_idx += 1; + (i, format_ident!("eval_expr_value_at")) + } else { + let i = cel_idx; + cel_idx += 1; + (i, format_ident!("eval_value_at")) + }; + let fp = field_path.clone(); + let this = this_expr.clone(); + out.push(quote! { + static #ident: ::protovalidate_buffa::cel::CelConstraint = + ::protovalidate_buffa::cel::CelConstraint::new(#id, #message, #expr); + if let Err(viol) = #ident.#method(#this, #fp, #idx_lit) { + violations.push(viol); + } + }); + } + + out +} + +fn emit_oneof_predefined( + oneof: &OneofValidator, + f: &FieldValidator, + val_ident: &syn::Ident, +) -> Vec { + if f.standard.predefined.is_empty() { + return Vec::new(); + } + let default_family = crate::emit::cel::predef_family_for(&f.field_type, &f.standard); + let field_path = oneof_field_path(f); + let this_expr = oneof_value_to_cel_expr(&f.field_type, val_ident); + let mut out = Vec::new(); + + for (pi, rule) in f.standard.predefined.iter().enumerate() { + let family = match rule.family_override { + Some((name, number)) => crate::emit::cel::Family { name, number }, + None => match default_family { + Some(family) => family, + None => continue, + }, + }; + let ident = format_ident!( + "{}", + format!( + "CEL_{}_{}_PRED{}_{}_{}", + oneof + .parent_msg_name + .replace(|c: char| !c.is_ascii_alphanumeric(), "_") + .to_uppercase(), + f.field_name.to_uppercase(), + pi, + rule.ext_number, + rule.id + .replace(|c: char| !c.is_ascii_alphanumeric(), "_") + .to_uppercase(), + ) + ); + let id = &rule.id; + let message = &rule.message; + let expr = &rule.expression; + let family_name = family.name; + let family_num = family.number; + let ext_num = rule.ext_number; + let ext_fty = format_ident!("{}", rule.ext_field_type); + let ext_bracketed = format!("[buf.validate.conformance.cases.{}]", rule.ext_name); + let rule_value: TokenStream = syn::parse_str(&rule.rule_value_expr) + .unwrap_or_else(|_| quote! { ::protovalidate_buffa::cel_core::Value::Null }); + let rule_path = quote! { + ::protovalidate_buffa::FieldPath { + elements: ::std::vec![ + ::protovalidate_buffa::FieldPathElement { + field_number: Some(#family_num), + field_name: Some(::std::borrow::Cow::Borrowed(#family_name)), + field_type: Some(::protovalidate_buffa::FieldType::Message), + key_type: None, + value_type: None, + subscript: None, + }, + ::protovalidate_buffa::FieldPathElement { + field_number: Some(#ext_num), + field_name: Some(::std::borrow::Cow::Borrowed(#ext_bracketed)), + field_type: Some(::protovalidate_buffa::FieldType::#ext_fty), + key_type: None, + value_type: None, + subscript: None, + }, + ], + } + }; + let fp = field_path.clone(); + let this = this_expr.clone(); + out.push(quote! { + static #ident: ::protovalidate_buffa::cel::CelConstraint = + ::protovalidate_buffa::cel::CelConstraint::new(#id, #message, #expr); + if let Err(viol) = #ident.eval_predefined(#this, #rule_value, #fp, #rule_path) { + violations.push(viol); + } + }); + } + + out +} + +#[expect( + clippy::too_many_lines, + reason = "codegen helper mirrors the WKT rule families oneof members can carry" +)] +fn emit_oneof_wkt_checks( + f: &FieldValidator, + full_name: &str, + val_ident: &syn::Ident, +) -> Vec { + let mut out = Vec::new(); + let field_path = oneof_field_path(f); + + if full_name == "google.protobuf.Any" + && let Some(any) = &f.standard.any_rules + { + if !any.in_set.is_empty() { + let set = &any.in_set; + let field = &field_path; + let rule = oneof_rule_path("any", 20, "in", 2, "String"); + out.push(quote! { + { + const ALLOWED: &[&str] = &[ #( #set ),* ]; + if !ALLOWED.iter().any(|s| *s == #val_ident.type_url.as_str()) { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("any.in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !any.not_in.is_empty() { + let set = &any.not_in; + let field = &field_path; + let rule = oneof_rule_path("any", 20, "not_in", 3, "String"); + out.push(quote! { + { + const DISALLOWED: &[&str] = &[ #( #set ),* ]; + if DISALLOWED.iter().any(|s| *s == #val_ident.type_url.as_str()) { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("any.not_in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + } + + if full_name == "google.protobuf.FieldMask" + && let Some(field_mask) = &f.standard.field_mask + { + if let Some(expected) = &field_mask.r#const { + let expected_lits = expected.iter().map(String::as_str); + let message = format!("must equal paths [{}]", expected.join(", ")); + let field = &field_path; + let rule = oneof_rule_path("field_mask", 28, "const", 1, "Message"); + out.push(quote! { + { + const EXPECTED: &[&str] = &[ #( #expected_lits ),* ]; + let actual: ::std::vec::Vec<&str> = #val_ident.paths.iter().map(|s| s.as_str()).collect(); + let eq = actual.len() == EXPECTED.len() + && actual.iter().zip(EXPECTED.iter()).all(|(a, b)| a == b); + if !eq { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("field_mask.const"), + message: ::std::borrow::Cow::Borrowed(#message), + for_key: false, + }); + } + } + }); + } + if !field_mask.in_set.is_empty() { + let allowed = field_mask.in_set.iter().map(String::as_str); + let field = &field_path; + let rule = oneof_rule_path("field_mask", 28, "in", 2, "String"); + out.push(quote! { + { + const ALLOWED: &[&str] = &[ #( #allowed ),* ]; + let ok = #val_ident.paths.iter().all(|p| { + ALLOWED.iter().any(|c| ::protovalidate_buffa::rules::string::fieldmask_covers(c, p.as_str())) + }); + if !ok { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("field_mask.in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !field_mask.not_in.is_empty() { + let denied = field_mask.not_in.iter().map(String::as_str); + let field = &field_path; + let rule = oneof_rule_path("field_mask", 28, "not_in", 3, "String"); + out.push(quote! { + { + const DENIED: &[&str] = &[ #( #denied ),* ]; + let bad = #val_ident.paths.iter().any(|p| { + DENIED.iter().any(|c| ::protovalidate_buffa::rules::string::fieldmask_covers(c, p.as_str()) + || ::protovalidate_buffa::rules::string::fieldmask_covers(p.as_str(), c)) + }); + if bad { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("field_mask.not_in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + } + + if full_name == "google.protobuf.Duration" + && let Some(duration) = &f.standard.duration + { + out.extend(emit_oneof_duration_checks(duration, val_ident, &field_path)); + } + + if full_name == "google.protobuf.Timestamp" + && let Some(timestamp) = &f.standard.timestamp + { + out.extend(emit_oneof_timestamp_checks( + timestamp, + val_ident, + &field_path, + )); + } + + out +} + +fn oneof_rule_path( + outer_name: &str, + outer_number: i32, + inner_name: &str, + inner_number: i32, + inner_type_variant: &str, +) -> TokenStream { + let inner_ty = format_ident!("{}", inner_type_variant); + quote! { + ::protovalidate_buffa::FieldPath { + elements: ::std::vec![ + ::protovalidate_buffa::FieldPathElement { + field_number: Some(#outer_number), + field_name: Some(::std::borrow::Cow::Borrowed(#outer_name)), + field_type: Some(::protovalidate_buffa::FieldType::Message), + key_type: None, + value_type: None, + subscript: None, + }, + ::protovalidate_buffa::FieldPathElement { + field_number: Some(#inner_number), + field_name: Some(::std::borrow::Cow::Borrowed(#inner_name)), + field_type: Some(::protovalidate_buffa::FieldType::#inner_ty), + key_type: None, + value_type: None, + subscript: None, + }, + ], + } + } +} + +fn duration_nanos_literal(value: (i64, i32)) -> TokenStream { + let total = (value.0 as i128) * 1_000_000_000 + (value.1 as i128); + let total_i128 = proc_macro2::Literal::i128_suffixed(total); + quote! { #total_i128 } +} + +const fn duration_nanos_lt(left: (i64, i32), right: (i64, i32)) -> bool { + let left_ns = (left.0 as i128) * 1_000_000_000 + (left.1 as i128); + let right_ns = (right.0 as i128) * 1_000_000_000 + (right.1 as i128); + left_ns < right_ns +} + +fn emit_oneof_duration_checks( + duration: &crate::scan::DurationStandard, + val_ident: &syn::Ident, + field_path: &TokenStream, +) -> Vec { + let actual = + quote! { ((#val_ident.seconds as i128) * 1_000_000_000 + (#val_ident.nanos as i128)) }; + emit_oneof_time_checks( + "duration", + 21, + duration.r#const, + duration.lt, + duration.lte, + duration.gt, + duration.gte, + &duration.in_set, + &duration.not_in, + None, + None, + None, + &actual, + field_path, + ) +} + +fn emit_oneof_timestamp_checks( + timestamp: &crate::scan::TimestampStandard, + val_ident: &syn::Ident, + field_path: &TokenStream, +) -> Vec { + let actual = + quote! { ((#val_ident.seconds as i128) * 1_000_000_000 + (#val_ident.nanos as i128)) }; + emit_oneof_time_checks( + "timestamp", + 22, + timestamp.r#const, + timestamp.lt, + timestamp.lte, + timestamp.gt, + timestamp.gte, + &[], + &[], + Some(timestamp.lt_now), + Some(timestamp.gt_now), + timestamp.within, + &actual, + field_path, + ) +} + +#[expect( + clippy::too_many_arguments, + reason = "shared codegen helper for DurationRules and TimestampRules" +)] +fn emit_oneof_time_checks( + family: &'static str, + family_number: i32, + const_value: Option<(i64, i32)>, + lt: Option<(i64, i32)>, + lte: Option<(i64, i32)>, + gt: Option<(i64, i32)>, + gte: Option<(i64, i32)>, + in_set: &[(i64, i32)], + not_in: &[(i64, i32)], + lt_now: Option, + gt_now: Option, + within: Option<(i64, i32)>, + actual: &TokenStream, + field_path: &TokenStream, +) -> Vec { + let mut out = Vec::new(); + let push = |out: &mut Vec, + inner: &str, + inner_num: i32, + inner_ty: &str, + rule_id: &'static str, + cond: TokenStream| { + let field = field_path.clone(); + let rule = oneof_rule_path(family, family_number, inner, inner_num, inner_ty); + out.push(quote! { + if #cond { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed(#rule_id), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + }); + }; + + if let (Some(lo), Some(hi)) = (gt, lt) { + let lo_ns = duration_nanos_literal(lo); + let hi_ns = duration_nanos_literal(hi); + let is_exclusive = duration_nanos_lt(hi, lo); + let rule_id = if family == "duration" { + if is_exclusive { + "duration.gt_lt_exclusive" + } else { + "duration.gt_lt" + } + } else if is_exclusive { + "timestamp.gt_lt_exclusive" + } else { + "timestamp.gt_lt" + }; + let cond = if is_exclusive { + quote! { #actual >= #hi_ns && #actual <= #lo_ns } + } else { + quote! { #actual <= #lo_ns || #actual >= #hi_ns } + }; + push(&mut out, "gt", 5, "Message", rule_id, cond); + return out; + } + if let (Some(lo), Some(hi)) = (gte, lte) { + let lo_ns = duration_nanos_literal(lo); + let hi_ns = duration_nanos_literal(hi); + let is_exclusive = duration_nanos_lt(hi, lo); + let rule_id = if family == "duration" { + if is_exclusive { + "duration.gte_lte_exclusive" + } else { + "duration.gte_lte" + } + } else if is_exclusive { + "timestamp.gte_lte_exclusive" + } else { + "timestamp.gte_lte" + }; + let cond = if is_exclusive { + quote! { #actual > #hi_ns && #actual < #lo_ns } + } else { + quote! { #actual < #lo_ns || #actual > #hi_ns } + }; + push(&mut out, "gte", 6, "Message", rule_id, cond); + return out; + } + + if let Some(value) = const_value { + let expected = duration_nanos_literal(value); + let rule_id = if family == "duration" { + "duration.const" + } else { + "timestamp.const" + }; + push( + &mut out, + "const", + 2, + "Message", + rule_id, + quote! { #actual != #expected }, + ); + } + if let Some(value) = lt { + let bound = duration_nanos_literal(value); + let rule_id = if family == "duration" { + "duration.lt" + } else { + "timestamp.lt" + }; + push( + &mut out, + "lt", + 3, + "Message", + rule_id, + quote! { #actual >= #bound }, + ); + } + if let Some(value) = lte { + let bound = duration_nanos_literal(value); + let rule_id = if family == "duration" { + "duration.lte" + } else { + "timestamp.lte" + }; + push( + &mut out, + "lte", + 4, + "Message", + rule_id, + quote! { #actual > #bound }, + ); + } + if let Some(value) = gt { + let bound = duration_nanos_literal(value); + let rule_id = if family == "duration" { + "duration.gt" + } else { + "timestamp.gt" + }; + push( + &mut out, + "gt", + 5, + "Message", + rule_id, + quote! { #actual <= #bound }, + ); + } + if let Some(value) = gte { + let bound = duration_nanos_literal(value); + let rule_id = if family == "duration" { + "duration.gte" + } else { + "timestamp.gte" + }; + push( + &mut out, + "gte", + 6, + "Message", + rule_id, + quote! { #actual < #bound }, + ); + } + if !in_set.is_empty() { + let values: Vec = in_set.iter().copied().map(duration_nanos_literal).collect(); + let field = field_path.clone(); + let rule = oneof_rule_path(family, family_number, "in", 7, "Message"); + out.push(quote! { + { + const ALLOWED: &[i128] = &[ #( #values ),* ]; + if !ALLOWED.iter().any(|x| *x == #actual) { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("duration.in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if !not_in.is_empty() { + let values: Vec = not_in.iter().copied().map(duration_nanos_literal).collect(); + let field = field_path.clone(); + let rule = oneof_rule_path(family, family_number, "not_in", 8, "Message"); + out.push(quote! { + { + const DISALLOWED: &[i128] = &[ #( #values ),* ]; + if DISALLOWED.iter().any(|x| *x == #actual) { + violations.push(::protovalidate_buffa::Violation { + field: #field, + rule: #rule, + rule_id: ::std::borrow::Cow::Borrowed("duration.not_in"), + message: ::std::borrow::Cow::Borrowed(""), + for_key: false, + }); + } + } + }); + } + if lt_now == Some(true) { + push( + &mut out, + "lt_now", + 7, + "Bool", + "timestamp.lt_now", + quote! { + #actual >= ::std::time::SystemTime::now() + .duration_since(::std::time::UNIX_EPOCH) + .map_or(0i128, |d| d.as_nanos() as i128) + }, + ); + } + if gt_now == Some(true) { + push( + &mut out, + "gt_now", + 8, + "Bool", + "timestamp.gt_now", + quote! { + #actual <= ::std::time::SystemTime::now() + .duration_since(::std::time::UNIX_EPOCH) + .map_or(0i128, |d| d.as_nanos() as i128) + }, + ); + } + if let Some(value) = within { + let bound = duration_nanos_literal(value); + push( + &mut out, + "within", + 9, + "Message", + "timestamp.within", + quote! { + { + let now_ns = ::std::time::SystemTime::now() + .duration_since(::std::time::UNIX_EPOCH) + .map_or(0i128, |d| d.as_nanos() as i128); + (#actual - now_ns).abs() > #bound + } + }, + ); + } + + out +} + +fn oneof_field_path(f: &FieldValidator) -> TokenStream { + let field_name = &f.field_name; + let field_number = f.field_number; + let field_type = if f.is_group { + "Group" + } else { + crate::emit::field::kind_to_field_type(&f.field_type) + }; + let field_type_ident = format_ident!("{}", field_type); + quote! { + ::protovalidate_buffa::FieldPath { + elements: ::std::vec![ + ::protovalidate_buffa::FieldPathElement { + field_number: Some(#field_number), + field_name: Some(::std::borrow::Cow::Borrowed(#field_name)), + field_type: Some(::protovalidate_buffa::FieldType::#field_type_ident), + key_type: None, + value_type: None, + subscript: None, + }, + ], + } + } +} + +fn oneof_value_to_cel_expr(kind: &FieldKind, val_ident: &syn::Ident) -> TokenStream { + match kind { + FieldKind::String | FieldKind::Bytes | FieldKind::Enum { .. } => { + quote! { ::protovalidate_buffa::cel::to_cel_value(#val_ident) } + } + FieldKind::Message { .. } => { + quote! { ::protovalidate_buffa::cel::AsCelValue::as_cel_value(#val_ident.as_ref()) } + } + FieldKind::Wrapper(_) => { + quote! { ::protovalidate_buffa::cel::to_cel_value(&#val_ident.value) } + } + _ => quote! { ::protovalidate_buffa::cel::to_cel_value(&#val_ident) }, + } +} + fn to_snake_case(s: &str) -> String { let chars: Vec = s.chars().collect(); let mut out = String::with_capacity(s.len() + 2); diff --git a/crates/protoc-gen-protovalidate-buffa/src/scan.rs b/crates/protoc-gen-protovalidate-buffa/src/scan.rs index 0cd2985..ec51b76 100644 --- a/crates/protoc-gen-protovalidate-buffa/src/scan.rs +++ b/crates/protoc-gen-protovalidate-buffa/src/scan.rs @@ -1270,7 +1270,8 @@ fn gather_message( .iter() .find_map(|f| check_rule_field_mismatch(&f.field_type, &f.standard)) .or_else(|| check_message_oneof_specs(&qualified_name, &message_oneofs, &field_rules_out)) - .or_else(|| check_message_cel_missing_fields(&message_cel, &field_rules_out)); + .or_else(|| check_message_cel_missing_fields(&message_cel, &field_rules_out)) + .or_else(|| check_message_cel_type_errors(&message_cel, &field_rules_out)); out.push(MessageValidators { proto_name: qualified_name.clone(), @@ -1414,6 +1415,63 @@ fn check_message_cel_missing_fields(cels: &[CelRule], fields: &[FieldValidator]) None } +/// Detect simple schema/type CEL compile errors that `cel-interpreter` only +/// reports when the expression is evaluated against a concrete runtime value. +fn check_message_cel_type_errors(cels: &[CelRule], fields: &[FieldValidator]) -> Option { + let kinds: std::collections::HashMap<&str, &FieldKind> = fields + .iter() + .map(|f| (f.field_name.as_str(), &f.field_type)) + .collect(); + for rule in cels { + let bytes = rule.expression.as_bytes(); + let mut i = 0; + while i + 5 <= bytes.len() { + if &bytes[i..i + 5] == b"this." { + if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { + i += 1; + continue; + } + let mut j = i + 5; + let start = j; + while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') { + j += 1; + } + if j > start { + let ident = std::str::from_utf8(&bytes[start..j]).unwrap_or(""); + if has_string_method_call(&bytes[j..]) + && let Some(kind) = kinds.get(ident) + { + let underlying = cel_underlying_kind(kind); + if !matches!(underlying, FieldKind::String) { + return Some(format!( + "expression incorrectly treats an {} field as a string", + kind_family_name(underlying) + )); + } + } + } + i = j; + } else { + i += 1; + } + } + } + None +} + +fn has_string_method_call(bytes: &[u8]) -> bool { + bytes.starts_with(b".startsWith(") + || bytes.starts_with(b".endsWith(") + || bytes.starts_with(b".matches(") +} + +fn cel_underlying_kind(kind: &FieldKind) -> &FieldKind { + match kind { + FieldKind::Optional(inner) | FieldKind::Wrapper(inner) => cel_underlying_kind(inner), + other => other, + } +} + fn check_message_oneof_specs( msg_fqn: &str, specs: &[MessageOneofSpec], diff --git a/crates/protovalidate-buffa/src/cel.rs b/crates/protovalidate-buffa/src/cel.rs index dd683a8..8738a53 100644 --- a/crates/protovalidate-buffa/src/cel.rs +++ b/crates/protovalidate-buffa/src/cel.rs @@ -87,6 +87,12 @@ impl ToCelValue for Vec { } } +impl ToCelValue for buffa::bytes::Bytes { + fn to_cel_value(&self) -> Value { + Value::Bytes(self.to_vec().into()) + } +} + impl ToCelValue for Option { fn to_cel_value(&self) -> Value { self.as_ref().map_or(Value::Null, AsCelValue::as_cel_value) @@ -120,10 +126,25 @@ impl ToCelValue for buffa::MessageField { // `google.protobuf.Duration` — including the predefined-rule path, which binds // `this` to the field value via `AsCelValue`. // -// `FieldMask` is exposed as a CEL map with one entry, `paths`, so expressions -// like `this.paths.all(p, ...)` work. `Timestamp` and `Duration` are exposed -// as native CEL `Timestamp` / `Duration` values, which carry their own -// comparison and arithmetic operators in CEL. +// `Any`, `Empty`, and `FieldMask` are exposed as CEL maps. `Timestamp` and +// `Duration` are exposed as native CEL `Timestamp` / `Duration` values, which +// carry their own comparison and arithmetic operators in CEL. + +impl AsCelValue for buffa_types::google::protobuf::Any { + fn as_cel_value(&self) -> Value { + let mut map: std::collections::HashMap = + std::collections::HashMap::with_capacity(2); + map.insert("type_url".to_string(), self.type_url.to_cel_value()); + map.insert("value".to_string(), self.value.to_cel_value()); + Value::Map(map.into()) + } +} + +impl AsCelValue for buffa_types::google::protobuf::Empty { + fn as_cel_value(&self) -> Value { + Value::Map(std::collections::HashMap::::new().into()) + } +} impl AsCelValue for buffa_types::google::protobuf::FieldMask { fn as_cel_value(&self) -> Value { @@ -154,6 +175,26 @@ impl AsCelValue for buffa_types::google::protobuf::Duration { } } +macro_rules! impl_to_cel_for_as_cel_wkt { + ($($ty:path),* $(,)?) => { + $( + impl ToCelValue for $ty { + fn to_cel_value(&self) -> Value { + self.as_cel_value() + } + } + )* + }; +} + +impl_to_cel_for_as_cel_wkt!( + buffa_types::google::protobuf::Any, + buffa_types::google::protobuf::Empty, + buffa_types::google::protobuf::FieldMask, + buffa_types::google::protobuf::Timestamp, + buffa_types::google::protobuf::Duration, +); + impl ToCelValue for buffa::EnumValue { fn to_cel_value(&self) -> Value { Value::Int(i64::from(self.to_i32())) diff --git a/crates/protovalidate-buffa/tests/wkt_cel.rs b/crates/protovalidate-buffa/tests/wkt_cel.rs index eaa037c..55df96e 100644 --- a/crates/protovalidate-buffa/tests/wkt_cel.rs +++ b/crates/protovalidate-buffa/tests/wkt_cel.rs @@ -1,11 +1,36 @@ -//! `AsCelValue` impls for `google.protobuf.FieldMask`, `Timestamp`, and -//! `Duration`. Each test compiles and runs a real CEL expression against the -//! WKT, exercising the same `CelConstraint::eval` path used by -//! plugin-emitted field-level predefined rules. +//! `AsCelValue` impls for supported `google.protobuf.*` well-known types. +//! Each test compiles and runs a real CEL expression against the WKT, +//! exercising the same `CelConstraint::eval` path used by plugin-emitted CEL +//! rules. -use buffa_types::google::protobuf::{Duration, FieldMask, Timestamp}; +use buffa_types::google::protobuf::{Any, Duration, Empty, FieldMask, Timestamp}; use protovalidate_buffa::cel::CelConstraint; +#[test] +fn any_exposes_type_url() { + static RULE: CelConstraint = CelConstraint::new( + "test.any.type_url", + "Any type URL must match", + "this.type_url == 'type.googleapis.com/example.Widget'", + ); + + let any = Any { + type_url: "type.googleapis.com/example.Widget".to_string(), + ..Any::default() + }; + + RULE.eval(&any).expect("matching type_url should pass"); +} + +#[test] +fn empty_is_empty_map() { + static RULE: CelConstraint = + CelConstraint::new("test.empty.map", "Empty has no fields", "size(this) == 0"); + + RULE.eval(&Empty::default()) + .expect("Empty should expose an empty CEL map"); +} + #[test] fn field_mask_paths_all_true() { static RULE: CelConstraint = CelConstraint::new(