Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions derive/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,20 @@ impl Parse for DeriveDialectInput {

/// Entry point for the `derive_dialect!` macro
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
let err = |msg: String| {
Error::new(proc_macro2::Span::call_site(), msg)
.to_compile_error()
.into()
};
match derive_dialect_inner(input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}

let source = match read_dialect_mod_file() {
Ok(s) => s,
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
};
let file: File = match syn::parse_str(&source) {
Ok(f) => f,
Err(e) => return err(format!("Failed to parse source: {e}")),
};
let methods = match extract_dialect_methods(&file) {
Ok(m) => m,
Err(e) => return e.to_compile_error().into(),
};
fn derive_dialect_inner(input: DeriveDialectInput) -> syn::Result<TokenStream> {
let call_site = proc_macro2::Span::call_site();

let source = read_dialect_mod_file()
.map_err(|e| Error::new(call_site, format!("Failed to read dialect/mod.rs: {e}")))?;
let file: File = syn::parse_str::<File>(&source)
.map_err(|e| Error::new(call_site, format!("Failed to parse source: {e}")))?;
let methods = extract_dialect_methods(&file)?;

// Validate overrides
let bool_names: HashSet<_> = methods
Expand All @@ -147,20 +143,23 @@ pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStre
.collect();
for (key, value) in &input.overrides {
let key_str = key.to_string();
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
match value {
Override::Bool(_) if !bool_names.contains(&key_str) => {
return err(format!("Unknown boolean method `{key_str}`"));
return Err(Error::new(
key.span(),
format!("Unknown boolean method `{key_str}`"),
));
}
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
return err(format!(
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
return Err(Error::new(
key.span(),
format!("Char/None only valid for `identifier_quote_style`, not `{key_str}`"),
));
}
_ => {}
}
}
generate_derived_dialect(&input, &methods).into()
Ok(generate_derived_dialect(&input, &methods))
}

/// Generate the complete derived `Dialect` implementation
Expand Down Expand Up @@ -258,11 +257,59 @@ fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
}

/// Read the `dialect/mod.rs` file that contains the Dialect trait.
///
/// Searches for the file in the following order:
/// 1. `$CARGO_MANIFEST_DIR/src/dialect/mod.rs` - works when the macro is
/// invoked from within the `sqlparser` crate itself (e.g. in tests).
/// 2. `<sqlparser_derive dir>/../src/dialect/mod.rs` - works when
/// `sqlparser_derive` lives in a workspace alongside the main crate
/// (the standard `derive/` layout).
/// 3. Sibling directories of the compiled `sqlparser_derive` crate in the
/// Cargo registry - works when an external crate uses `derive_dialect!`
/// via a registry dependency.
fn read_dialect_mod_file() -> Result<String, String> {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
use std::path::{Path, PathBuf};

const DERIVE_CRATE_DIR: &str = env!("CARGO_MANIFEST_DIR");
let derive_dir = Path::new(DERIVE_CRATE_DIR);
let mut candidates: Vec<PathBuf> = Vec::new();

// The crate being compiled (eg: within sqlparser).
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
candidates.push(Path::new(&manifest_dir).join("src/dialect/mod.rs"));
}
// Workspace layout: the main crate is the parent of `derive/`.
candidates.push(derive_dir.join("../src/dialect/mod.rs"));

// Cargo registry: look for sibling `sqlparser-*` directories (prefer newest).
if let Some(parent) = derive_dir.parent() {
if let Ok(entries) = std::fs::read_dir(parent) {
let mut siblings: Vec<_> = entries
.filter_map(|e| e.ok())
.filter(|e| {
let name = e.file_name();
let name = name.to_string_lossy();
name.starts_with("sqlparser-") && !name.starts_with("sqlparser-derive")
})
.collect();
siblings.sort_by(|a, b| b.file_name().cmp(&a.file_name()));
candidates.extend(
siblings
.into_iter()
.map(|e| e.path().join("src/dialect/mod.rs")),
);
}
}
for path in &candidates {
if let Ok(content) = std::fs::read_to_string(path) {
return Ok(content);
}
}
Err(format!(
"Could not find `sqlparser` dialect/mod.rs file. \
Searched in $CARGO_MANIFEST_DIR/src/dialect/mod.rs and \
the `sqlparser_derive` crate at {DERIVE_CRATE_DIR}"
))
}

/// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId)
Expand Down
Loading