Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions source/compiler/qsc_frontend/src/typeck/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,32 @@ fn adj_requires_unit_return() {
);
}

#[test]
fn should_have_been_array() {
check(
indoc! {"
namespace A {
operation Foo(qs: Qubit[]) : Unit is Adj {
Foo(qs[0])
}
}
"},
"",
&expect![[r##"
#6 31-44 "(qs: Qubit[])" : Qubit[]
#7 32-43 "qs: Qubit[]" : Qubit[]
#17 59-85 "{\n Foo(qs[0])\n }" : Unit
#19 69-79 "Foo(qs[0])" : Unit
#20 69-72 "Foo" : (Qubit[] => Unit is Adj)
#23 72-79 "(qs[0])" : Qubit
#24 73-78 "qs[0]" : Qubit
#25 73-75 "qs" : Qubit[]
#28 76-77 "0" : Int
Error(Type(Error(TyMismatch("Qubit[]", "Qubit", Span { lo: 69, hi: 79 }))))
"##]],
);
}

#[test]
fn ctl_requires_unit_return() {
check(
Expand Down
7 changes: 7 additions & 0 deletions source/language_service/src/code_action.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

mod wrap_in_array;
mod wrapper_refactor;

use miette::Diagnostic;
Expand Down Expand Up @@ -33,6 +34,12 @@ pub(crate) fn get_code_actions(
span,
position_encoding,
));
actions.extend(wrap_in_array::wrap_in_array_fixes(
compilation,
source_name,
span,
position_encoding,
));
actions
}

Expand Down
139 changes: 139 additions & 0 deletions source/language_service/src/code_action/wrap_in_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

// Code action: "Convert to single-element array"
// Detects when a value is passed where an array of that type is expected,
// and offers to wrap it in `[...]`.

#[cfg(test)]
mod tests;

use qsc::{
Span,
ast::{self, Expr, ExprKind, NodeId},
display::Lookup,
hir::ty::Ty,
line_column::{Encoding, Range},
};

use crate::{
compilation::Compilation,
protocol::{CodeAction, CodeActionKind, TextEdit, WorkspaceEdit},
};

pub(crate) fn wrap_in_array_fixes(
compilation: &Compilation,
source_name: &str,
span: Span,
encoding: Encoding,
) -> Vec<CodeAction> {
let mut code_actions = Vec::new();

let unit = compilation.user_unit();
let package = &unit.ast.package;
let source = unit
.sources
.find_by_name(source_name)
.expect("source should exist");

// Find all call expressions overlapping the requested span.
let mut finder = CallFinder {
target_span: span,
found: Vec::new(),
};
ast::visit::Visitor::visit_package(&mut finder, package);

for (callee_id, args) in finder.found {
// Look up the callee type to get the expected parameter types.
let Some(callee_ty) = compilation.get_ty(callee_id) else {
continue;
};
let Ty::Arrow(arrow) = callee_ty else {
continue;
};

let expected_input = arrow.input.borrow();
let param_tys: Vec<&Ty> = match &*expected_input {
Ty::Tuple(tys) => tys.iter().collect(),
other => vec![other],
};

if args.len() != param_tys.len() {
continue;
}

// Match arguments against parameters.
for (arg, param_ty) in args.iter().zip(param_tys.iter()) {
let Some(arg_ty) = compilation.get_ty(arg.id) else {
continue;
};
// Check if expected is Array(T) and actual is T.
if let Ty::Array(item_ty) = param_ty
&& item_ty.as_ref() == arg_ty
{
// Generate the fix: wrap arg in [...]
let lo = (arg.span.lo - source.offset) as usize;
let hi = (arg.span.hi - source.offset) as usize;
let arg_text = &source.contents[lo..hi];
let new_text = format!("[{arg_text}]");
let range =
Range::from_span(encoding, &source.contents, &(arg.span - source.offset));
code_actions.push(CodeAction {
title: "Convert to single-element array".to_string(),
edit: Some(WorkspaceEdit {
changes: vec![(
source_name.to_string(),
vec![TextEdit { new_text, range }],
)],
}),
kind: Some(CodeActionKind::QuickFix),
is_preferred: None,
});
}
}
}

code_actions
}

/// AST visitor that finds Call expressions overlapping the target span and extracts
/// the callee node id and individual argument expressions.
struct CallFinder<'a> {
target_span: Span,
found: Vec<(NodeId, Vec<&'a Expr>)>,
}

impl<'a> ast::visit::Visitor<'a> for CallFinder<'a> {
fn visit_namespace(&mut self, namespace: &'a ast::Namespace) {
if self.target_span.intersection(&namespace.span).is_some() {
ast::visit::walk_namespace(self, namespace);
}
}

fn visit_stmt(&mut self, stmt: &'a ast::Stmt) {
if self.target_span.intersection(&stmt.span).is_some() {
ast::visit::walk_stmt(self, stmt);
}
}

fn visit_expr(&mut self, expr: &'a Expr) {
if self.target_span.intersection(&expr.span).is_some() {
if let ExprKind::Call(callee, arg) = &*expr.kind {
let args = extract_args(arg);
self.found.push((callee.id, args));
}
ast::visit::walk_expr(self, expr);
}
}
}

/// Given a call argument expression, extract the individual argument expressions.
/// If the argument is a tuple, returns each element. If it's a paren-wrapped
/// single expression, returns the inner expression. Otherwise returns the expression itself.
fn extract_args(arg: &Expr) -> Vec<&Expr> {
match &*arg.kind {
ExprKind::Tuple(items) => items.iter().map(AsRef::as_ref).collect(),
ExprKind::Paren(inner) => vec![inner.as_ref()],
_ => vec![arg],
}
}
94 changes: 94 additions & 0 deletions source/language_service/src/code_action/wrap_in_array/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use crate::{code_action, test_utils::compile_project_with_markers_no_cursor};
use qsc::line_column::{Encoding, Position, Range};

fn get_wrap_in_array_actions(source: &str) -> Vec<crate::protocol::CodeAction> {
let (compilation, _targets) =
compile_project_with_markers_no_cursor(&[("<source>", source)], false);
let newline_count = u32::try_from(source.matches('\n').count()).expect("count fits");
let end = if newline_count == 0 {
Position {
line: 0,
column: u32::try_from(source.len()).expect("len fits"),
}
} else {
Position {
line: newline_count,
column: 0,
}
};
let range = Range {
start: Position { line: 0, column: 0 },
end,
};
let actions = code_action::get_code_actions(&compilation, "<source>", range, Encoding::Utf8);
actions
.into_iter()
.filter(|a| a.title == "Convert to single-element array")
.collect()
}

#[test]
fn single_arg_qubit_to_qubit_array() {
let source = "namespace A {
operation Foo(qs: Qubit[]) : Unit is Adj {
use q = Qubit();
Foo(q);
}
}
";
let actions = get_wrap_in_array_actions(source);
assert_eq!(actions.len(), 1, "Expected 1 action, got: {actions:?}");
let action = &actions[0];
let edit = action.edit.as_ref().expect("expected edit");
let (_, text_edits) = &edit.changes[0];
assert_eq!(text_edits.len(), 1);
assert_eq!(text_edits[0].new_text, "[q]");
}

#[test]
fn multi_arg_second_param_is_array() {
let source = "namespace A {
operation Bar(x: Int, qs: Qubit[]) : Unit {
use q = Qubit();
Bar(1, q);
}
}
";
let actions = get_wrap_in_array_actions(source);
assert_eq!(actions.len(), 1, "Expected 1 action, got: {actions:?}");
let action = &actions[0];
let edit = action.edit.as_ref().expect("expected edit");
let (_, text_edits) = &edit.changes[0];
assert_eq!(text_edits.len(), 1);
assert_eq!(text_edits[0].new_text, "[q]");
}

#[test]
fn no_action_when_types_already_match() {
let source = "namespace A {
operation Foo(qs: Qubit[]) : Unit is Adj {
use q = Qubit();
Foo([q]);
}
}
";
let actions = get_wrap_in_array_actions(source);
assert!(actions.is_empty(), "Expected no actions, got: {actions:?}");
}

#[test]
fn no_action_for_unrelated_mismatch() {
// Int passed where String expected - should NOT offer wrap in array.
let source = "namespace A {
function Foo(s: String) : Unit {}
function Bar() : Unit {
Foo(42);
}
}
";
let actions = get_wrap_in_array_actions(source);
assert!(actions.is_empty(), "Expected no actions, got: {actions:?}");
}
Loading