Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) {
ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table),
ast::Stmt::CreateIndex(create_index) => bind_create_index(b, create_index),
ast::Stmt::CreateFunction(create_function) => bind_create_function(b, create_function),
ast::Stmt::CreateAggregate(create_aggregate) => bind_create_aggregate(b, create_aggregate),
ast::Stmt::CreateSchema(create_schema) => bind_create_schema(b, create_schema),
ast::Stmt::Set(set) => bind_set(b, set),
_ => {}
Expand All @@ -103,6 +104,7 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) {
kind: SymbolKind::Table,
ptr: name_ptr,
schema,
params: None,
});

let root = b.root_scope();
Expand All @@ -125,6 +127,7 @@ fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) {
kind: SymbolKind::Index,
ptr: name_ptr,
schema,
params: None,
});

let root = b.root_scope();
Expand All @@ -146,16 +149,47 @@ fn bind_create_function(b: &mut Binder, create_function: ast::CreateFunction) {
return;
};

let params = extract_param_signature(create_function.param_list());

let function_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Function,
ptr: name_ptr,
schema,
params,
});

let root = b.root_scope();
b.scopes[root].insert(function_name, function_id);
}

fn bind_create_aggregate(b: &mut Binder, create_aggregate: ast::CreateAggregate) {
let Some(path) = create_aggregate.path() else {
return;
};

let Some(aggregate_name) = item_name(&path) else {
return;
};

let name_ptr = path_to_ptr(&path);

let Some(schema) = schema_name(b, &path, false) else {
return;
};

let params = extract_param_signature(create_aggregate.param_list());

let aggregate_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Aggregate,
ptr: name_ptr,
schema,
params,
});

let root = b.root_scope();
b.scopes[root].insert(aggregate_name, aggregate_id);
}

fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) {
let Some(schema_name_node) = create_schema.name() else {
return;
Expand All @@ -168,6 +202,7 @@ fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) {
kind: SymbolKind::Schema,
ptr: name_ptr,
schema: Schema(schema_name.clone()),
params: None,
});

let root = b.root_scope();
Expand Down Expand Up @@ -293,3 +328,19 @@ fn extract_string_literal(literal: &ast::Literal) -> Option<String> {
None
}
}

fn extract_param_signature(param_list: Option<ast::ParamList>) -> Option<Vec<Name>> {
let param_list = param_list?;
let mut params = vec![];
for param in param_list.params() {
if let Some(ty) = param.ty()
&& let ast::Type::PathType(path_type) = ty
&& let Some(path) = path_type.path()
&& let Some(segment) = path.segment()
&& let Some(name_ref) = segment.name_ref()
{
params.push(Name::new(name_ref.syntax().text().to_string()));
}
}
(!params.is_empty()).then_some(params)
}
152 changes: 152 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,37 @@ drop function foo(), bar$0();
");
}

#[test]
fn goto_drop_function_overloaded() {
assert_snapshot!(goto("
create function add(complex) returns complex as $$ select null $$ language sql;
create function add(bigint) returns bigint as $$ select 1 $$ language sql;
drop function add$0(complex);
"), @r"
╭▸
2 │ create function add(complex) returns complex as $$ select null $$ language sql;
│ ─── 2. destination
3 │ create function add(bigint) returns bigint as $$ select 1 $$ language sql;
4 │ drop function add(complex);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_function_second_overload() {
assert_snapshot!(goto("
create function add(complex) returns complex as $$ select null $$ language sql;
create function add(bigint) returns bigint as $$ select 1 $$ language sql;
drop function add$0(bigint);
"), @r"
╭▸
3 │ create function add(bigint) returns bigint as $$ select 1 $$ language sql;
│ ─── 2. destination
4 │ drop function add(bigint);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_function_call() {
assert_snapshot!(goto("
Expand Down Expand Up @@ -2300,4 +2331,125 @@ select a$0 from t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate() {
assert_snapshot!(goto("
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate myavg$0(int);
"), @r"
╭▸
2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
│ ───── 2. destination
3 │ drop aggregate myavg(int);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate_with_schema() {
assert_snapshot!(goto("
set search_path to public;
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate public.myavg$0(int);
"), @r"
╭▸
3 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
│ ───── 2. destination
4 │ drop aggregate public.myavg(int);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate_defined_after() {
assert_snapshot!(goto("
drop aggregate myavg$0(int);
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
"), @r"
╭▸
2 │ drop aggregate myavg(int);
│ ─ 1. source
3 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
╰╴ ───── 2. destination
");
}

#[test]
fn goto_aggregate_definition_returns_self() {
assert_snapshot!(goto("
create aggregate myavg$0(int) (sfunc = int4_avg_accum, stype = _int8);
"), @r"
╭▸
2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
│ ┬───┬
│ │ │
│ │ 1. source
╰╴ 2. destination
");
}

#[test]
fn goto_drop_aggregate_with_search_path() {
assert_snapshot!(goto("
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
set search_path to bar;
create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
set search_path to default;
drop aggregate myavg$0(int);
"), @r"
╭▸
2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);
│ ───── 2. destination
6 │ drop aggregate myavg(int);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate_multiple() {
assert_snapshot!(goto("
create aggregate avg1(int) (sfunc = int4_avg_accum, stype = _int8);
create aggregate avg2(int) (sfunc = int4_avg_accum, stype = _int8);
drop aggregate avg1(int), avg2$0(int);
"), @r"
╭▸
3 │ create aggregate avg2(int) (sfunc = int4_avg_accum, stype = _int8);
│ ──── 2. destination
4 │ drop aggregate avg1(int), avg2(int);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate_overloaded() {
assert_snapshot!(goto("
create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)');
create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
drop aggregate sum$0(complex);
"), @r"
╭▸
2 │ create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)');
│ ─── 2. destination
3 │ create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
4 │ drop aggregate sum(complex);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_aggregate_second_overload() {
assert_snapshot!(goto("
create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)');
create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
drop aggregate sum$0(bigint);
"), @r"
╭▸
3 │ create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0');
│ ─── 2. destination
4 │ drop aggregate sum(bigint);
╰╴ ─ 1. source
");
}
}
Loading
Loading