diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java index b3795a8dc617..1bc602c7443a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java @@ -66,7 +66,8 @@ public ScalarTransformFunctionWrapper(FunctionInfo functionInfo) { ColumnDataType resultType = FunctionUtils.getColumnDataType(resultClass); if (resultType != null) { _resultType = resultType; - _resultMetadata = new TransformResultMetadata(resultType.toDataType(), !_resultType.isArray(), false); + DataType dataType = resultType == ColumnDataType.OBJECT ? DataType.STRING : resultType.toDataType(); + _resultMetadata = new TransformResultMetadata(dataType, !_resultType.isArray(), false); } else { // Handle unrecognized result class with STRING _resultType = ColumnDataType.STRING; @@ -264,7 +265,7 @@ public String[] transformToStringValuesSV(ValueBlock valueBlock) { } Object value = _functionInvoker.invoke(_scalarArguments); _stringValuesSV[i] = - value != null ? (String) _resultType.toInternal(value) : NullValuePlaceHolder.STRING; + value != null ? String.valueOf(_resultType.toInternal(value)) : NullValuePlaceHolder.STRING; } return _stringValuesSV; } @@ -424,6 +425,9 @@ private void getNonLiteralValues(ValueBlock valueBlock) { case BYTES: _nonLiteralValues[i] = transformFunction.transformToBytesValuesSV(valueBlock); break; + case OBJECT: + _nonLiteralValues[i] = transformToObjectValues(transformFunction, valueBlock); + break; case PRIMITIVE_INT_ARRAY: _nonLiteralValues[i] = transformFunction.transformToIntValuesMV(valueBlock); break; @@ -482,4 +486,65 @@ private void getNonLiteralValues(ValueBlock valueBlock) { } } } + + private static Object[] transformToObjectValues(TransformFunction transformFunction, ValueBlock valueBlock) { + TransformResultMetadata resultMetadata = transformFunction.getResultMetadata(); + DataType dataType = resultMetadata.getDataType(); + if (resultMetadata.isSingleValue()) { + switch (dataType) { + case BOOLEAN: { + int[] intValues = transformFunction.transformToIntValuesSV(valueBlock); + int numValues = intValues.length; + Boolean[] booleanValues = new Boolean[numValues]; + for (int i = 0; i < numValues; i++) { + booleanValues[i] = intValues[i] == 1; + } + return booleanValues; + } + case INT: + return ArrayUtils.toObject(transformFunction.transformToIntValuesSV(valueBlock)); + case LONG: + return ArrayUtils.toObject(transformFunction.transformToLongValuesSV(valueBlock)); + case FLOAT: + return ArrayUtils.toObject(transformFunction.transformToFloatValuesSV(valueBlock)); + case DOUBLE: + return ArrayUtils.toObject(transformFunction.transformToDoubleValuesSV(valueBlock)); + case BIG_DECIMAL: + return transformFunction.transformToBigDecimalValuesSV(valueBlock); + case TIMESTAMP: { + long[] longValues = transformFunction.transformToLongValuesSV(valueBlock); + int numValues = longValues.length; + Timestamp[] timestampValues = new Timestamp[numValues]; + for (int i = 0; i < numValues; i++) { + timestampValues[i] = new Timestamp(longValues[i]); + } + return timestampValues; + } + case STRING: + return transformFunction.transformToStringValuesSV(valueBlock); + case BYTES: + return transformFunction.transformToBytesValuesSV(valueBlock); + default: + throw new IllegalStateException("Unsupported data type: " + dataType); + } + } + switch (dataType) { + case INT: + return transformFunction.transformToIntValuesMV(valueBlock); + case LONG: + return transformFunction.transformToLongValuesMV(valueBlock); + case FLOAT: + return transformFunction.transformToFloatValuesMV(valueBlock); + case DOUBLE: + return transformFunction.transformToDoubleValuesMV(valueBlock); + case BIG_DECIMAL: + return transformFunction.transformToBigDecimalValuesMV(valueBlock); + case STRING: + return transformFunction.transformToStringValuesMV(valueBlock); + case BYTES: + return transformFunction.transformToBytesValuesMV(valueBlock); + default: + throw new IllegalStateException("Unsupported multi-value data type: " + dataType); + } + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java index e44a74f46478..5752481c2fc7 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java @@ -50,6 +50,30 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTest { + @Test + public void testObjectResultScalarTransformFunction() { + ExpressionContext expression = RequestContextUtils.getExpression("nullIf(1, 2)"); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + + String[] expectedValues = new String[NUM_ROWS]; + Arrays.fill(expectedValues, "1"); + testTransformFunction(transformFunction, expectedValues); + } + + @Test + public void testObjectParameterScalarTransformFunction() { + ExpressionContext expression = + RequestContextUtils.getExpression(String.format("nullIf(%s, '__pinot_nullif_miss__')", STRING_SV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + testTransformFunction(transformFunction, _stringSVValues); + } + @Test public void testStringLowerTransformFunction() { ExpressionContext expression =