diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 19e28795..ba7ba1f2 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -3077,4 +3077,179 @@ delete from users where id in (select id$0 from old_data); ╰╴ ─ 1. source "); } + + #[test] + fn goto_update_table() { + assert_snapshot!(goto(" +create table users(id int, email text); +update users$0 set email = 'new@example.com'; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ───── 2. destination + 3 │ update users set email = 'new@example.com'; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_table_with_schema() { + assert_snapshot!(goto(" +create table public.users(id int, email text); +update public.users$0 set email = 'new@example.com'; +"), @r" + ╭▸ + 2 │ create table public.users(id int, email text); + │ ───── 2. destination + 3 │ update public.users set email = 'new@example.com'; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_table_with_search_path() { + assert_snapshot!(goto(" +set search_path to foo; +create table foo.users(id int, email text); +update users$0 set email = 'new@example.com'; +"), @r" + ╭▸ + 3 │ create table foo.users(id int, email text); + │ ───── 2. destination + 4 │ update users set email = 'new@example.com'; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_where_column() { + assert_snapshot!(goto(" +create table users(id int, email text); +update users set email = 'new@example.com' where id$0 = 1; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ── 2. destination + 3 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_where_column_with_schema() { + assert_snapshot!(goto(" +create table public.users(id int, email text); +update public.users set email = 'new@example.com' where id$0 = 1; +"), @r" + ╭▸ + 2 │ create table public.users(id int, email text); + │ ── 2. destination + 3 │ update public.users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_where_column_with_search_path() { + assert_snapshot!(goto(" +set search_path to foo; +create table foo.users(id int, email text); +update users set email = 'new@example.com' where id$0 = 1; +"), @r" + ╭▸ + 3 │ create table foo.users(id int, email text); + │ ── 2. destination + 4 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_set_column() { + assert_snapshot!(goto(" +create table users(id int, email text); +update users set email$0 = 'new@example.com' where id = 1; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ───── 2. destination + 3 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_set_column_with_schema() { + assert_snapshot!(goto(" +create table public.users(id int, email text); +update public.users set email$0 = 'new@example.com' where id = 1; +"), @r" + ╭▸ + 2 │ create table public.users(id int, email text); + │ ───── 2. destination + 3 │ update public.users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_set_column_with_search_path() { + assert_snapshot!(goto(" +set search_path to foo; +create table foo.users(id int, email text); +update users set email$0 = 'new@example.com' where id = 1; +"), @r" + ╭▸ + 3 │ create table foo.users(id int, email text); + │ ───── 2. destination + 4 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_from_table() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, email text); +update users set email = messages.email from messages$0 where users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, email text); + │ ──────── 2. destination + 4 │ update users set email = messages.email from messages where users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_from_table_with_schema() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table public.messages(id int, user_id int, email text); +update users set email = messages.email from public.messages$0 where users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table public.messages(id int, user_id int, email text); + │ ──────── 2. destination + 4 │ update users set email = messages.email from public.messages where users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_from_table_with_search_path() { + assert_snapshot!(goto(" +set search_path to foo; +create table users(id int, email text); +create table foo.messages(id int, user_id int, email text); +update users set email = messages.email from messages$0 where users.id = messages.user_id; +"), @r" + ╭▸ + 4 │ create table foo.messages(id int, user_id int, email text); + │ ──────── 2. destination + 5 │ update users set email = messages.email from messages where users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 7ca4f246..f076bbb3 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -31,6 +31,10 @@ enum NameRefContext { InsertColumn, DeleteTable, DeleteWhereColumn, + UpdateTable, + UpdateWhereColumn, + UpdateSetColumn, + UpdateFromTable, SchemaQualifier, } @@ -42,7 +46,8 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti | NameRefContext::Table | NameRefContext::CreateIndex | NameRefContext::InsertTable - | NameRefContext::DeleteTable => { + | NameRefContext::DeleteTable + | NameRefContext::UpdateTable => { let path = find_containing_path(name_ref)?; let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); @@ -201,6 +206,29 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti NameRefContext::SelectQualifiedColumn => resolve_select_qualified_column(binder, name_ref), NameRefContext::InsertColumn => resolve_insert_column(binder, name_ref), NameRefContext::DeleteWhereColumn => resolve_delete_where_column(binder, name_ref), + NameRefContext::UpdateWhereColumn => resolve_update_where_column(binder, name_ref), + NameRefContext::UpdateSetColumn => resolve_update_set_column(binder, name_ref), + NameRefContext::UpdateFromTable => { + let table_name = Name::from_node(name_ref); + let schema = if let Some(parent) = name_ref.syntax().parent() + && let Some(field_expr) = ast::FieldExpr::cast(parent) + && let Some(base) = field_expr.base() + && let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone()) + { + Some(Schema(Name::from_node(&schema_name_ref))) + } else { + None + }; + + if schema.is_none() + && let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) + { + return Some(cte_ptr); + } + + let position = name_ref.syntax().text_range().start(); + resolve_table(binder, &table_name, &schema, position) + } } } @@ -211,6 +239,7 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option let mut in_column_list = false; let mut in_where_clause = false; let mut in_from_clause = false; + let mut in_set_clause = false; // TODO: can we combine this if and the one that follows? if let Some(parent) = name_ref.syntax().parent() @@ -368,12 +397,27 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option if ast::WhereClause::can_cast(ancestor.kind()) { in_where_clause = true; } + if ast::SetClause::can_cast(ancestor.kind()) { + in_set_clause = true; + } if ast::Delete::can_cast(ancestor.kind()) { if in_where_clause { return Some(NameRefContext::DeleteWhereColumn); } return Some(NameRefContext::DeleteTable); } + if ast::Update::can_cast(ancestor.kind()) { + if in_where_clause { + return Some(NameRefContext::UpdateWhereColumn); + } + if in_set_clause { + return Some(NameRefContext::UpdateSetColumn); + } + if in_from_clause { + return Some(NameRefContext::UpdateFromTable); + } + return Some(NameRefContext::UpdateTable); + } } None @@ -930,6 +974,70 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti None } +fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { + let column_name = Name::from_node(name_ref); + + let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?; + let relation_name = update.relation_name()?; + let path = relation_name.path()?; + + let table_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + let position = name_ref.syntax().text_range().start(); + + let table_ptr = resolve_table(binder, &table_name, &schema, position)?; + + let root = &name_ref.syntax().ancestors().last()?; + let table_name_node = table_ptr.to_node(root); + + let create_table = table_name_node + .ancestors() + .find_map(ast::CreateTable::cast)?; + + for arg in create_table.table_arg_list()?.args() { + if let ast::TableArg::Column(column) = arg + && let Some(col_name) = column.name() + && Name::from_node(&col_name) == column_name + { + return Some(SyntaxNodePtr::new(col_name.syntax())); + } + } + + None +} + +fn resolve_update_set_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { + let column_name = Name::from_node(name_ref); + + let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?; + let relation_name = update.relation_name()?; + let path = relation_name.path()?; + + let table_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + let position = name_ref.syntax().text_range().start(); + + let table_ptr = resolve_table(binder, &table_name, &schema, position)?; + + let root = &name_ref.syntax().ancestors().last()?; + let table_name_node = table_ptr.to_node(root); + + let create_table = table_name_node + .ancestors() + .find_map(ast::CreateTable::cast)?; + + for arg in create_table.table_arg_list()?.args() { + if let ast::TableArg::Column(column) = arg + && let Some(col_name) = column.name() + && Name::from_node(&col_name) == column_name + { + return Some(SyntaxNodePtr::new(col_name.syntax())); + } + } + + None +} + fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { let column_name = Name::from_node(name_ref);