From 89dc7ba2aaec0c7b16a4b9386b4f83c4166fd07c Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 4 Jun 2026 13:25:14 +0800 Subject: [PATCH 1/4] [SPARK-57255][SQL] Simplify RegExpReplace codegen by extracting the match/replace loop into a shared helper RegExpReplace inlined the match/replace loop (matcher.region, the find/appendReplacement/appendTail loop, and the invalidRegexpReplaceError construction in a try/catch) in both nullSafeEval and doGenCode, so every generated whole-stage class carried that block plus the error-construction constant-pool entries. Move the loop into a shared `RegExpUtils.replace(matcher, subject, replacement, regexp, rep, pos)` that both eval and codegen call; doGenCode now emits a single call. The eval-only cached `result` StringBuilder field is dropped (the helper allocates its own; codegen always allocated per call anyway). For a single regexp_replace whole-stage stage, this drops maxMethodCodeSize from 551 to 435 (-21%) and maxConstantPoolSize from 277 to 244 (-12%), with the loop body compiled once instead of per stage. Behavior is unchanged. Adds a throwing-path test (a replacement referencing a non-existent group -> INVALID_REGEXP_REPLACE), which the tree previously did not cover. --- .../expressions/regexpExpressions.scala | 93 ++++++++----------- .../expressions/RegexpExpressionsSuite.scala | 11 +++ 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c2c01d2c78159..729200a5ed829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -723,8 +723,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // last replacement string, we don't want to convert a UTF8String => java.langString every time. @transient private var lastReplacement: String = _ @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private lazy val result: JStringBuilder = new JStringBuilder final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { @@ -738,26 +736,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } - val source = s.toString() - val position = i.asInstanceOf[Int] - 1 - if (position == 0 || position < source.length) { - val m = pattern.matcher(source) - m.region(position, source.length) - result.delete(0, result.length()) - while (m.find) { - try { - m.appendReplacement(result, lastReplacement) - } catch { - case NonFatal(e) => - throw QueryExecutionErrors.invalidRegexpReplaceError(s.toString, - p.toString, r.toString, i.asInstanceOf[Int], e) - } - } - m.appendTail(result) - UTF8String.fromString(result.toString) - } else { - s - } + val subject = s.asInstanceOf[UTF8String] + val m = pattern.matcher(subject.toString) + RegExpUtils.replace(m, subject, lastReplacement, + p.asInstanceOf[UTF8String], r.asInstanceOf[UTF8String], i.asInstanceOf[Int]) } override def dataType: DataType = subject.dataType @@ -768,13 +750,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termResult = ctx.freshName("termResult") - - val classNameStringBuilder = classOf[JStringBuilder].getCanonicalName - val matcher = ctx.freshName("matcher") - val source = ctx.freshName("source") - val position = ctx.freshName("position") val termLastReplacement = ctx.addMutableState("String", "lastReplacement") val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") @@ -784,6 +760,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } else { "" } + val regExpUtils = RegExpUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => { s""" @@ -793,30 +770,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio $termLastReplacementInUTF8 = $rep.clone(); $termLastReplacement = $termLastReplacementInUTF8.toString(); } - String $source = $subject.toString(); - int $position = $pos - 1; - if ($position == 0 || $position < $source.length()) { - $classNameStringBuilder $termResult = new $classNameStringBuilder(); - $matcher.region($position, $source.length()); - - while ($matcher.find()) { - try { - $matcher.appendReplacement($termResult, $termLastReplacement); - } catch (Throwable e) { - if (scala.util.control.NonFatal.apply(e)) { - throw QueryExecutionErrors.invalidRegexpReplaceError($source, $regexp.toString(), - $rep.toString(), $pos, e); - } else { - throw e; - } - } - } - $matcher.appendTail($termResult); - ${ev.value} = UTF8String.fromString($termResult.toString()); - $termResult = null; - } else { - ${ev.value} = $subject; - } + ${ev.value} = $regExpUtils.replace( + $matcher, $subject, $termLastReplacement, $regexp, $rep, $pos); $setEvNotNull """ }) @@ -1274,4 +1229,38 @@ object RegExpUtils { r.toString, CollationSupport.collationAwareRegexFlags(collationId), prettyName) (pattern, r) } + + /** + * Runs the regexp_replace loop shared by RegExpReplace's eval and codegen, so the generated + * Java is a single call rather than an inline match/replace loop plus the error construction. + * `matcher` must already wrap `subject.toString`. `regexp`/`rep`/`pos` are only used to build + * the error message on a failed replacement. + */ + def replace( + matcher: Matcher, + subject: UTF8String, + replacement: String, + regexp: UTF8String, + rep: UTF8String, + pos: Int): UTF8String = { + val source = subject.toString + val position = pos - 1 + if (position == 0 || position < source.length) { + matcher.region(position, source.length) + val result = new JStringBuilder + while (matcher.find()) { + try { + matcher.appendReplacement(result, replacement) + } catch { + case NonFatal(e) => + throw QueryExecutionErrors.invalidRegexpReplaceError( + source, regexp.toString, rep.toString, pos, e) + } + } + matcher.appendTail(result) + UTF8String.fromString(result.toString) + } else { + subject + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 9b5a8f2fce066..19c8c8f341f0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -361,6 +361,17 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num")) checkEvaluation(nonNullExpr, "num-num", row1) + // A replacement that references a non-existent group makes the replace fail. + checkErrorInExpression[SparkRuntimeException]( + expr, + create_row("100-200", "(\\d+)", "$2"), + "INVALID_REGEXP_REPLACE", + Map( + "source" -> "100-200", + "pattern" -> "(\\d+)", + "replacement" -> "$2", + "position" -> "1")) + // Test escaping of arguments GenerateUnsafeProjection.generate( RegExpReplace(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) From a63dfedb5b917758f864127b442fc8c28398fb56 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 4 Jun 2026 14:33:27 +0800 Subject: [PATCH 2/4] [SPARK-57255][SQL][FOLLOWUP] Build the matcher inside RegExpUtils.replace so subject.toString is computed once The initial commit had RegExpUtils.replace take a pre-built Matcher, so the caller built the matcher from subject.toString and the helper recomputed subject.toString for the region/error - a duplicate decode per row in the eval path (the codegen path already decoded twice). Make replace take the cached Pattern and build the matcher itself, computing source = subject.toString once. To give codegen the cached pattern term without building a matcher, split initLastPatternCode out of initLastMatcherCode (the latter now composes the former + the matcher line); the other callers (RLike/RegExpExtract/RegExpExtractAll) are unchanged. This also removes the matcher build from every generated class. For a single regexp_replace whole-stage stage the numbers improve to maxMethodCodeSize 551 -> 423 (-23%) and maxConstantPoolSize 277 -> 236 (-15%). --- .../expressions/regexpExpressions.scala | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 729200a5ed829..3c162df9c0c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -736,9 +736,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } - val subject = s.asInstanceOf[UTF8String] - val m = pattern.matcher(subject.toString) - RegExpUtils.replace(m, subject, lastReplacement, + RegExpUtils.replace(pattern, s.asInstanceOf[UTF8String], lastReplacement, p.asInstanceOf[UTF8String], r.asInstanceOf[UTF8String], i.asInstanceOf[Int]) } @@ -750,8 +748,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val matcher = ctx.freshName("matcher") - val termLastReplacement = ctx.addMutableState("String", "lastReplacement") val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") @@ -763,17 +759,19 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val regExpUtils = RegExpUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => { - s""" - ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} - if (!$rep.equals($termLastReplacementInUTF8)) { - // replacement string changed - $termLastReplacementInUTF8 = $rep.clone(); - $termLastReplacement = $termLastReplacementInUTF8.toString(); - } - ${ev.value} = $regExpUtils.replace( - $matcher, $subject, $termLastReplacement, $regexp, $rep, $pos); - $setEvNotNull - """ + val (patternCode, termPattern) = + RegExpUtils.initLastPatternCode(ctx, regexp, prettyName, collationId) + s""" + $patternCode + if (!$rep.equals($termLastReplacementInUTF8)) { + // replacement string changed + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); + } + ${ev.value} = $regExpUtils.replace( + $termPattern, $subject, $termLastReplacement, $regexp, $rep, $pos); + $setEvNotNull + """ }) } @@ -1197,27 +1195,43 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) } object RegExpUtils { - def initLastMatcherCode( + // Emits the regex-pattern caching block (recompile only when the regexp value changes) and + // returns (code, patternTermName). The caller can build a Matcher from the returned term, or + // pass it to `replace`. + def initLastPatternCode( ctx: CodegenContext, - subject: String, regexp: String, - matcher: String, prettyName: String, - collationId: Int): String = { + collationId: Int): (String, String) = { val classNamePattern = classOf[Pattern].getCanonicalName val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") val termPattern = ctx.addMutableState(classNamePattern, "pattern") val collationRegexFlags = CollationSupport.collationAwareRegexFlags(collationId) val utils = classOf[ExpressionImplUtils].getName + val code = + s""" + |if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | UTF8String r = $regexp.clone(); + | $termPattern = + | $utils.compileRegexPattern(r.toString(), $collationRegexFlags, "$prettyName"); + | $termLastRegex = r; + |} + |""".stripMargin + (code, termPattern) + } + + def initLastMatcherCode( + ctx: CodegenContext, + subject: String, + regexp: String, + matcher: String, + prettyName: String, + collationId: Int): String = { + val (patternCode, termPattern) = initLastPatternCode(ctx, regexp, prettyName, collationId) s""" - |if (!$regexp.equals($termLastRegex)) { - | // regex value changed - | UTF8String r = $regexp.clone(); - | $termPattern = - | $utils.compileRegexPattern(r.toString(), $collationRegexFlags, "$prettyName"); - | $termLastRegex = r; - |} + |$patternCode |java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); |""".stripMargin } @@ -1232,12 +1246,13 @@ object RegExpUtils { /** * Runs the regexp_replace loop shared by RegExpReplace's eval and codegen, so the generated - * Java is a single call rather than an inline match/replace loop plus the error construction. - * `matcher` must already wrap `subject.toString`. `regexp`/`rep`/`pos` are only used to build - * the error message on a failed replacement. + * Java is a single call rather than an inline matcher build + match/replace loop + error + * construction. The matcher is built here from `pattern` so `subject.toString` is computed once. + * `subject` is returned unchanged when the start position is out of range; `regexp`/`rep`/`pos` + * are only used to build the error message on a failed replacement. */ def replace( - matcher: Matcher, + pattern: Pattern, subject: UTF8String, replacement: String, regexp: UTF8String, @@ -1246,6 +1261,7 @@ object RegExpUtils { val source = subject.toString val position = pos - 1 if (position == 0 || position < source.length) { + val matcher = pattern.matcher(source) matcher.region(position, source.length) val result = new JStringBuilder while (matcher.find()) { From 8c24f32f73a72f67f85289a4592c5b4a016193e4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 5 Jun 2026 10:36:39 +0800 Subject: [PATCH 3/4] [SPARK-57255][SQL][FOLLOWUP] Use Java String parameters in RegExpUtils.replace to drop UTF8String casts --- .../expressions/regexpExpressions.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3c162df9c0c87..9156f69ef013b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -736,8 +736,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } - RegExpUtils.replace(pattern, s.asInstanceOf[UTF8String], lastReplacement, - p.asInstanceOf[UTF8String], r.asInstanceOf[UTF8String], i.asInstanceOf[Int]) + RegExpUtils.replace(pattern, s.toString, lastReplacement, p.toString, i.asInstanceOf[Int]) } override def dataType: DataType = subject.dataType @@ -769,7 +768,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio $termLastReplacement = $termLastReplacementInUTF8.toString(); } ${ev.value} = $regExpUtils.replace( - $termPattern, $subject, $termLastReplacement, $regexp, $rep, $pos); + $termPattern, $subject.toString(), $termLastReplacement, $regexp.toString(), $pos); $setEvNotNull """ }) @@ -1247,18 +1246,16 @@ object RegExpUtils { /** * Runs the regexp_replace loop shared by RegExpReplace's eval and codegen, so the generated * Java is a single call rather than an inline matcher build + match/replace loop + error - * construction. The matcher is built here from `pattern` so `subject.toString` is computed once. - * `subject` is returned unchanged when the start position is out of range; `regexp`/`rep`/`pos` - * are only used to build the error message on a failed replacement. + * construction. The matcher is built here from `pattern`. `source` is returned (as a new + * UTF8String) when the start position is out of range; `regexp`/`pos` are only used to build + * the error message on a failed replacement. */ def replace( pattern: Pattern, - subject: UTF8String, + source: String, replacement: String, - regexp: UTF8String, - rep: UTF8String, + regexp: String, pos: Int): UTF8String = { - val source = subject.toString val position = pos - 1 if (position == 0 || position < source.length) { val matcher = pattern.matcher(source) @@ -1270,13 +1267,13 @@ object RegExpUtils { } catch { case NonFatal(e) => throw QueryExecutionErrors.invalidRegexpReplaceError( - source, regexp.toString, rep.toString, pos, e) + source, regexp, replacement, pos, e) } } matcher.appendTail(result) UTF8String.fromString(result.toString) } else { - subject + UTF8String.fromString(source) } } } From edcd5c622fa01e9437debb40034c1e14af8dc222 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 5 Jun 2026 11:09:38 +0800 Subject: [PATCH 4/4] [SPARK-57255][SQL][FOLLOWUP] Avoid eager per-row regexp.toString in RegExpUtils.replace by using pattern.pattern() for the error message --- .../catalyst/expressions/regexpExpressions.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9156f69ef013b..e6afb2da8db81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -736,7 +736,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } - RegExpUtils.replace(pattern, s.toString, lastReplacement, p.toString, i.asInstanceOf[Int]) + RegExpUtils.replace(pattern, s.toString, lastReplacement, i.asInstanceOf[Int]) } override def dataType: DataType = subject.dataType @@ -768,7 +768,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio $termLastReplacement = $termLastReplacementInUTF8.toString(); } ${ev.value} = $regExpUtils.replace( - $termPattern, $subject.toString(), $termLastReplacement, $regexp.toString(), $pos); + $termPattern, $subject.toString(), $termLastReplacement, $pos); $setEvNotNull """ }) @@ -1247,14 +1247,13 @@ object RegExpUtils { * Runs the regexp_replace loop shared by RegExpReplace's eval and codegen, so the generated * Java is a single call rather than an inline matcher build + match/replace loop + error * construction. The matcher is built here from `pattern`. `source` is returned (as a new - * UTF8String) when the start position is out of range; `regexp`/`pos` are only used to build - * the error message on a failed replacement. + * UTF8String) when the start position is out of range; `pos` and `pattern.pattern()` (the + * original regex string) are only used to build the error message on a failed replacement. */ def replace( pattern: Pattern, source: String, replacement: String, - regexp: String, pos: Int): UTF8String = { val position = pos - 1 if (position == 0 || position < source.length) { @@ -1266,8 +1265,11 @@ object RegExpUtils { matcher.appendReplacement(result, replacement) } catch { case NonFatal(e) => + // pattern.pattern() is the original regexp string: the pattern is compiled from the + // raw regexp without escaping (see ExpressionImplUtils.compileRegexPattern), so it + // round-trips exactly into the error message. throw QueryExecutionErrors.invalidRegexpReplaceError( - source, regexp, replacement, pos, e) + source, pattern.pattern(), replacement, pos, e) } } matcher.appendTail(result)