diff --git a/Cargo.lock b/Cargo.lock index 94701a5..1ce5f26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,7 @@ dependencies = [ "heck", "proc-macro2", "quote", + "serde", "syn", ] diff --git a/benzina-derive/Cargo.toml b/benzina-derive/Cargo.toml index 66f558b..e7f6d95 100644 --- a/benzina-derive/Cargo.toml +++ b/benzina-derive/Cargo.toml @@ -13,7 +13,7 @@ rust-version.workspace = true proc-macro = true [package.metadata.docs.rs] -features = ["postgres", "mysql"] +features = ["postgres", "mysql", "json"] [dependencies] proc-macro2 = "1.0.94" @@ -23,10 +23,12 @@ heck = "0.5.0" [dev-dependencies] diesel = { version = "2", default-features = false, features = ["postgres", "mysql_backend"] } +serde = { version = "1", features = ["derive"] } [features] postgres = [] mysql = [] +json = [] [lints] workspace = true diff --git a/benzina-derive/src/enum_derive.rs b/benzina-derive/src/enum_derive.rs index 30d4e79..1037f6d 100644 --- a/benzina-derive/src/enum_derive.rs +++ b/benzina-derive/src/enum_derive.rs @@ -25,79 +25,147 @@ pub(crate) struct Enum { ident: Ident, sql_type: Type, rename_all: RenameRule, - crate_name: Option, variants: Vec, + + #[cfg(all(feature = "postgres", feature = "json"))] + table: Option, + #[cfg(all(feature = "postgres", feature = "json"))] + column: Option, + #[cfg(all(feature = "postgres", feature = "json"))] + data_column: Option, + + crate_name: Option, } struct EnumVariant { original_name: String, + original_name_span: Span, rename: Option, + #[cfg(all(feature = "postgres", feature = "json"))] + has_payload: bool, + crate_name: Option, - span: Span, } impl Enum { + #[expect(clippy::too_many_lines)] pub(crate) fn parse(input: DeriveInput) -> Result { let Data::Enum(e) = input.data else { fail!(input, "`benzina::Enum` macro available only for enums"); }; - let (rename_all, sql_type, crate_name) = { - let mut first_attr = None; - let mut sql_type = None; - let mut rename_all = None; - let mut crate_name = None; - - for attr in input - .attrs - .iter() - .filter(|attr| attr.path().is_ident("benzina")) - { - first_attr.get_or_insert(attr); - - attr.parse_nested_meta(|meta| { - if meta.path.is_ident("sql_type") { + let mut first_attr = None; + let mut sql_type = None; + let mut rename_all = None; + #[cfg(all(feature = "postgres", feature = "json"))] + let mut table = None; + #[cfg(all(feature = "postgres", feature = "json"))] + let mut column = None; + #[cfg(all(feature = "postgres", feature = "json"))] + let mut data_column = None; + let mut crate_name = None; + + for attr in input + .attrs + .iter() + .filter(|attr| attr.path().is_ident("benzina")) + { + first_attr.get_or_insert(attr); + + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("sql_type") { + meta.input.parse::()?; + let val: Type = meta.input.parse()?; + try_set!(sql_type, val, val); + } else if meta.path.is_ident("rename_all") { + meta.input.parse::()?; + let val: LitStr = meta.input.parse()?; + try_set!( + rename_all, + val.value() + .parse() + .map_err(|err| syn::Error::new_spanned(val, err))?, + val + ); + } else if meta.path.is_ident("table") { + #[cfg(all(feature = "postgres", feature = "json"))] + { meta.input.parse::()?; - let val: Type = meta.input.parse()?; - try_set!(sql_type, val, val); - } else if meta.path.is_ident("rename_all") { + let val: Path = meta.input.parse()?; + try_set!(table, val, val); + } + #[cfg(not(all(feature = "postgres", feature = "json")))] + { + let _ = meta.input.parse::()?; + let _: Path = meta.input.parse()?; + } + } else if meta.path.is_ident("column") { + #[cfg(all(feature = "postgres", feature = "json"))] + { meta.input.parse::()?; - let val: LitStr = meta.input.parse()?; - try_set!( - rename_all, - val.value() - .parse() - .map_err(|err| syn::Error::new_spanned(val, err))?, - val - ); - } else if meta.path.is_ident("crate") { + let val: Ident = meta.input.parse()?; + try_set!(column, val, val); + } + #[cfg(not(all(feature = "postgres", feature = "json")))] + { + let _ = meta.input.parse::()?; + let _: Ident = meta.input.parse()?; + } + } else if meta.path.is_ident("data_column") { + #[cfg(all(feature = "postgres", feature = "json"))] + { meta.input.parse::()?; - let val: Path = meta.input.parse()?; - try_set!(crate_name, val, val); + let val: Ident = meta.input.parse()?; + try_set!(data_column, val, val); } + #[cfg(not(all(feature = "postgres", feature = "json")))] + { + let _ = meta.input.parse::()?; + let _: Ident = meta.input.parse()?; + } + } else if meta.path.is_ident("crate") { + meta.input.parse::()?; + let val: Path = meta.input.parse()?; + try_set!(crate_name, val, val); + } - Ok(()) - })?; - } - - let Some(first_attr) = first_attr else { - fail!(e.enum_token, "expected #[benzina(...)] attribute"); - }; + Ok(()) + })?; + } - let Some(sql_type) = sql_type else { - fail!(first_attr, "expected `sql_type`"); - }; + let Some(first_attr) = first_attr else { + fail!(e.enum_token, "expected #[benzina(...)] attribute"); + }; - (rename_all.unwrap_or(RenameRule::None), sql_type, crate_name) + let Some(sql_type) = sql_type else { + fail!(first_attr, "expected `sql_type`"); }; + let rename_all = rename_all.unwrap_or(RenameRule::None); + let variants = e .variants .into_iter() .map(|variant| { - if !matches!(variant.fields, Fields::Unit) { - fail!(variant, "only unit variants are supported"); - } + let has_payload = match &variant.fields { + Fields::Unit => false, + #[cfg(all(feature = "postgres", feature = "json"))] + Fields::Unnamed(fields) => { + let mut fields = fields.unnamed.iter(); + if !matches!((fields.next(), fields.next()), (Some(_),None)){ + fail!(variant, "only single-item variants are supported"); + } + + true + } + #[cfg(not(all(feature = "postgres", feature = "json")))] + Fields::Unnamed(_fields) => { + fail!(variant, "fields require both the `postgres` and the `json` feature to be enabled"); + } + Fields::Named(_fields) => { + fail!(variant, "only unit an unnamed variants are supported"); + } + }; let name = variant.ident.to_string(); let mut rename = None; @@ -118,12 +186,20 @@ impl Enum { })?; } - let span = variant.span(); + // Suppress build breakage when building without the + // PostgreSQL JSON feature. + #[cfg(not(all(feature = "postgres", feature = "json")))] + let _ = has_payload; + + let original_name_span = variant.span(); Ok(EnumVariant { original_name: name, + original_name_span, rename, + #[cfg(all(feature = "postgres", feature = "json"))] + has_payload, + crate_name: crate_name.clone(), - span, }) }) .collect::, syn::Error>>()?; @@ -131,173 +207,452 @@ impl Enum { ident: input.ident, sql_type, rename_all, - crate_name, variants, + + #[cfg(all(feature = "postgres", feature = "json"))] + table, + #[cfg(all(feature = "postgres", feature = "json"))] + column, + #[cfg(all(feature = "postgres", feature = "json"))] + data_column, + + crate_name, }) } + + #[cfg(all(feature = "postgres", feature = "json"))] + fn has_json_fields(&self) -> bool { + self.variants.iter().any(|variant| variant.has_payload) + } + + #[cfg(not(all(feature = "postgres", feature = "json")))] + #[expect( + clippy::unused_self, + reason = "kept for compatibility with the above implementation" + )] + fn has_json_fields(&self) -> bool { + false + } } impl ToTokens for Enum { + #[expect(clippy::too_many_lines)] fn to_tokens(&self, tokens: &mut TokenStream) { let Self { ident, sql_type, rename_all, - crate_name, variants, + + #[cfg(all(feature = "postgres", feature = "json"))] + table: _, + #[cfg(all(feature = "postgres", feature = "json"))] + column: _, + #[cfg(all(feature = "postgres", feature = "json"))] + data_column: _, + + crate_name, } = &self; let crate_name = crate::crate_name(crate_name); + let has_json_fields = self.has_json_fields(); + let impls_ident = Ident::new(&format!("{ident}Kind"), ident.span()); + + let as_expression = quote! { + #[automatically_derived] + impl #crate_name::__private::diesel::expression::AsExpression<#sql_type> for #ident { + type Expression = #crate_name::__private::diesel::internal::derives::as_expression::Bound< + #sql_type, + Self, + >; + + fn as_expression(self) -> Self::Expression { + #crate_name::__private::diesel::internal::derives::as_expression::Bound::new(self) + } + } + + #[automatically_derived] + impl<'__expr> #crate_name::__private::diesel::expression::AsExpression<#sql_type> for &'__expr #ident { + type Expression = #crate_name::__private::diesel::internal::derives::as_expression::Bound< + #sql_type, + Self, + >; + + fn as_expression(self) -> Self::Expression { + #crate_name::__private::diesel::internal::derives::as_expression::Bound::new(self) + } + } + + #[automatically_derived] + impl<'__expr, '__expr2> #crate_name::__private::diesel::expression::AsExpression<#sql_type> for &'__expr2 &'__expr #ident { + type Expression = #crate_name::__private::diesel::internal::derives::as_expression::Bound< + #sql_type, + Self, + >; + + fn as_expression(self) -> Self::Expression { + #crate_name::__private::diesel::internal::derives::as_expression::Bound::new(self) + } + } + }; + let from_bytes_arms = variants .iter() - .map(|variant| variant.gen_from_bytes(*rename_all)); - let to_str_arms = variants + .map(|variant| variant.gen_from_bytes(has_json_fields, *rename_all)) + .collect::>(); + #[cfg(feature = "postgres")] + let to_byte_str_arms = variants .iter() - .map(|variant| variant.gen_to_str(*rename_all)); + .map(|variant| variant.gen_to_byte_str(has_json_fields, *rename_all)) + .collect::>(); - tokens.append_all(quote! { - impl #ident { - #[doc(hidden)] - fn __benzina04_from_bytes(val: &[u8]) -> #crate_name::__private::std::option::Option { - match val { - #(#from_bytes_arms)* - _ => #crate_name::__private::std::option::Option::None, + #[cfg(feature = "postgres")] + let (queryable_sql_type, queryable_row_type, queryable_impl) = if self.has_json_fields() { + #[cfg(all(feature = "postgres", feature = "json"))] + { + let from_queryable_arms = variants + .iter() + .map(|variant| variant.gen_from_queryable(&impls_ident)); + + ( + quote! { (#sql_type, #crate_name::__private::diesel::pg::sql_types::Jsonb) }, + quote! { (#impls_ident, #crate_name::__private::json::RawJsonb) }, + quote! { + match row.0 { + #(#from_queryable_arms)* + } + }, + ) + } + + #[cfg(not(all(feature = "postgres", feature = "json")))] + unreachable!() + } else { + ( + quote! { #sql_type }, + quote! { Self }, + quote! { #crate_name::__private::std::result::Result::Ok(row) }, + ) + }; + + #[cfg(feature = "postgres")] + let postgres_from_to_sql = if has_json_fields { + quote! {} + } else { + quote! { + #[automatically_derived] + impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident { + fn from_sql(bytes: #crate_name::__private::diesel::pg::PgValue<'_>) -> #crate_name::__private::diesel::deserialize::Result { + match bytes.as_bytes() { + #(#from_bytes_arms)* + _ => { + #crate_name::__private::std::result::Result::Err( + #crate_name::__private::std::convert::Into::into( + "Unrecognized enum variant" + ) + ) + }, + } } } - #[doc(hidden)] - fn __benzina04_as_str(&self) -> &'static str { - match self { - #(#to_str_arms)* + #[automatically_derived] + impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident { + fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::pg::Pg>) -> #crate_name::__private::diesel::serialize::Result { + let s: &[u8] = match self { + #(#to_byte_str_arms)* + }; + #crate_name::__private::std::io::Write::write_all(out, s)?; + + #crate_name::__private::std::result::Result::Ok( + #crate_name::__private::diesel::serialize::IsNull::No + ) } } } - }); + }; #[cfg(feature = "postgres")] - tokens.append_all(quote! { + let postgres = quote! { #[automatically_derived] - impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident { - fn from_sql(bytes: #crate_name::__private::diesel::pg::PgValue<'_>) -> #crate_name::__private::diesel::deserialize::Result { - match Self::__benzina04_from_bytes(bytes.as_bytes()) { - #crate_name::__private::std::option::Option::Some(this) => { - #crate_name::__private::std::result::Result::Ok(this) - }, - #crate_name::__private::std::option::Option::None => { - #crate_name::__private::std::result::Result::Err( - #crate_name::__private::std::convert::Into::into( - "Unrecognized enum variant" - ) - ) - }, - } + impl #crate_name::__private::diesel::deserialize::Queryable<#queryable_sql_type, #crate_name::__private::diesel::pg::Pg> for #ident { + type Row = #queryable_row_type; + + fn build(row: Self::Row) -> #crate_name::__private::diesel::deserialize::Result { + #queryable_impl } } - #[automatically_derived] - impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::pg::Pg> for #ident { - fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::pg::Pg>) -> #crate_name::__private::diesel::serialize::Result { - let sql_val = self.__benzina04_as_str(); - #crate_name::__private::std::io::Write::write_all(out, sql_val.as_bytes())?; - - #crate_name::__private::std::result::Result::Ok( - #crate_name::__private::diesel::serialize::IsNull::No + #postgres_from_to_sql + }; + #[cfg(not(feature = "postgres"))] + let postgres = quote! {}; + + #[cfg(all(feature = "postgres", feature = "json"))] + let postgres_extra = if self.has_json_fields() { + let entries = self.variants.iter().map(|variant| { + let original_name_ident = variant.original_name(); + quote! { + #original_name_ident, + } + }); + + let impls_enum = Self { + ident: impls_ident.clone(), + sql_type: self.sql_type.clone(), + rename_all: self.rename_all, + variants: self + .variants + .iter() + .map( + |EnumVariant { + original_name, + original_name_span, + rename, + has_payload: _, + crate_name, + }| EnumVariant { + original_name: original_name.clone(), + original_name_span: *original_name_span, + rename: rename.clone(), + has_payload: false, + crate_name: crate_name.clone(), + }, ) + .collect(), + table: None, + column: None, + data_column: None, + crate_name: self.crate_name.clone(), + }; + let selectable_insertable_impl = if let (Some(table), Some(column), Some(data_column)) = + (&self.table, &self.column, &self.data_column) + { + let to_insertable_arms = self + .variants + .iter() + .map(|variant| variant.gen_to_insertable(ident, &impls_ident)); + + quote! { + #[automatically_derived] + impl #crate_name::__private::diesel::expression::Selectable<#crate_name::__private::diesel::pg::Pg> for #ident { + type SelectExpression = (#table::#column, #table::#data_column); + + fn construct_selection() -> Self::SelectExpression { + (#table::#column, #table::#data_column) + } + } + + #[automatically_derived] + impl<'__ins> #crate_name::__private::diesel::Insertable<#table::table> for &'__ins #ident { + type Values = <( + #crate_name::__private::diesel::dsl::Eq<#table::#column, #impls_ident>, + #crate_name::__private::diesel::dsl::Eq<#table::#data_column, #crate_name::__private::json::RawJsonb>, + ) as #crate_name::__private::diesel::Insertable<#table::table>>::Values; + + fn values(self) -> Self::Values { + use #crate_name::__private::diesel::ExpressionMethods; + let (kind, data) = match self { + #(#to_insertable_arms)* + }; + #crate_name::__private::diesel::Insertable::values(( + #table::#column.eq(kind), + #table::#data_column.eq(data), + )) + } + } + + #[automatically_derived] + impl #crate_name::__private::diesel::Insertable<#table::table> for #ident { + type Values = <( + #crate_name::__private::diesel::dsl::Eq<#table::#column, #impls_ident>, + #crate_name::__private::diesel::dsl::Eq<#table::#data_column, #crate_name::__private::json::RawJsonb>, + ) as #crate_name::__private::diesel::Insertable<#table::table>>::Values; + + fn values(self) -> Self::Values { + #crate_name::__private::diesel::Insertable::values(&self) + } + } + } + } else { + quote! {} + }; + + quote! { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub enum #impls_ident { + #(#entries)* } + + #impls_enum + + #selectable_insertable_impl } - }); + } else { + quote! {} + }; + #[cfg(not(all(feature = "postgres", feature = "json")))] + let postgres_extra = quote! {}; #[cfg(feature = "mysql")] - tokens.append_all(quote! { - #[automatically_derived] - impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident { - fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result { - match Self::__benzina04_from_bytes(bytes.as_bytes()) { - #crate_name::__private::std::option::Option::Some(this) => { - #crate_name::__private::std::result::Result::Ok(this) - }, - #crate_name::__private::std::option::Option::None => { - #crate_name::__private::std::result::Result::Err( - #crate_name::__private::std::convert::Into::into( - "Unrecognized enum variant" + let mysql = if self.has_json_fields() { + quote! {} + } else { + let from_bytes_arms = variants + .iter() + .map(|variant| variant.gen_from_bytes(false, *rename_all)) + .collect::>(); + let to_byte_str_arms = variants + .iter() + .map(|variant| variant.gen_to_byte_str(false, *rename_all)) + .collect::>(); + + quote! { + #[automatically_derived] + impl #crate_name::__private::diesel::deserialize::Queryable<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident { + type Row = Self; + + fn build(row: Self::Row) -> #crate_name::__private::diesel::deserialize::Result { + #crate_name::__private::std::result::Result::Ok(row) + } + } + + #[automatically_derived] + impl #crate_name::__private::diesel::deserialize::FromSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident { + fn from_sql(bytes: #crate_name::__private::diesel::mysql::MysqlValue<'_>) -> #crate_name::__private::diesel::deserialize::Result { + match bytes.as_bytes() { + #(#from_bytes_arms)* + _ => { + #crate_name::__private::std::result::Result::Err( + #crate_name::__private::std::convert::Into::into( + "Unrecognized enum variant" + ) ) - ) - }, + }, + } } } - } - #[automatically_derived] - impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident { - fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result { - let sql_val = self.__benzina04_as_str(); - #crate_name::__private::std::io::Write::write_all(out, sql_val.as_bytes())?; + #[automatically_derived] + impl #crate_name::__private::diesel::serialize::ToSql<#sql_type, #crate_name::__private::diesel::mysql::Mysql> for #ident { + fn to_sql<'b>(&'b self, out: &mut #crate_name::__private::diesel::serialize::Output<'b, '_, #crate_name::__private::diesel::mysql::Mysql>) -> #crate_name::__private::diesel::serialize::Result { + let s: &[u8] = match self { + #(#to_byte_str_arms)* + }; + #crate_name::__private::std::io::Write::write_all(out, s)?; - #crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No) + #crate_name::__private::std::result::Result::Ok(#crate_name::__private::diesel::serialize::IsNull::No) + } } } + }; + #[cfg(not(feature = "mysql"))] + let mysql = quote! {}; + + tokens.append_all(quote! { + #as_expression + #postgres + #postgres_extra + #mysql }); } } impl EnumVariant { - fn gen_from_bytes(&self, rename_rule: RenameRule) -> impl ToTokens + use<'_> { - struct EnumVariantFromBytes<'a>(&'a EnumVariant, RenameRule); - - impl ToTokens for EnumVariantFromBytes<'_> { - fn to_tokens(&self, tokens: &mut TokenStream) { - let Self( - EnumVariant { - original_name, - rename, - crate_name, - span, - }, - rename_rule, - ) = self; - let crate_name = crate::crate_name(crate_name); - - let rename = rename - .clone() - .unwrap_or_else(|| rename_rule.format(original_name)); - - let original_name_ident = Ident::new(original_name, *span); - let rename_bytes = LitByteStr::new(rename.as_bytes(), *span); - tokens.append_all(quote! { - #rename_bytes => #crate_name::__private::std::option::Option::Some(Self::#original_name_ident), - }); - } - } + fn original_name(&self) -> Ident { + Ident::new(&self.original_name, self.original_name_span) + } - EnumVariantFromBytes(self, rename_rule) + fn gen_from_bytes(&self, _has_fields: bool, rename_rule: RenameRule) -> impl ToTokens { + let Self { + original_name, + original_name_span, + rename, + #[cfg(all(feature = "postgres", feature = "json"))] + has_payload: _, + + crate_name, + } = self; + let crate_name = crate::crate_name(crate_name); + + let rename = rename + .clone() + .unwrap_or_else(|| rename_rule.format(original_name)); + + let original_name_ident = self.original_name(); + let rename_bytes = LitByteStr::new(rename.as_bytes(), *original_name_span); + quote! { + #rename_bytes => #crate_name::__private::std::result::Result::Ok(Self::#original_name_ident), + } } - fn gen_to_str(&self, rename_rule: RenameRule) -> impl ToTokens + use<'_> { - struct EnumVariantToStr<'a>(&'a EnumVariant, RenameRule); - - impl ToTokens for EnumVariantToStr<'_> { - fn to_tokens(&self, tokens: &mut TokenStream) { - let Self( - EnumVariant { - original_name, - rename, - crate_name: _, - span, - }, - rename_rule, - ) = self; + #[cfg(all(feature = "postgres", feature = "json"))] + fn gen_from_queryable(&self, impls_ident: &Ident) -> impl ToTokens { + let crate_name = crate::crate_name(&self.crate_name); - let rename = rename - .clone() - .unwrap_or_else(|| rename_rule.format(original_name)); + let original_name_ident = self.original_name(); - let original_name_ident = Ident::new(original_name, *span); - tokens.append_all(quote! { - Self::#original_name_ident => #rename, - }); + let inner = if self.has_payload { + quote! { + #crate_name::__private::std::result::Result::map( + #crate_name::__private::json::RawJsonb::deserialize(&row.1), + Self::#original_name_ident + ) + } + } else { + quote! { + #crate_name::__private::std::result::Result::Ok(Self::#original_name_ident) } + }; + quote! { + #impls_ident::#original_name_ident => { + #inner + }, } + } - EnumVariantToStr(self, rename_rule) + fn gen_to_byte_str(&self, _has_fields: bool, rename_rule: RenameRule) -> impl ToTokens { + let Self { + original_name, + original_name_span, + rename, + #[cfg(all(feature = "postgres", feature = "json"))] + has_payload: _, + + crate_name: _, + } = self; + + let rename = rename + .clone() + .unwrap_or_else(|| rename_rule.format(original_name)); + + let original_name_ident = self.original_name(); + let rename_bytes = LitByteStr::new(rename.as_bytes(), *original_name_span); + quote! { + Self::#original_name_ident => #rename_bytes, + } + } + + #[cfg(all(feature = "postgres", feature = "json"))] + fn gen_to_insertable(&self, ident: &Ident, impls_ident: &Ident) -> impl ToTokens { + let crate_name = crate::crate_name(&self.crate_name); + let original_name_ident = self.original_name(); + + if self.has_payload { + quote! { + #ident::#original_name_ident(payload) => ( + #impls_ident::#original_name_ident, + #crate_name::__private::json::RawJsonb::serialize(payload) + .expect("failed to serialize enum payload"), + ), + } + } else { + quote! { + #ident::#original_name_ident => ( + #impls_ident::#original_name_ident, + #crate_name::__private::json::RawJsonb::EMPTY, + ), + } + } } } diff --git a/benzina-derive/src/lib.rs b/benzina-derive/src/lib.rs index aa15bf7..c81f5f5 100644 --- a/benzina-derive/src/lib.rs +++ b/benzina-derive/src/lib.rs @@ -35,13 +35,7 @@ mod rename_rule; /// # use benzina_derive as benzina; /// # fn main() {} /// -/// use diesel::{ -/// deserialize::FromSqlRow, -/// expression::AsExpression, -/// }; -/// -/// #[derive(Debug, Copy, Clone, AsExpression, FromSqlRow, benzina::Enum)] -/// #[diesel(sql_type = crate::schema::sql_types::Animal)] +/// #[derive(Debug, Copy, Clone, benzina::Enum)] /// #[benzina( /// sql_type = crate::schema::sql_types::Animal, /// rename_all = "snake_case" @@ -73,6 +67,153 @@ mod rename_rule; /// # } /// ``` /// +/// ## Enums with variant-specific data in separate JSONB column +/// +/// You can also use `benzina::Enum` for enums where each variant holds +/// associated data. This is useful when you have a PostgreSQL ENUM for the +/// discriminator and a JSONB column for the variant-specific payload. +/// +/// ### migration +/// +/// ```sql +/// CREATE TYPE animal AS ENUM ('chicken', 'duck', 'oca', 'rabbit'); +/// +/// CREATE TABLE pets ( +/// id SERIAL PRIMARY KEY, +/// name TEXT NOT NULL, +/// animal animal NOT NULL, +/// animal_data JSONB NOT NULL +/// ); +/// ``` +/// +/// ### Rust enum +/// +/// ```rust +/// # use benzina_derive as benzina; +/// # fn main() {} +/// use diesel::pg::Pg; +/// use diesel::{Identifiable, Insertable, Queryable, Selectable}; +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Debug, Queryable, Identifiable, Insertable, Selectable)] +/// #[diesel(table_name = schema::pets, check_for_backend(Pg))] +/// pub struct Pet { +/// pub id: i32, +/// pub name: String, +/// #[diesel(embed)] +/// pub animal: Animal, +/// } +/// +/// #[derive(Debug, Clone, benzina::Enum)] +/// #[benzina( +/// sql_type = schema::sql_types::Animal, +/// rename_all = "snake_case", +/// table = schema::pets, +/// column = animal, +/// data_column = animal_data +/// )] +/// # #[benzina(crate = fake_benzina)] +/// pub enum Animal { +/// Chicken(ChickenData), +/// Duck(DuckData), +/// #[benzina(rename = "oca")] +/// Goose(GooseData), +/// Rabbit(RabbitData), +/// } +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// pub struct ChickenData { +/// pub likes_cuddles: bool, +/// pub breed: String, +/// } +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// pub struct DuckData { +/// pub favorite_treat: String, +/// pub feather_color: String, +/// } +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// pub struct GooseData { +/// pub weight_kg: f64, +/// pub honks_at_strangers: bool, +/// } +/// +/// #[derive(Debug, Clone, Serialize, Deserialize)] +/// pub struct RabbitData { +/// pub fur_color: String, +/// pub litter_trained: bool, +/// } +/// +/// pub mod schema { +/// // @generated automatically by Diesel CLI. +/// +/// pub mod sql_types { +/// #[derive(diesel::query_builder::QueryId, Clone, diesel::sql_types::SqlType)] +/// #[diesel(postgres_type(name = "animal"))] +/// pub struct Animal; +/// } +/// +/// diesel::table! { +/// use diesel::sql_types::*; +/// use super::sql_types::Animal; +/// +/// pets (id) { +/// id -> Int4, +/// name -> Text, +/// animal -> Animal, +/// animal_data -> Jsonb, +/// } +/// } +/// } +/// # +/// # mod fake_benzina { +/// # pub mod __private { +/// # pub use std; +/// # pub use diesel; +/// # +/// # pub mod json { +/// # use diesel::{ +/// # deserialize::{FromSql, FromSqlRow}, +/// # expression::AsExpression, +/// # pg::{Pg, PgValue}, +/// # serialize::ToSql, +/// # sql_types, +/// # }; +/// # use serde::{Deserialize, Serialize}; +/// # +/// # #[derive(Debug, FromSqlRow, AsExpression)] +/// # #[diesel(sql_type = sql_types::Jsonb)] +/// # pub struct RawJsonb; +/// # +/// # impl RawJsonb { +/// # pub const EMPTY: Self = Self; +/// # +/// # pub fn serialize(value: &impl Serialize) -> diesel::deserialize::Result { +/// # unimplemented!() +/// # } +/// # +/// # pub fn deserialize Deserialize<'a>>(&self) -> diesel::deserialize::Result { +/// # unimplemented!() +/// # } +/// # } +/// # +/// # impl FromSql for RawJsonb { +/// # fn from_sql(value: PgValue) -> diesel::deserialize::Result { +/// # unimplemented!() +/// # } +/// # } +/// # +/// # impl ToSql for RawJsonb { +/// # fn to_sql(&self, out: &mut diesel::serialize::Output) -> diesel::serialize::Result { +/// # unimplemented!() +/// # } +/// # } +/// # } +/// # } +/// # } +/// ``` +/// /// [`FromSql`]: https://docs.rs/diesel/latest/diesel/deserialize/trait.FromSql.html /// [`ToSql`]: https://docs.rs/diesel/latest/diesel/serialize/trait.ToSql.html #[proc_macro_derive(Enum, attributes(benzina))] diff --git a/benzina/Cargo.toml b/benzina/Cargo.toml index bf7b39c..e5acee0 100644 --- a/benzina/Cargo.toml +++ b/benzina/Cargo.toml @@ -53,7 +53,7 @@ example-generated = ["typed-uuid"] dangerous-construction = ["typed-uuid"] array = ["postgres"] -json = ["postgres", "dep:serde_core", "dep:serde_json", "diesel/serde_json"] +json = ["postgres", "benzina-derive?/json", "dep:serde_core", "dep:serde_json", "diesel/serde_json"] ctid = ["postgres", "diesel/i-implement-a-third-party-backend-and-opt-into-breaking-changes"] [lints] diff --git a/benzina/src/__private.rs b/benzina/src/__private.rs index e53890e..a81a72a 100644 --- a/benzina/src/__private.rs +++ b/benzina/src/__private.rs @@ -23,6 +23,56 @@ pub fn new_indexmap() -> IndexMap { IndexMap::with_hasher(Hasher::default()) } +#[cfg(all(feature = "postgres", feature = "json"))] +pub mod json { + use std::borrow::Cow; + + use diesel::{ + deserialize::{FromSql, FromSqlRow}, + expression::AsExpression, + pg::{Pg, PgValue}, + serialize::ToSql, + sql_types, + }; + use serde_core::{Deserialize, Serialize}; + + use crate::json::convert::{sql_deserialize_binary_raw, sql_serialize_binary_raw}; + + #[derive(Debug, FromSqlRow, AsExpression)] + #[diesel(sql_type = sql_types::Jsonb)] + pub struct RawJsonb(Cow<'static, [u8]>); + + impl RawJsonb { + pub const EMPTY: Self = Self(Cow::Borrowed(b"{}")); + + pub fn serialize(value: &impl Serialize) -> diesel::deserialize::Result { + serde_json::to_vec(value) + .map(Cow::Owned) + .map(Self) + .map_err(Into::into) + } + + pub fn deserialize Deserialize<'a>>(&self) -> diesel::deserialize::Result { + serde_json::from_slice(&self.0).map_err(Into::into) + } + } + + impl FromSql for RawJsonb { + fn from_sql(value: PgValue) -> diesel::deserialize::Result { + sql_deserialize_binary_raw(&value) + .map(ToOwned::to_owned) + .map(Cow::Owned) + .map(Self) + } + } + + impl ToSql for RawJsonb { + fn to_sql(&self, out: &mut diesel::serialize::Output) -> diesel::serialize::Result { + sql_serialize_binary_raw(&self.0, out) + } + } +} + pub mod deep_clone { pub trait DeepClone { type Output; diff --git a/benzina/src/json/convert.rs b/benzina/src/json/convert.rs index 8a4ff22..6195d18 100644 --- a/benzina/src/json/convert.rs +++ b/benzina/src/json/convert.rs @@ -55,6 +55,15 @@ where sql_serialize(value, out) } +pub(crate) fn sql_serialize_binary_raw( + value: &[u8], + out: &mut diesel::serialize::Output<'_, '_, Pg>, +) -> diesel::serialize::Result { + out.write_all(&[1])?; + out.write_all(value)?; + Ok(IsNull::No) +} + pub(super) fn sql_deserialize(value: PgValue<'_>) -> diesel::deserialize::Result where T: DeserializeOwned, @@ -66,6 +75,13 @@ pub(super) fn sql_deserialize_binary(value: PgValue<'_>) -> diesel::deseriali where T: DeserializeOwned, { + let bytes = sql_deserialize_binary_raw(&value)?; + serde_json::from_slice(bytes).map_err(Into::into) +} + +pub(crate) fn sql_deserialize_binary_raw<'a>( + value: &'a PgValue<'_>, +) -> diesel::deserialize::Result<&'a [u8]> { let (version, bytes) = value .as_bytes() .split_first() @@ -75,5 +91,5 @@ where return Err("Unsupported JSONB encoding version".into()); } - serde_json::from_slice(bytes).map_err(Into::into) + Ok(bytes) }