@@ -11,7 +11,7 @@ use sqlparser::ast::helpers::attached_token::AttachedToken;
1111use sqlparser:: ast:: {
1212 CastKind , DataType , Expr , Function , FunctionArg , FunctionArgExpr , FunctionArgumentList ,
1313 FunctionArguments , Ident , ObjectName , ObjectNamePart , SelectFlavor , SelectItem , Set , SetExpr ,
14- Spanned , Statement , Value , ValueWithSpan ,
14+ Spanned , Statement , Value , ValueWithSpan , VisitMut , VisitorMut ,
1515} ;
1616use sqlparser:: dialect:: {
1717 Dialect , DuckDbDialect , GenericDialect , MsSqlDialect , MySqlDialect , OracleDialect ,
@@ -22,6 +22,7 @@ use sqlparser::tokenizer::Token::{self, EOF, SemiColon};
2222use sqlparser:: tokenizer:: { Location , Span , TokenWithSpan , Tokenizer } ;
2323use sqlx:: any:: AnyKind ;
2424use std:: fmt:: Write ;
25+ use std:: ops:: ControlFlow ;
2526use std:: path:: { Path , PathBuf } ;
2627use std:: str:: FromStr ;
2728
@@ -234,6 +235,7 @@ fn parse_single_statement(
234235 return Some ( ParsedStatement :: Error ( err) ) ;
235236 }
236237 let json_columns = extract_json_columns ( & stmt, dbms) ;
238+ escape_national_string_literals_for_display ( & mut stmt) ;
237239 let query = format ! (
238240 "{stmt}{semicolon}" ,
239241 semicolon = if semicolon { ";" } else { "" }
@@ -250,6 +252,24 @@ fn parse_single_statement(
250252 Some ( ParsedStatement :: StmtWithParams ( stmt_with_params) )
251253}
252254
255+ fn escape_national_string_literals_for_display ( stmt : & mut Statement ) {
256+ struct NationalStringEscaper ;
257+
258+ impl VisitorMut for NationalStringEscaper {
259+ type Break = std:: convert:: Infallible ;
260+
261+ fn pre_visit_value ( & mut self , value : & mut ValueWithSpan ) -> ControlFlow < Self :: Break > {
262+ if let Value :: NationalStringLiteral ( s) = & mut value. value {
263+ // sqlparser 0.62 does not escape NationalStringLiteral when formatting the AST.
264+ * s = s. replace ( '\'' , "''" ) ;
265+ }
266+ ControlFlow :: Continue ( ( ) )
267+ }
268+ }
269+
270+ let _ = stmt. visit ( & mut NationalStringEscaper ) ;
271+ }
272+
253273fn extract_query_start ( stmt : & impl Spanned ) -> SourceSpan {
254274 let location = stmt. span ( ) ;
255275 SourceSpan {
@@ -1020,6 +1040,19 @@ mod test {
10201040 assert_eq ! ( parameters, [ StmtParam :: PostOrGet ( "1" . to_string( ) ) , ] ) ;
10211041 }
10221042
1043+ #[ test]
1044+ fn test_mssql_national_string_literal_rewrite ( ) {
1045+ let sql = r#"select N'Tu geres '';'' et ''"'' ?' as msg"# ;
1046+ let db_info = create_test_db_info ( SupportedDatabase :: Mssql ) ;
1047+ let mut parsed = parse_sql ( & db_info, & MsSqlDialect { } , sql) . unwrap ( ) ;
1048+ match parsed. next ( ) . expect ( "expected one statement" ) {
1049+ ParsedStatement :: StmtWithParams ( stmt) => {
1050+ assert_eq ! ( stmt. query, r#"SELECT N'Tu geres '';'' et ''"'' ?' AS msg"# ) ;
1051+ }
1052+ other => panic ! ( "expected a database statement: {other:?}" ) ,
1053+ }
1054+ }
1055+
10231056 #[ test]
10241057 fn test_static_extract ( ) {
10251058 use SimpleSelectValue :: Static ;
0 commit comments