From d5442ae32c5b6bd6d78e8cd426f14c481a45306b Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 4 Jun 2026 17:04:58 +0800 Subject: [PATCH] [SPARK-57258][SQL] Reduce regexp_extract/regexp_extract_all generated code size via shared extract helpers --- .../expressions/regexpExpressions.scala | 98 +++++++------------ 1 file changed, 37 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c2c01d2c78159..6ddb37ca6e36f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{StringBuilder => JStringBuilder} import java.util.Locale -import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException} +import java.util.regex.{Matcher, Pattern, PatternSyntaxException} import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -844,6 +844,37 @@ object RegExpExtractBase { prettyName, groupCount, groupIndex) } } + + // Extracts group `idx` of the first match, shared by RegExpExtract's eval and codegen so the + // generated Java is a single call rather than an inline match/group block. + def extract(matcher: Matcher, idx: Int, prettyName: String): UTF8String = { + if (matcher.find()) { + val mr = matcher.toMatchResult + checkGroupIndex(prettyName, mr.groupCount, idx) + val group = mr.group(idx) + // Pattern matched, but it's an optional group + if (group == null) UTF8String.EMPTY_UTF8 else UTF8String.fromString(group) + } else { + UTF8String.EMPTY_UTF8 + } + } + + // Extracts group `idx` of every match, shared by RegExpExtractAll's eval and codegen. + def extractAll(matcher: Matcher, idx: Int, prettyName: String): GenericArrayData = { + val matchResults = new ArrayBuffer[UTF8String]() + while (matcher.find()) { + val mr = matcher.toMatchResult + checkGroupIndex(prettyName, mr.groupCount, idx) + val group = mr.group(idx) + // Pattern matched, but it's an optional group + if (group == null) { + matchResults += UTF8String.EMPTY_UTF8 + } else { + matchResults += UTF8String.fromString(group) + } + } + new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]]) + } } abstract class RegExpExtractBase @@ -924,20 +955,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio def this(s: Expression, r: Expression) = this(s, r, Literal(1)) override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - val m = getLastMatcher(s, p) - if (m.find) { - val mr: MatchResult = m.toMatchResult - val index = r.asInstanceOf[Int] - RegExpExtractBase.checkGroupIndex(prettyName, mr.groupCount, index) - val group = mr.group(index) - if (group == null) { // Pattern matched, but it's an optional group - UTF8String.EMPTY_UTF8 - } else { - UTF8String.fromString(group) - } - } else { - UTF8String.EMPTY_UTF8 - } + RegExpExtractBase.extract(getLastMatcher(s, p), r.asInstanceOf[Int], prettyName) } override def dataType: DataType = subject.dataType @@ -946,7 +964,6 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName val matcher = ctx.freshName("matcher") - val matchResult = ctx.freshName("matchResult") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" } else { @@ -956,19 +973,8 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} - if ($matcher.find()) { - java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); - $classNameRegExpExtractBase.checkGroupIndex("$prettyName", $matchResult.groupCount(), $idx); - if ($matchResult.group($idx) == null) { - ${ev.value} = UTF8String.EMPTY_UTF8; - } else { - ${ev.value} = UTF8String.fromString($matchResult.group($idx)); - } - $setEvNotNull - } else { - ${ev.value} = UTF8String.EMPTY_UTF8; - $setEvNotNull - }""" + ${ev.value} = $classNameRegExpExtractBase.extract($matcher, $idx, "$prettyName"); + $setEvNotNull""" }) } @@ -1022,21 +1028,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres def this(s: Expression, r: Expression) = this(s, r, Literal(1)) override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - val m = getLastMatcher(s, p) - val matchResults = new ArrayBuffer[UTF8String]() - while (m.find) { - val mr: MatchResult = m.toMatchResult - val index = r.asInstanceOf[Int] - RegExpExtractBase.checkGroupIndex(prettyName, mr.groupCount, index) - val group = mr.group(index) - if (group == null) { // Pattern matched, but it's an optional group - matchResults += UTF8String.EMPTY_UTF8 - } else { - matchResults += UTF8String.fromString(group) - } - } - - new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]]) + RegExpExtractBase.extractAll(getLastMatcher(s, p), r.asInstanceOf[Int], prettyName) } override def dataType: DataType = ArrayType(subject.dataType) @@ -1044,10 +1036,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName - val arrayClass = classOf[GenericArrayData].getName val matcher = ctx.freshName("matcher") - val matchResult = ctx.freshName("matchResult") - val matchResults = ctx.freshName("matchResults") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" } else { @@ -1057,21 +1046,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres s""" | ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)} - | java.util.ArrayList $matchResults = new java.util.ArrayList(); - | while ($matcher.find()) { - | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); - | $classNameRegExpExtractBase.checkGroupIndex( - | "$prettyName", - | $matchResult.groupCount(), - | $idx); - | if ($matchResult.group($idx) == null) { - | $matchResults.add(UTF8String.EMPTY_UTF8); - | } else { - | $matchResults.add(UTF8String.fromString($matchResult.group($idx))); - | } - | } | ${ev.value} = - | new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()])); + | $classNameRegExpExtractBase.extractAll($matcher, $idx, "$prettyName"); | $setEvNotNull """ })