diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c2c01d2c7815..e6afb2da8db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -723,8 +723,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // last replacement string, we don't want to convert a UTF8String => java.langString every time. @transient private var lastReplacement: String = _ @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private lazy val result: JStringBuilder = new JStringBuilder final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { @@ -738,26 +736,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() lastReplacement = lastReplacementInUTF8.toString } - val source = s.toString() - val position = i.asInstanceOf[Int] - 1 - if (position == 0 || position < source.length) { - val m = pattern.matcher(source) - m.region(position, source.length) - result.delete(0, result.length()) - while (m.find) { - try { - m.appendReplacement(result, lastReplacement) - } catch { - case NonFatal(e) => - throw QueryExecutionErrors.invalidRegexpReplaceError(s.toString, - p.toString, r.toString, i.asInstanceOf[Int], e) - } - } - m.appendTail(result) - UTF8String.fromString(result.toString) - } else { - s - } + RegExpUtils.replace(pattern, s.toString, lastReplacement, i.asInstanceOf[Int]) } override def dataType: DataType = subject.dataType @@ -768,14 +747,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termResult = ctx.freshName("termResult") - - val classNameStringBuilder = classOf[JStringBuilder].getCanonicalName - - val matcher = ctx.freshName("matcher") - val source = ctx.freshName("source") - val position = ctx.freshName("position") - val termLastReplacement = ctx.addMutableState("String", "lastReplacement") val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") @@ -784,41 +755,22 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } else { "" } + val regExpUtils = RegExpUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => { - s""" - ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} - if (!$rep.equals($termLastReplacementInUTF8)) { - // replacement string changed - $termLastReplacementInUTF8 = $rep.clone(); - $termLastReplacement = $termLastReplacementInUTF8.toString(); - } - String $source = $subject.toString(); - int $position = $pos - 1; - if ($position == 0 || $position < $source.length()) { - $classNameStringBuilder $termResult = new $classNameStringBuilder(); - $matcher.region($position, $source.length()); - - while ($matcher.find()) { - try { - $matcher.appendReplacement($termResult, $termLastReplacement); - } catch (Throwable e) { - if (scala.util.control.NonFatal.apply(e)) { - throw QueryExecutionErrors.invalidRegexpReplaceError($source, $regexp.toString(), - $rep.toString(), $pos, e); - } else { - throw e; - } - } + val (patternCode, termPattern) = + RegExpUtils.initLastPatternCode(ctx, regexp, prettyName, collationId) + s""" + $patternCode + if (!$rep.equals($termLastReplacementInUTF8)) { + // replacement string changed + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); } - $matcher.appendTail($termResult); - ${ev.value} = UTF8String.fromString($termResult.toString()); - $termResult = null; - } else { - ${ev.value} = $subject; - } - $setEvNotNull - """ + ${ev.value} = $regExpUtils.replace( + $termPattern, $subject.toString(), $termLastReplacement, $pos); + $setEvNotNull + """ }) } @@ -1242,27 +1194,43 @@ case class RegExpInStr(subject: Expression, regexp: Expression, idx: Expression) } object RegExpUtils { - def initLastMatcherCode( + // Emits the regex-pattern caching block (recompile only when the regexp value changes) and + // returns (code, patternTermName). The caller can build a Matcher from the returned term, or + // pass it to `replace`. + def initLastPatternCode( ctx: CodegenContext, - subject: String, regexp: String, - matcher: String, prettyName: String, - collationId: Int): String = { + collationId: Int): (String, String) = { val classNamePattern = classOf[Pattern].getCanonicalName val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") val termPattern = ctx.addMutableState(classNamePattern, "pattern") val collationRegexFlags = CollationSupport.collationAwareRegexFlags(collationId) val utils = classOf[ExpressionImplUtils].getName + val code = + s""" + |if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | UTF8String r = $regexp.clone(); + | $termPattern = + | $utils.compileRegexPattern(r.toString(), $collationRegexFlags, "$prettyName"); + | $termLastRegex = r; + |} + |""".stripMargin + (code, termPattern) + } + + def initLastMatcherCode( + ctx: CodegenContext, + subject: String, + regexp: String, + matcher: String, + prettyName: String, + collationId: Int): String = { + val (patternCode, termPattern) = initLastPatternCode(ctx, regexp, prettyName, collationId) s""" - |if (!$regexp.equals($termLastRegex)) { - | // regex value changed - | UTF8String r = $regexp.clone(); - | $termPattern = - | $utils.compileRegexPattern(r.toString(), $collationRegexFlags, "$prettyName"); - | $termLastRegex = r; - |} + |$patternCode |java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); |""".stripMargin } @@ -1274,4 +1242,40 @@ object RegExpUtils { r.toString, CollationSupport.collationAwareRegexFlags(collationId), prettyName) (pattern, r) } + + /** + * Runs the regexp_replace loop shared by RegExpReplace's eval and codegen, so the generated + * Java is a single call rather than an inline matcher build + match/replace loop + error + * construction. The matcher is built here from `pattern`. `source` is returned (as a new + * UTF8String) when the start position is out of range; `pos` and `pattern.pattern()` (the + * original regex string) are only used to build the error message on a failed replacement. + */ + def replace( + pattern: Pattern, + source: String, + replacement: String, + pos: Int): UTF8String = { + val position = pos - 1 + if (position == 0 || position < source.length) { + val matcher = pattern.matcher(source) + matcher.region(position, source.length) + val result = new JStringBuilder + while (matcher.find()) { + try { + matcher.appendReplacement(result, replacement) + } catch { + case NonFatal(e) => + // pattern.pattern() is the original regexp string: the pattern is compiled from the + // raw regexp without escaping (see ExpressionImplUtils.compileRegexPattern), so it + // round-trips exactly into the error message. + throw QueryExecutionErrors.invalidRegexpReplaceError( + source, pattern.pattern(), replacement, pos, e) + } + } + matcher.appendTail(result) + UTF8String.fromString(result.toString) + } else { + UTF8String.fromString(source) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 9b5a8f2fce06..19c8c8f341f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -361,6 +361,17 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num")) checkEvaluation(nonNullExpr, "num-num", row1) + // A replacement that references a non-existent group makes the replace fail. + checkErrorInExpression[SparkRuntimeException]( + expr, + create_row("100-200", "(\\d+)", "$2"), + "INVALID_REGEXP_REPLACE", + Map( + "source" -> "100-200", + "pattern" -> "(\\d+)", + "replacement" -> "$2", + "position" -> "1")) + // Test escaping of arguments GenerateUnsafeProjection.generate( RegExpReplace(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)