Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.lang.{StringBuilder => JStringBuilder}
import java.util.Locale
import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException}
import java.util.regex.{Matcher, Pattern, PatternSyntaxException}

import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -844,6 +844,37 @@ object RegExpExtractBase {
prettyName, groupCount, groupIndex)
}
}

// Extracts group `idx` of the first match, shared by RegExpExtract's eval and codegen so the
// generated Java is a single call rather than an inline match/group block.
def extract(matcher: Matcher, idx: Int, prettyName: String): UTF8String = {
if (matcher.find()) {
val mr = matcher.toMatchResult
checkGroupIndex(prettyName, mr.groupCount, idx)
val group = mr.group(idx)
// Pattern matched, but it's an optional group
if (group == null) UTF8String.EMPTY_UTF8 else UTF8String.fromString(group)
} else {
UTF8String.EMPTY_UTF8
}
}

// Extracts group `idx` of every match, shared by RegExpExtractAll's eval and codegen.
def extractAll(matcher: Matcher, idx: Int, prettyName: String): GenericArrayData = {
val matchResults = new ArrayBuffer[UTF8String]()
while (matcher.find()) {
val mr = matcher.toMatchResult
checkGroupIndex(prettyName, mr.groupCount, idx)
val group = mr.group(idx)
// Pattern matched, but it's an optional group
if (group == null) {
matchResults += UTF8String.EMPTY_UTF8
} else {
matchResults += UTF8String.fromString(group)
}
}
new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
}
}

abstract class RegExpExtractBase
Expand Down Expand Up @@ -924,20 +955,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
val m = getLastMatcher(s, p)
if (m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtractBase.checkGroupIndex(prettyName, mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but it's an optional group
UTF8String.EMPTY_UTF8
} else {
UTF8String.fromString(group)
}
} else {
UTF8String.EMPTY_UTF8
}
RegExpExtractBase.extract(getLastMatcher(s, p), r.asInstanceOf[Int], prettyName)
}

override def dataType: DataType = subject.dataType
Expand All @@ -946,7 +964,6 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
Expand All @@ -956,19 +973,8 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName, collationId)}
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
$classNameRegExpExtractBase.checkGroupIndex("$prettyName", $matchResult.groupCount(), $idx);
if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
${ev.value} = UTF8String.fromString($matchResult.group($idx));
}
$setEvNotNull
} else {
${ev.value} = UTF8String.EMPTY_UTF8;
$setEvNotNull
}"""
${ev.value} = $classNameRegExpExtractBase.extract($matcher, $idx, "$prettyName");
$setEvNotNull"""
})
}

Expand Down Expand Up @@ -1022,32 +1028,15 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
val m = getLastMatcher(s, p)
val matchResults = new ArrayBuffer[UTF8String]()
while (m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtractBase.checkGroupIndex(prettyName, mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but it's an optional group
matchResults += UTF8String.EMPTY_UTF8
} else {
matchResults += UTF8String.fromString(group)
}
}

new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
RegExpExtractBase.extractAll(getLastMatcher(s, p), r.asInstanceOf[Int], prettyName)
}

override def dataType: DataType = ArrayType(subject.dataType)
override def prettyName: String = "regexp_extract_all"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName
val arrayClass = classOf[GenericArrayData].getName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
val matchResults = ctx.freshName("matchResults")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
Expand All @@ -1057,21 +1046,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
s"""
| ${RegExpUtils.initLastMatcherCode(ctx, subject, regexp, matcher, prettyName,
collationId)}
| java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>();
| while ($matcher.find()) {
| java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
| $classNameRegExpExtractBase.checkGroupIndex(
| "$prettyName",
| $matchResult.groupCount(),
| $idx);
| if ($matchResult.group($idx) == null) {
| $matchResults.add(UTF8String.EMPTY_UTF8);
| } else {
| $matchResults.add(UTF8String.fromString($matchResult.group($idx)));
| }
| }
| ${ev.value} =
| new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()]));
| $classNameRegExpExtractBase.extractAll($matcher, $idx, "$prettyName");
| $setEvNotNull
"""
})
Expand Down