diff --git a/be/src/exprs/function/function_ifnull.h b/be/src/exprs/function/function_ifnull.h deleted file mode 100644 index f6efb3451260e3..00000000000000 --- a/be/src/exprs/function/function_ifnull.h +++ /dev/null @@ -1,130 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -// This file is copied from -// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/Ifnull.h -// and modified by Doris - -#pragma once - -#include - -#include -#include -#include - -#include "common/status.h" -#include "core/block/block.h" -#include "core/block/column_numbers.h" -#include "core/block/column_with_type_and_name.h" -#include "core/block/columns_with_type_and_name.h" -#include "core/column/column.h" -#include "core/column/column_nullable.h" -#include "core/data_type/data_type.h" -#include "core/data_type/data_type_nullable.h" -#include "core/data_type/data_type_number.h" -#include "core/types.h" -#include "exprs/aggregate/aggregate_function.h" -#include "exprs/function/function.h" -#include "exprs/function/simple_function_factory.h" -#include "runtime/runtime_state.h" - -namespace doris { -class FunctionContext; -} // namespace doris - -namespace doris { -class FunctionIfNull : public IFunction { -public: - static constexpr auto name = "ifnull"; - - static FunctionPtr create() { return std::make_shared(); } - - String get_name() const override { return name; } - - size_t get_number_of_arguments() const override { return 2; } - - // be compatible with fe code - /* - if (fn.functionName().equalsIgnoreCase("ifnull") || fn.functionName().equalsIgnoreCase("nvl")) { - Preconditions.checkState(children.size() == 2); - if (children.get(0).isNullable()) { - return children.get(1).isNullable(); - } - return false; - } - */ - DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - if (arguments[0]->is_nullable()) { - return arguments[1]; - } - return arguments[0]; - } - - bool use_default_implementation_for_nulls() const override { return false; } - - // ifnull(col_left, col_right) == if(isnull(col_left), col_right, col_left) - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - uint32_t result, size_t input_rows_count) const override { - ColumnWithTypeAndName& col_left = block.get_by_position(arguments[0]); - if (col_left.column->only_null()) { - // Here we need to use convert_to_full_column_if_const because only_null() is a runtime function. - // If the second parameter is constant, it will cause the execution to rely on runtime information to determine whether it is constant. - block.get_by_position(result).column = - block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); - return Status::OK(); - } - - ColumnWithTypeAndName null_column_arg0 {nullptr, std::make_shared(), ""}; - ColumnWithTypeAndName nested_column_arg0 {nullptr, col_left.type, ""}; - - col_left.column = col_left.column->convert_to_full_column_if_const(); - - /// implement isnull(col_left) logic - if (auto* nullable = check_and_get_column(*col_left.column)) { - null_column_arg0.column = nullable->get_null_map_column_ptr(); - nested_column_arg0.column = nullable->get_nested_column_ptr(); - nested_column_arg0.type = - reinterpret_cast(nested_column_arg0.type.get()) - ->get_nested_type(); - } else { - block.get_by_position(result).column = col_left.column; - return Status::OK(); - } - const ColumnsWithTypeAndName if_columns { - null_column_arg0, block.get_by_position(arguments[1]), nested_column_arg0}; - - // see get_return_type_impl - // if result is nullable, means both then and else column are nullable, we use original col_left to keep nullable info - // if result is not nullable, means both then and else column are not nullable, we use nested_column_arg0 to remove nullable info - bool result_nullable = block.get_by_position(result).type->is_nullable(); - Block temporary_block({ - null_column_arg0, - block.get_by_position(arguments[1]), - result_nullable - ? col_left - : nested_column_arg0, // if result is nullable, we need pass the original col_left else pass nested_column_arg0 - block.get_by_position(result), - }); - - auto func_if = SimpleFunctionFactory::instance().get_function( - "if", if_columns, block.get_by_position(result).type, {}); - RETURN_IF_ERROR(func_if->execute(context, temporary_block, {0, 1, 2}, 3, input_rows_count)); - block.get_by_position(result).column = temporary_block.get_by_position(3).column; - return Status::OK(); - } -}; -} // namespace doris diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index 7cc9498160a349..dd37af56c0e9fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -3493,6 +3493,15 @@ public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpression private Expression processUnboundFunction(ParserRuleContext ctx, String dbName, String functionName, boolean isDistinct, List params, WindowSpecContext windowContext, IdentifierContext hintContext) { + if (dbName == null && "nullif".equalsIgnoreCase(functionName) && !isDistinct + && windowContext == null && hintContext == null && params.size() == 2) { + Expression first = params.get(0); + Expression second = params.get(1); + return new UnboundFunction("if", ImmutableList.of( + new EqualTo(first, second), + NullLiteral.INSTANCE, + first)); + } List unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance); if (!unboundStars.isEmpty()) { if (dbName != null diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java index 52e75b65534c57..4eba83cfcfec79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java @@ -108,8 +108,9 @@ private static Expression rewriteNvl(ExpressionMatchingContext ctx) { /* * nullif(null, R) => Null - * nullif(L, null) => Null + * nullif(L, null) => nullable(L) * nullif(null, null) => Null + * nullif(L, R) => if(L = R, null, L) */ private static Expression rewriteNullIf(ExpressionMatchingContext ctx) { NullIf nullIf = ctx.expr; @@ -120,7 +121,7 @@ private static Expression rewriteNullIf(ExpressionMatchingContext ctx) { nullIf, new Nullable(nullIf.child(0)), ctx.rewriteContext ); } else { - return nullIf; + return NullIfToIf.rewrite(nullIf, ctx.rewriteContext); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullIfToIfTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullIfToIfTest.java new file mode 100644 index 00000000000000..f90ad288c26f72 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullIfToIfTest.java @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.StringType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class NullIfToIfTest extends ExpressionRewriteTestHelper { + + @Test + void testRewriteConstructedNullIfToIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of(NullIfToIf.INSTANCE)); + + SlotReference first = new SlotReference("a", StringType.INSTANCE, true); + SlotReference second = new SlotReference("b", StringType.INSTANCE, true); + assertRewrite( + new NullIf(first, second), + new If(new EqualTo(first, second), new NullLiteral(StringType.INSTANCE), first)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java index 3cef0f3e02d461..a5d326b4c5d899 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java @@ -84,9 +84,10 @@ void testPushIntoNullIf() { )); assertRewriteAfterTypeCoercion("cast(nullif(TA, TB) as bigint)", "if(TA = TB, NULL, cast(TA as bigint))"); assertRewriteAfterTypeCoercion("cast(nullif(TA, 1) as bigint)", "if(TA = 1, null, cast(TA as bigint))"); - assertRewriteAfterTypeCoercion("a > nullif(b, c)", "a > nullif(b, c)"); + assertRewriteAfterTypeCoercion("a > nullif(b, c)", "if(b = c, null, a > b)"); assertRewriteAfterTypeCoercion("2 > nullif(b, c)", "if(b = c, null, 2 > b)"); - assertRewriteAfterTypeCoercion("2 > nullif(b + random(1, 10), c)", "2 > nullif(b + random(1, 10), c)"); + assertRewriteAfterTypeCoercion("2 > nullif(b + random(1, 10), c)", + "if(b + random(1, 10) = c, null, 2 > b + random(1, 10))"); assertRewriteAfterTypeCoercion("2 > nullif(b, c + random(1, 10))", "if(b = c + random(1, 10), null, 2 > b)"); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java index c75f3febbd7011..474b6f32056234 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java @@ -178,6 +178,9 @@ public void testNullIf() { ), new NullLiteral(DateTimeV2Type.of(6)) ); + + // nullif(L, R) -> if(L = R, null, L) + assertRewriteAfterTypeCoercion("nullif(a, b)", "if(a = b, null, a)"); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java index 41017520046151..9681d35c9bad99 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java @@ -17,14 +17,17 @@ package org.apache.doris.nereids.trees.expressions; +import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.exceptions.SyntaxParseException; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.parser.ParserTestBase; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.SqlModeHelper; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; public class ExpressionParserTest extends ParserTestBase { @@ -92,6 +95,15 @@ public void testExprBetweenPredicate() { ); } + @Test + public void testNullIfRewriteInPlanBuilder() { + parseExpression("nullif(a, b)") + .assertEquals(new UnboundFunction("if", ImmutableList.of( + new EqualTo(new UnboundSlot("a"), new UnboundSlot("b")), + NullLiteral.INSTANCE, + new UnboundSlot("a")))); + } + @Test public void testInPredicate() { String in = "select * from test1 where d1 in (1, 2, 3)"; diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out index 84590c64610e4d..1bfb0bdf6ba36b 100644 --- a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -275,7 +275,7 @@ PhysicalResultSink -- !nullif_one_side_1 -- PhysicalResultSink --PhysicalProject[t1.a, t2.x] -----NestedLoopJoin[INNER_JOIN](nullif(a, x) = (a + b)) +----NestedLoopJoin[INNER_JOIN](if((a = x), NULL, a) = (a + b)) ------PhysicalProject[(a + b) AS `(a + b)`, t1.a] --------filter((t1.a = (t1.a + t1.b))) ----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] @@ -285,7 +285,7 @@ PhysicalResultSink -- !nullif_two_side_1 -- PhysicalResultSink --PhysicalProject[t1.a, t2.x] -----hashJoin[INNER_JOIN] hashCondition=((t1.a = expr_(x + y))) otherCondition=((nullif(a, x) = (t2.x + t2.y))) +----hashJoin[INNER_JOIN] hashCondition=((t1.a = expr_(x + y))) otherCondition=((if((a = x), NULL, a) = (t2.x + t2.y))) ------PhysicalProject[t1.a] --------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1(t1)] ------PhysicalProject[(x + y) AS `expr_(x + y)`, tbl_join_extract_or_from_case_when_2.x, tbl_join_extract_or_from_case_when_2.y]