diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index ccf273b8..dde81b6b 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -80,6 +80,7 @@ fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) { ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table), ast::Stmt::CreateIndex(create_index) => bind_create_index(b, create_index), ast::Stmt::CreateFunction(create_function) => bind_create_function(b, create_function), + ast::Stmt::CreateAggregate(create_aggregate) => bind_create_aggregate(b, create_aggregate), ast::Stmt::CreateSchema(create_schema) => bind_create_schema(b, create_schema), ast::Stmt::Set(set) => bind_set(b, set), _ => {} @@ -103,6 +104,7 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) { kind: SymbolKind::Table, ptr: name_ptr, schema, + params: None, }); let root = b.root_scope(); @@ -125,6 +127,7 @@ fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) { kind: SymbolKind::Index, ptr: name_ptr, schema, + params: None, }); let root = b.root_scope(); @@ -146,16 +149,47 @@ fn bind_create_function(b: &mut Binder, create_function: ast::CreateFunction) { return; }; + let params = extract_param_signature(create_function.param_list()); + let function_id = b.symbols.alloc(Symbol { kind: SymbolKind::Function, ptr: name_ptr, schema, + params, }); let root = b.root_scope(); b.scopes[root].insert(function_name, function_id); } +fn bind_create_aggregate(b: &mut Binder, create_aggregate: ast::CreateAggregate) { + let Some(path) = create_aggregate.path() else { + return; + }; + + let Some(aggregate_name) = item_name(&path) else { + return; + }; + + let name_ptr = path_to_ptr(&path); + + let Some(schema) = schema_name(b, &path, false) else { + return; + }; + + let params = extract_param_signature(create_aggregate.param_list()); + + let aggregate_id = b.symbols.alloc(Symbol { + kind: SymbolKind::Aggregate, + ptr: name_ptr, + schema, + params, + }); + + let root = b.root_scope(); + b.scopes[root].insert(aggregate_name, aggregate_id); +} + fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) { let Some(schema_name_node) = create_schema.name() else { return; @@ -168,6 +202,7 @@ fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) { kind: SymbolKind::Schema, ptr: name_ptr, schema: Schema(schema_name.clone()), + params: None, }); let root = b.root_scope(); @@ -293,3 +328,19 @@ fn extract_string_literal(literal: &ast::Literal) -> Option { None } } + +fn extract_param_signature(param_list: Option) -> Option> { + let param_list = param_list?; + let mut params = vec![]; + for param in param_list.params() { + if let Some(ty) = param.ty() + && let ast::Type::PathType(path_type) = ty + && let Some(path) = path_type.path() + && let Some(segment) = path.segment() + && let Some(name_ref) = segment.name_ref() + { + params.push(Name::new(name_ref.syntax().text().to_string())); + } + } + (!params.is_empty()).then_some(params) +} diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 1dffc694..2cfe008e 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -999,6 +999,37 @@ drop function foo(), bar$0(); "); } + #[test] + fn goto_drop_function_overloaded() { + assert_snapshot!(goto(" +create function add(complex) returns complex as $$ select null $$ language sql; +create function add(bigint) returns bigint as $$ select 1 $$ language sql; +drop function add$0(complex); +"), @r" + ╭▸ + 2 │ create function add(complex) returns complex as $$ select null $$ language sql; + │ ─── 2. destination + 3 │ create function add(bigint) returns bigint as $$ select 1 $$ language sql; + 4 │ drop function add(complex); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_function_second_overload() { + assert_snapshot!(goto(" +create function add(complex) returns complex as $$ select null $$ language sql; +create function add(bigint) returns bigint as $$ select 1 $$ language sql; +drop function add$0(bigint); +"), @r" + ╭▸ + 3 │ create function add(bigint) returns bigint as $$ select 1 $$ language sql; + │ ─── 2. destination + 4 │ drop function add(bigint); + ╰╴ ─ 1. source + "); + } + #[test] fn goto_select_function_call() { assert_snapshot!(goto(" @@ -2300,4 +2331,125 @@ select a$0 from t; ╰╴ ─ 1. source "); } + + #[test] + fn goto_drop_aggregate() { + assert_snapshot!(goto(" +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate myavg$0(int); +"), @r" + ╭▸ + 2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + │ ───── 2. destination + 3 │ drop aggregate myavg(int); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_aggregate_with_schema() { + assert_snapshot!(goto(" +set search_path to public; +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate public.myavg$0(int); +"), @r" + ╭▸ + 3 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + │ ───── 2. destination + 4 │ drop aggregate public.myavg(int); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_aggregate_defined_after() { + assert_snapshot!(goto(" +drop aggregate myavg$0(int); +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +"), @r" + ╭▸ + 2 │ drop aggregate myavg(int); + │ ─ 1. source + 3 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + ╰╴ ───── 2. destination + "); + } + + #[test] + fn goto_aggregate_definition_returns_self() { + assert_snapshot!(goto(" +create aggregate myavg$0(int) (sfunc = int4_avg_accum, stype = _int8); +"), @r" + ╭▸ + 2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + │ ┬───┬ + │ │ │ + │ │ 1. source + ╰╴ 2. destination + "); + } + + #[test] + fn goto_drop_aggregate_with_search_path() { + assert_snapshot!(goto(" +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +set search_path to bar; +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +set search_path to default; +drop aggregate myavg$0(int); +"), @r" + ╭▸ + 2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + │ ───── 2. destination + ‡ + 6 │ drop aggregate myavg(int); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_aggregate_multiple() { + assert_snapshot!(goto(" +create aggregate avg1(int) (sfunc = int4_avg_accum, stype = _int8); +create aggregate avg2(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate avg1(int), avg2$0(int); +"), @r" + ╭▸ + 3 │ create aggregate avg2(int) (sfunc = int4_avg_accum, stype = _int8); + │ ──── 2. destination + 4 │ drop aggregate avg1(int), avg2(int); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_aggregate_overloaded() { + assert_snapshot!(goto(" +create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)'); +create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); +drop aggregate sum$0(complex); +"), @r" + ╭▸ + 2 │ create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)'); + │ ─── 2. destination + 3 │ create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); + 4 │ drop aggregate sum(complex); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_aggregate_second_overload() { + assert_snapshot!(goto(" +create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)'); +create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); +drop aggregate sum$0(bigint); +"), @r" + ╭▸ + 3 │ create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); + │ ─── 2. destination + 4 │ drop aggregate sum(bigint); + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 2e8074c3..d3f020dc 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -15,6 +15,7 @@ enum NameRefContext { Table, DropIndex, DropFunction, + DropAggregate, DropSchema, CreateIndex, CreateIndexColumn, @@ -75,11 +76,34 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti resolve_index(binder, &index_name, &schema, position) } NameRefContext::DropFunction => { - let path = find_containing_path(name_ref)?; + let function_sig = name_ref + .syntax() + .ancestors() + .find_map(ast::FunctionSig::cast)?; + let path = function_sig.path()?; let function_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); + let params = extract_param_signature(&function_sig); let position = name_ref.syntax().text_range().start(); - resolve_function(binder, &function_name, &schema, position) + resolve_function(binder, &function_name, &schema, params.as_deref(), position) + } + NameRefContext::DropAggregate => { + let aggregate = name_ref + .syntax() + .ancestors() + .find_map(ast::Aggregate::cast)?; + let path = aggregate.path()?; + let aggregate_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + let params = extract_param_signature(&aggregate); + let position = name_ref.syntax().text_range().start(); + resolve_aggregate( + binder, + &aggregate_name, + &schema, + params.as_deref(), + position, + ) } NameRefContext::DropSchema | NameRefContext::SchemaQualifier => { let schema_name = Name::new(name_ref.syntax().text().to_string()); @@ -100,7 +124,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti let position = name_ref.syntax().text_range().start(); // functions take precedence - if let Some(ptr) = resolve_function(binder, &function_name, &schema, position) { + if let Some(ptr) = resolve_function(binder, &function_name, &schema, None, position) { return Some(ptr); } @@ -236,6 +260,9 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option if ast::DropFunction::can_cast(ancestor.kind()) { return Some(NameRefContext::DropFunction); } + if ast::DropAggregate::can_cast(ancestor.kind()) { + return Some(NameRefContext::DropAggregate); + } if ast::DropSchema::can_cast(ancestor.kind()) { return Some(NameRefContext::DropSchema); } @@ -338,21 +365,82 @@ fn resolve_for_kind( None } +fn resolve_for_kind_with_params( + binder: &Binder, + name: &Name, + schema: &Option, + params: Option<&[Name]>, + position: TextSize, + kind: SymbolKind, +) -> Option { + let symbols = binder.scopes[binder.root_scope()].get(name)?; + + if let Some(schema) = schema { + let symbol_id = symbols.iter().copied().find(|id| { + let symbol = &binder.symbols[*id]; + let params_match = match (&symbol.params, params) { + (Some(sym_params), Some(req_params)) => sym_params.as_slice() == req_params, + (None, None) => true, + (_, None) => true, + _ => false, + }; + symbol.kind == kind && &symbol.schema == schema && params_match + })?; + return Some(binder.symbols[symbol_id].ptr); + } else { + let search_path = binder.search_path_at(position); + for search_schema in search_path { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &binder.symbols[*id]; + let params_match = match (&symbol.params, params) { + (Some(sym_params), Some(req_params)) => sym_params.as_slice() == req_params, + (None, None) => true, + (_, None) => true, + _ => false, + }; + symbol.kind == kind && &symbol.schema == search_schema && params_match + }) { + return Some(binder.symbols[symbol_id].ptr); + } + } + } + None +} + fn resolve_function( binder: &Binder, function_name: &Name, schema: &Option, + params: Option<&[Name]>, position: TextSize, ) -> Option { - resolve_for_kind( + resolve_for_kind_with_params( binder, function_name, schema, + params, position, SymbolKind::Function, ) } +fn resolve_aggregate( + binder: &Binder, + aggregate_name: &Name, + schema: &Option, + params: Option<&[Name]>, + position: TextSize, +) -> Option { + resolve_for_kind_with_params( + binder, + aggregate_name, + schema, + params, + position, + SymbolKind::Aggregate, + ) +} + fn resolve_schema(binder: &Binder, schema_name: &Name) -> Option { let symbols = binder.scopes[binder.root_scope()].get(schema_name)?; let symbol_id = symbols.iter().copied().find(|id| { @@ -579,7 +667,7 @@ fn resolve_select_qualified_column( // 2. No column found, check for field-style function call // e.g., select t.b from t where b is a function that takes t as an argument - resolve_function(binder, &column_name, &schema, position) + resolve_function(binder, &column_name, &schema, None, position) } fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { @@ -1040,3 +1128,19 @@ fn extract_schema_from_path(path: &ast::Path) -> Option { .and_then(|s| s.name_ref()) .map(|name_ref| name_ref.syntax().text().to_string()) } + +fn extract_param_signature(node: &impl ast::HasParamList) -> Option> { + let param_list = node.param_list()?; + let mut params = vec![]; + for param in param_list.params() { + if let Some(ty) = param.ty() + && let ast::Type::PathType(path_type) = ty + && let Some(path) = path_type.path() + && let Some(segment) = path.segment() + && let Some(name_ref) = segment.name_ref() + { + params.push(Name::new(name_ref.syntax().text().to_string())); + } + } + (!params.is_empty()).then_some(params) +} diff --git a/crates/squawk_ide/src/symbols.rs b/crates/squawk_ide/src/symbols.rs index 7f98f054..e4b400e3 100644 --- a/crates/squawk_ide/src/symbols.rs +++ b/crates/squawk_ide/src/symbols.rs @@ -42,6 +42,7 @@ pub(crate) enum SymbolKind { Table, Index, Function, + Aggregate, Schema, } @@ -50,6 +51,7 @@ pub(crate) struct Symbol { pub(crate) kind: SymbolKind, pub(crate) ptr: SyntaxNodePtr, pub(crate) schema: Schema, + pub(crate) params: Option>, } pub(crate) type SymbolId = Idx; diff --git a/crates/squawk_syntax/src/ast.rs b/crates/squawk_syntax/src/ast.rs index f73094cf..2c5e059e 100644 --- a/crates/squawk_syntax/src/ast.rs +++ b/crates/squawk_syntax/src/ast.rs @@ -54,6 +54,7 @@ pub use self::{ // HasVisibility, // HasGenericParams, HasLoopBody, HasName, + HasParamList, }, }; diff --git a/crates/squawk_syntax/src/ast/node_ext.rs b/crates/squawk_syntax/src/ast/node_ext.rs index bb587e2b..400bb148 100644 --- a/crates/squawk_syntax/src/ast/node_ext.rs +++ b/crates/squawk_syntax/src/ast/node_ext.rs @@ -229,6 +229,9 @@ pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> { } } +impl ast::HasParamList for ast::FunctionSig {} +impl ast::HasParamList for ast::Aggregate {} + #[test] fn index_expr() { let source_code = " diff --git a/crates/squawk_syntax/src/ast/traits.rs b/crates/squawk_syntax/src/ast/traits.rs index 35f10f86..7011d72c 100644 --- a/crates/squawk_syntax/src/ast/traits.rs +++ b/crates/squawk_syntax/src/ast/traits.rs @@ -15,6 +15,12 @@ pub trait HasArgList: AstNode { } } +pub trait HasParamList: AstNode { + fn param_list(&self) -> Option { + support::child(self.syntax()) + } +} + pub trait HasIfExists: AstNode { fn if_exists(&self) -> Option { support::child(self.syntax())