Skip to content

Commit c355cd2

Browse files
committed
Preserve MSSQL national string quotes
1 parent 7ddcf79 commit c355cd2

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

src/webserver/database/sql.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use sqlparser::ast::helpers::attached_token::AttachedToken;
1111
use sqlparser::ast::{
1212
CastKind, DataType, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
1313
FunctionArguments, Ident, ObjectName, ObjectNamePart, SelectFlavor, SelectItem, Set, SetExpr,
14-
Spanned, Statement, Value, ValueWithSpan,
14+
Spanned, Statement, Value, ValueWithSpan, VisitMut, VisitorMut,
1515
};
1616
use sqlparser::dialect::{
1717
Dialect, DuckDbDialect, GenericDialect, MsSqlDialect, MySqlDialect, OracleDialect,
@@ -22,6 +22,7 @@ use sqlparser::tokenizer::Token::{self, EOF, SemiColon};
2222
use sqlparser::tokenizer::{Location, Span, TokenWithSpan, Tokenizer};
2323
use sqlx::any::AnyKind;
2424
use std::fmt::Write;
25+
use std::ops::ControlFlow;
2526
use std::path::{Path, PathBuf};
2627
use std::str::FromStr;
2728

@@ -234,6 +235,7 @@ fn parse_single_statement(
234235
return Some(ParsedStatement::Error(err));
235236
}
236237
let json_columns = extract_json_columns(&stmt, dbms);
238+
escape_national_string_literals_for_display(&mut stmt);
237239
let query = format!(
238240
"{stmt}{semicolon}",
239241
semicolon = if semicolon { ";" } else { "" }
@@ -250,6 +252,24 @@ fn parse_single_statement(
250252
Some(ParsedStatement::StmtWithParams(stmt_with_params))
251253
}
252254

255+
fn escape_national_string_literals_for_display(stmt: &mut Statement) {
256+
struct NationalStringEscaper;
257+
258+
impl VisitorMut for NationalStringEscaper {
259+
type Break = std::convert::Infallible;
260+
261+
fn pre_visit_value(&mut self, value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
262+
if let Value::NationalStringLiteral(s) = &mut value.value {
263+
// sqlparser 0.62 does not escape NationalStringLiteral when formatting the AST.
264+
*s = s.replace('\'', "''");
265+
}
266+
ControlFlow::Continue(())
267+
}
268+
}
269+
270+
let _ = stmt.visit(&mut NationalStringEscaper);
271+
}
272+
253273
fn extract_query_start(stmt: &impl Spanned) -> SourceSpan {
254274
let location = stmt.span();
255275
SourceSpan {
@@ -1020,6 +1040,19 @@ mod test {
10201040
assert_eq!(parameters, [StmtParam::PostOrGet("1".to_string()),]);
10211041
}
10221042

1043+
#[test]
1044+
fn test_mssql_national_string_literal_rewrite() {
1045+
let sql = r#"select N'Tu geres '';'' et ''"'' ?' as msg"#;
1046+
let db_info = create_test_db_info(SupportedDatabase::Mssql);
1047+
let mut parsed = parse_sql(&db_info, &MsSqlDialect {}, sql).unwrap();
1048+
match parsed.next().expect("expected one statement") {
1049+
ParsedStatement::StmtWithParams(stmt) => {
1050+
assert_eq!(stmt.query, r#"SELECT N'Tu geres '';'' et ''"'' ?' AS msg"#);
1051+
}
1052+
other => panic!("expected a database statement: {other:?}"),
1053+
}
1054+
}
1055+
10231056
#[test]
10241057
fn test_static_extract() {
10251058
use SimpleSelectValue::Static;

tests/data_formats/csv_data_mssql.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ select
88
union all
99
select
1010
1 as id,
11-
CONCAT(N'Tu gères ', NCHAR(39), NCHAR(59), NCHAR(39), N' et ', NCHAR(39), NCHAR(34), NCHAR(39), N' ?') as msg;
11+
N'Tu gères '';'' et ''"'' ?' as msg;

0 commit comments

Comments
 (0)