From 12ace9adc83af9319ba2d4333980e1f63a0a2d8d Mon Sep 17 00:00:00 2001 From: Mryange Date: Wed, 20 May 2026 14:51:53 +0800 Subject: [PATCH 01/10] [fix](uniform function) fix constant argument handling and use ColumnView (#63076) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What problem does this PR solve? Issue Number: N/A Problem Summary: The uniform function takes three arguments: min, max, and seed. Only the first two (min, max) are truly "always constant" — the seed column should be treated as a regular column, not a constant. Without overriding get_arguments_that_are_always_constant(), when a user passes a constant value as the third argument (seed), the default framework logic treats it as a constant column, leading to incorrect results. Root cause: the base class default get_arguments_that_are_always_constant() does not distinguish between the seed argument and the min/max arguments, so a constant seed would be folded by the constant-handling path rather than being treated as a per-row value. Fix: - Override get_arguments_that_are_always_constant() to return {0, 1}, explicitly marking only min and max as always-constant arguments. - Refactor seed column access to use ColumnView for safer and more idiomatic typed column iteration. (cherry picked from commit a70c212956cad9a3614602e935dc1868bac10e92) --- be/src/exprs/function/uniform.cpp | 14 +++++++++----- .../nereids_function_p0/scalar_function/U.groovy | 2 ++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/be/src/exprs/function/uniform.cpp b/be/src/exprs/function/uniform.cpp index 3bd1e139e1528f..f8bbc8870c4292 100644 --- a/be/src/exprs/function/uniform.cpp +++ b/be/src/exprs/function/uniform.cpp @@ -30,6 +30,7 @@ #include "core/block/block.h" #include "core/block/column_numbers.h" #include "core/column/column.h" +#include "core/column/column_execute_util.h" #include "core/column/column_vector.h" #include "core/data_type/data_type_number.h" // IWYU pragma: keep #include "core/data_type/primitive_type.h" @@ -74,12 +75,12 @@ struct UniformIntImpl { "uniform's min should be less than max, but got [{}, {})", min, max); } - // Get gen column (seed values) - const auto& gen_column = block.get_by_position(arguments[2]).column; + auto gen_column = + ColumnView::create(block.get_by_position(arguments[2]).column); for (int i = 0; i < input_rows_count; i++) { // Use gen value as seed for each row - auto seed = (*gen_column)[i].get(); + auto seed = gen_column.value_at(i); std::mt19937_64 generator(seed); std::uniform_int_distribution distribution(min, max); res_data[i] = distribution(generator); @@ -123,11 +124,12 @@ struct UniformDoubleImpl { } // Get gen column (seed values) - const auto& gen_column = block.get_by_position(arguments[2]).column; + auto gen_column = + ColumnView::create(block.get_by_position(arguments[2]).column); for (int i = 0; i < input_rows_count; i++) { // Use gen value as seed for each row - auto seed = (*gen_column)[i].get(); + auto seed = gen_column.value_at(i); std::mt19937_64 generator(seed); std::uniform_real_distribution distribution(min, max); res_data[i] = distribution(generator); @@ -158,6 +160,8 @@ class FunctionUniform : public IFunction { return Impl::get_variadic_argument_types(); } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {0, 1}; } + Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope if (scope == FunctionContext::FRAGMENT_LOCAL) { diff --git a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy index 68642fa31ec91b..f43bc4cb6eea97 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy @@ -62,6 +62,8 @@ suite("nereids_scalar_fn_U") { def result = sql """select uniform(1, 100, random()*10000) from numbers("number" = "10");""" assertTrue(result.size() == 10) + def doubleResult = sql """select uniform(1.23, 100.100, random()*10000) from numbers("number" = "10");""" + assertTrue(doubleResult.size() == 10) test { sql """select uniform(100, 1, random()*10000) from numbers("number" = "10");""" exception "uniform's min should be less than max" From d2052bb6a111e1a5b376256312ebba0f47786e8e Mon Sep 17 00:00:00 2001 From: Mryange Date: Mon, 1 Jun 2026 12:18:26 +0800 Subject: [PATCH 02/10] [fix](expr) fix mixed const probe constant handling regressions (#63810) The mixed const execution probe exposed several constant-handling problems in BE vectorized functions. - ColumnConst::clone_resized reused the original nested column, so cloned const columns could still alias the source data. - quantile_percent requires its percentile argument to stay constant, but the all-const probe path unpacked it and triggered a false constant-check failure. - regexp_count accessed string columns directly and did not handle mixed const inputs correctly. - uniform still went through the default constant implementation even though its result depends on per-row seed values. This change fixes those behaviors and adds focused unit tests for the uncovered cases. (cherry picked from commit 905c80433b1714027bc853b870de77eb415732e7) --- be/src/core/column/column_const.h | 3 +- .../function/function_quantile_state.cpp | 2 + be/src/exprs/function/function_regexp.cpp | 20 ++++--- be/src/exprs/function/uniform.cpp | 2 + be/test/core/column/column_const_test.cpp | 13 +++++ be/test/exprs/function/function_math_test.cpp | 58 +++++++++++++++++++ .../function/function_quantile_state_test.cpp | 17 ++++++ .../exprs/function/function_string_test.cpp | 16 +++++ 8 files changed, 122 insertions(+), 9 deletions(-) diff --git a/be/src/core/column/column_const.h b/be/src/core/column/column_const.h index 1d0a0d7e596d59..cf26588a6a5d84 100644 --- a/be/src/core/column/column_const.h +++ b/be/src/core/column/column_const.h @@ -126,7 +126,8 @@ class ColumnConst final : public COWHelper { void resize(size_t new_size) override { s = new_size; } MutableColumnPtr clone_resized(size_t new_size) const override { - return ColumnConst::create(data, new_size, false, false); + auto cloned_data = data->clone_resized(data->size()); + return ColumnConst::create(std::move(cloned_data), new_size, false, false); } size_t size() const override { return s; } diff --git a/be/src/exprs/function/function_quantile_state.cpp b/be/src/exprs/function/function_quantile_state.cpp index a0edb82dbf9450..af1b80822f0007 100644 --- a/be/src/exprs/function/function_quantile_state.cpp +++ b/be/src/exprs/function/function_quantile_state.cpp @@ -169,6 +169,8 @@ class FunctionQuantileStatePercent : public IFunction { bool use_default_implementation_for_nulls() const override { return false; } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; } + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const override { auto res_data_column = ColumnFloat64::create(); diff --git a/be/src/exprs/function/function_regexp.cpp b/be/src/exprs/function/function_regexp.cpp index 8a9871c8eb3d6c..fd642e8abf4395 100644 --- a/be/src/exprs/function/function_regexp.cpp +++ b/be/src/exprs/function/function_regexp.cpp @@ -34,6 +34,7 @@ #include "core/block/column_with_type_and_name.h" #include "core/column/column.h" #include "core/column/column_const.h" +#include "core/column/column_execute_util.h" #include "core/column/column_nullable.h" #include "core/column/column_string.h" #include "core/column/column_vector.h" @@ -189,23 +190,26 @@ struct RegexpExtractEngine { }; struct RegexpCountImpl { + using StringColumnView = ColumnView; + static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], size_t input_rows_count, ColumnInt32::Container& result_data) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - for (int i = 0; i < input_rows_count; ++i) { + auto str_col = StringColumnView::create(argument_columns[0]); + auto pattern_col = StringColumnView::create(argument_columns[1]); + for (size_t i = 0; i < input_rows_count; ++i) { + DCHECK(!str_col.is_null_at(i)); + DCHECK(!pattern_col.is_null_at(i)); result_data[i] = _execute_inner_loop(context, str_col, pattern_col, i); } } - static int _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, - const ColumnString* pattern_col, const size_t index_now) { + static int _execute_inner_loop(FunctionContext* context, const StringColumnView& str_col, + const StringColumnView& pattern_col, const size_t index_now) { re2::RE2* re = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); std::unique_ptr scoped_re; if (re == nullptr) { std::string error_str; - DCHECK(pattern_col); - const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, false)); + const auto pattern = pattern_col.value_at(index_now); bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(), StringRef(), scoped_re); if (!st) { @@ -216,7 +220,7 @@ struct RegexpCountImpl { re = scoped_re.get(); } - const auto& str = str_col->get_data_at(index_now); + const auto str = str_col.value_at(index_now); int count = 0; size_t pos = 0; while (pos < str.size) { diff --git a/be/src/exprs/function/uniform.cpp b/be/src/exprs/function/uniform.cpp index f8bbc8870c4292..e639df7a2958bb 100644 --- a/be/src/exprs/function/uniform.cpp +++ b/be/src/exprs/function/uniform.cpp @@ -148,6 +148,8 @@ class FunctionUniform : public IFunction { static FunctionPtr create() { return std::make_shared>(); } String get_name() const override { return name; } + bool use_default_implementation_for_constants() const override { return false; } + size_t get_number_of_arguments() const override { return get_variadic_argument_types_impl().size(); } diff --git a/be/test/core/column/column_const_test.cpp b/be/test/core/column/column_const_test.cpp index f6f81ec3aaba4f..e9f57df213bce3 100644 --- a/be/test/core/column/column_const_test.cpp +++ b/be/test/core/column/column_const_test.cpp @@ -41,6 +41,19 @@ TEST(ColumnConstTest, TestCreate) { EXPECT_TRUE(!is_column_const(column_const2->get_data_column())); } +TEST(ColumnConstTest, clone_resized_clones_nested_data) { + auto column_data = ColumnHelper::create_column({7}); + auto column_const = ColumnConst::create(column_data, 3); + + auto cloned = column_const->clone_resized(5); + const auto& cloned_const = assert_cast(*cloned); + + EXPECT_EQ(cloned_const.size(), 5); + EXPECT_EQ(cloned_const.get_data_column_ptr()->size(), 1); + EXPECT_EQ(cloned_const.get_data_column().get_int(0), 7); + EXPECT_NE(column_const->get_data_column_ptr().get(), cloned_const.get_data_column_ptr().get()); +} + TEST(ColumnConstTest, TestFilter) { { auto column_data = ColumnHelper::create_column({7}); diff --git a/be/test/exprs/function/function_math_test.cpp b/be/test/exprs/function/function_math_test.cpp index 4e51a5dc3e700b..cf1b3a442ea686 100644 --- a/be/test/exprs/function/function_math_test.cpp +++ b/be/test/exprs/function/function_math_test.cpp @@ -18,14 +18,17 @@ #include #include #include +#include #include +#include "core/column/column_const.h" #include "core/data_type/data_type_decimal.h" #include "core/data_type/data_type_number.h" #include "core/data_type/data_type_string.h" #include "core/types.h" #include "exprs/function/function_test_util.h" #include "testutil/any_type.h" +#include "testutil/column_helper.h" namespace doris { @@ -532,6 +535,11 @@ TEST(MathFunctionTest, hex_test) { } TEST(MathFunctionTest, random_test) { +#ifndef NDEBUG + GTEST_SKIP() << "random(seed) exact-value assertions are release-only; debug builds run " + "mock_const_execute before the real call."; +#endif + std::string func_name = "random"; // random(x) InputTypeSet input_types = {Consted {PrimitiveType::TYPE_BIGINT}}; DataSet data_set = {{{Null()}, Null()}, @@ -547,6 +555,56 @@ TEST(MathFunctionTest, random_test) { } } +TEST(MathFunctionTest, uniform_mixed_const_probe_test) { + auto input_type = std::make_shared(); + auto return_type = std::make_shared(); + + Block block; + auto min_data = ColumnHelper::create_column({1}); + auto max_data = ColumnHelper::create_column({10}); + auto seed_column = ColumnHelper::create_column({101, 202, 303}); + + block.insert({ColumnConst::create(min_data, 3), input_type, "min"}); + block.insert({ColumnConst::create(max_data, 3), input_type, "max"}); + block.insert({seed_column, input_type, "seed"}); + + FunctionBasePtr function = SimpleFunctionFactory::instance().get_function( + "uniform", block.get_columns_with_type_and_name(), return_type); + ASSERT_TRUE(function != nullptr); + + block.insert({nullptr, return_type, "result"}); + + FunctionUtils fn_utils(return_type, {input_type, input_type, input_type}, false); + auto* fn_ctx = fn_utils.get_fn_ctx(); + std::vector> constant_cols { + std::make_shared(block.get_by_position(0).column), + std::make_shared(block.get_by_position(1).column), + nullptr, + }; + fn_ctx->set_constant_cols(constant_cols); + + ASSERT_TRUE(function->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL).ok()); + ASSERT_TRUE(function->open(fn_ctx, FunctionContext::THREAD_LOCAL).ok()); + + auto exec_status = function->execute(fn_ctx, block, {0, 1, 2}, 3, 3); + + static_cast(function->close(fn_ctx, FunctionContext::THREAD_LOCAL)); + static_cast(function->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + + ASSERT_TRUE(exec_status.ok()) << exec_status.to_string(); + + const auto& result_column = assert_cast(*block.get_by_position(3).column); + auto expected_uniform = [](int64_t seed) { + std::mt19937_64 generator(seed); + std::uniform_int_distribution distribution(1, 10); + return distribution(generator); + }; + + EXPECT_EQ(result_column.get_element(0), expected_uniform(101)); + EXPECT_EQ(result_column.get_element(1), expected_uniform(202)); + EXPECT_EQ(result_column.get_element(2), expected_uniform(303)); +} + TEST(MathFunctionTest, conv_test) { std::string func_name = "conv"; diff --git a/be/test/exprs/function/function_quantile_state_test.cpp b/be/test/exprs/function/function_quantile_state_test.cpp index 1cb1ced1dae561..e8f2fca702895f 100644 --- a/be/test/exprs/function/function_quantile_state_test.cpp +++ b/be/test/exprs/function/function_quantile_state_test.cpp @@ -213,4 +213,21 @@ TEST(function_quantile_state_test, function_quantile_state_roundtrip) { 0.01); } +TEST(function_quantile_state_test, function_quantile_percent_mixed_const_test) { + std::string func_name = "quantile_percent"; + InputTypeSet input_types = {PrimitiveType::TYPE_QUANTILE_STATE, + ConstedNotnull {PrimitiveType::TYPE_FLOAT}}; + + QuantileState quantile_state; + quantile_state.add_value(1.0); + quantile_state.add_value(2.0); + quantile_state.add_value(3.0); + quantile_state.add_value(4.0); + quantile_state.add_value(5.0); + + DataSet data_set = {{{&quantile_state, 0.5F}, 3.0}}; + + static_cast(check_function(func_name, input_types, data_set)); +} + } // namespace doris diff --git a/be/test/exprs/function/function_string_test.cpp b/be/test/exprs/function/function_string_test.cpp index edf888f2c8f1b3..90456da258a960 100644 --- a/be/test/exprs/function/function_string_test.cpp +++ b/be/test/exprs/function/function_string_test.cpp @@ -3854,4 +3854,20 @@ TEST(function_string_test, function_unicode_normalize_invalid_mode) { EXPECT_NE(Status::OK(), st); } +TEST(function_string_test, function_regexp_count_mixed_const_test) { + std::string func_name = "regexp_count"; + + InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR, PrimitiveType::TYPE_VARCHAR}; + DataSet data_set = { + {{std::string("a.b:c;d"), std::string("[.:;]")}, std::int32_t(3)}, + {{std::string("a1b2346c3d"), std::string("\\d+")}, std::int32_t(3)}, + {{std::string("abcd"), std::string("")}, std::int32_t(0)}, + {{std::string("book keeper"), std::string("oo|ee")}, std::int32_t(2)}, + {{Null(), std::string("\\d+")}, Null()}, + {{std::string("abcd"), Null()}, Null()}, + }; + + check_function_all_arg_comb(func_name, input_types, data_set); +} + } // namespace doris From 635bd8f57b3a7d8f2c1bd1d3aa367e92665aeb8e Mon Sep 17 00:00:00 2001 From: Mryange Date: Tue, 19 May 2026 14:38:08 +0800 Subject: [PATCH 03/10] [fix](be) Clean up aggregate states and use Doris hash containers (#63174) ### What problem does this PR solve? Issue Number: N/A Problem Summary: Aggregate batch deserialization creates aggregate states with placement new before deserializing or merging serialized input. If deserialization or merge throws after `create()` succeeds, the previous cleanup only destroyed states from earlier rows and skipped the current already-created state. This can leak resources owned by aggregate state objects, such as hash sets or bitmap internals. Root cause: the exception cleanup destroyed only states from previous rows. If the current row's state was created successfully and deserialization failed afterward, that current state was excluded from cleanup. This PR tracks the number of successfully created aggregate states and destroys exactly that range on exception. It preserves the successful-path ownership model: `deserialize_vec()` leaves created states to its caller, while merge helpers still release temporary rhs states with `destroy_vec()` after successful merge. This PR also switches aggregate-local `phmap::flat_hash_map` and `phmap::flat_hash_set` usages to Doris wrapper aliases so they use Doris' default equality and allocator definitions consistently. (cherry picked from commit 4483daf9f03b93b9ef4fbb168d86219749ca0181) --- be/src/exprs/aggregate/aggregate_function.h | 54 +++--- .../aggregate/aggregate_function_collect.h | 4 +- .../aggregate/aggregate_function_distinct.h | 8 +- .../exprs/aggregate/aggregate_function_map.h | 2 +- .../aggregate/aggregate_function_map_v2.h | 2 +- .../aggregate_function_exception_test.cpp | 162 ++++++++++++++++++ 6 files changed, 200 insertions(+), 32 deletions(-) create mode 100644 be/test/exprs/aggregate/aggregate_function_exception_test.cpp diff --git a/be/src/exprs/aggregate/aggregate_function.h b/be/src/exprs/aggregate/aggregate_function.h index 0e07f74c1aeab1..94eaaa9ad72403 100644 --- a/be/src/exprs/aggregate/aggregate_function.h +++ b/be/src/exprs/aggregate/aggregate_function.h @@ -479,19 +479,21 @@ class IAggregateFunctionHelper : public IAggregateFunction { size_t num_rows) const override { const Derived* derived = assert_cast(this); const auto size_of_data = derived->size_of_data(); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto place = places + size_of_data * i; VectorBufferReader buffer_reader(column->get_data_at(i)); derived->create(place); + ++created_count; derived->deserialize(place, buffer_reader, arena); - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = places + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = places + size_of_data * j; + derived->destroy(place); + } + throw; } } @@ -502,19 +504,21 @@ class IAggregateFunctionHelper : public IAggregateFunction { const auto size_of_data = derived->size_of_data(); const auto* column_string = assert_cast(column); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto rhs_place = rhs + size_of_data * i; VectorBufferReader buffer_reader(column_string->get_data_at(i)); derived->create(rhs_place); + ++created_count; derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader, arena); - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = rhs + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = rhs + size_of_data * j; + derived->destroy(place); + } + throw; } derived->destroy_vec(rhs, num_rows); @@ -526,22 +530,24 @@ class IAggregateFunctionHelper : public IAggregateFunction { const auto* derived = assert_cast(this); const auto size_of_data = derived->size_of_data(); const auto* column_string = assert_cast(column); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto rhs_place = rhs + size_of_data * i; VectorBufferReader buffer_reader(column_string->get_data_at(i)); derived->create(rhs_place); + ++created_count; if (places[i]) { derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader, arena); } - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = rhs + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = rhs + size_of_data * j; + derived->destroy(place); + } + throw; } derived->destroy_vec(rhs, num_rows); } diff --git a/be/src/exprs/aggregate/aggregate_function_collect.h b/be/src/exprs/aggregate/aggregate_function_collect.h index 3f9c84f7dea373..70b6e5d8d61bc4 100644 --- a/be/src/exprs/aggregate/aggregate_function_collect.h +++ b/be/src/exprs/aggregate/aggregate_function_collect.h @@ -50,7 +50,7 @@ struct AggregateFunctionCollectSetData { using ElementType = typename PrimitiveTypeTraits::CppType; using ColVecType = typename PrimitiveTypeTraits::ColumnType; using SelfType = AggregateFunctionCollectSetData; - using Set = phmap::flat_hash_set; + using Set = doris::flat_hash_set; Set data_set; Int64 max_size = -1; @@ -119,7 +119,7 @@ struct AggregateFunctionCollectSetData { using ElementType = StringRef; using ColVecType = ColumnString; using SelfType = AggregateFunctionCollectSetData; - using Set = phmap::flat_hash_set; + using Set = doris::flat_hash_set; Set data_set; Int64 max_size = -1; diff --git a/be/src/exprs/aggregate/aggregate_function_distinct.h b/be/src/exprs/aggregate/aggregate_function_distinct.h index 618d9b46f41996..825e782f0cba53 100644 --- a/be/src/exprs/aggregate/aggregate_function_distinct.h +++ b/be/src/exprs/aggregate/aggregate_function_distinct.h @@ -52,8 +52,8 @@ template struct AggregateFunctionDistinctSingleNumericData { /// When creating, the hash table must be small. using Container = std::conditional_t< - stable, phmap::flat_hash_map::CppType, uint32_t>, - phmap::flat_hash_set::CppType>>; + stable, doris::flat_hash_map::CppType, uint32_t>, + doris::flat_hash_set::CppType>>; using Self = AggregateFunctionDistinctSingleNumericData; Container data; @@ -126,8 +126,8 @@ struct AggregateFunctionDistinctSingleNumericData { template struct AggregateFunctionDistinctGenericData { /// When creating, the hash table must be small. - using Container = std::conditional_t, - phmap::flat_hash_set>; + using Container = std::conditional_t, + doris::flat_hash_set>; using Self = AggregateFunctionDistinctGenericData; Container data; diff --git a/be/src/exprs/aggregate/aggregate_function_map.h b/be/src/exprs/aggregate/aggregate_function_map.h index f9aff592503cc0..a16bea7867c91a 100644 --- a/be/src/exprs/aggregate/aggregate_function_map.h +++ b/be/src/exprs/aggregate/aggregate_function_map.h @@ -35,7 +35,7 @@ namespace doris { template struct AggregateFunctionMapAggData { using KeyType = typename PrimitiveTypeTraits::CppType; - using Map = phmap::flat_hash_map; + using Map = doris::flat_hash_map; AggregateFunctionMapAggData() { throw Exception(Status::FatalError("__builtin_unreachable")); } diff --git a/be/src/exprs/aggregate/aggregate_function_map_v2.h b/be/src/exprs/aggregate/aggregate_function_map_v2.h index 3181b1ad4261d0..1d821c486c6b0a 100644 --- a/be/src/exprs/aggregate/aggregate_function_map_v2.h +++ b/be/src/exprs/aggregate/aggregate_function_map_v2.h @@ -33,7 +33,7 @@ namespace doris { #include "common/compile_check_begin.h" struct AggregateFunctionMapAggDataV2 { - using Map = phmap::flat_hash_map; + using Map = doris::flat_hash_map; AggregateFunctionMapAggDataV2() { throw Exception(Status::FatalError("__builtin_unreachable")); diff --git a/be/test/exprs/aggregate/aggregate_function_exception_test.cpp b/be/test/exprs/aggregate/aggregate_function_exception_test.cpp new file mode 100644 index 00000000000000..21ee64dba4aef1 --- /dev/null +++ b/be/test/exprs/aggregate/aggregate_function_exception_test.cpp @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "core/arena.h" +#include "exprs/aggregate/aggregate_function.h" + +namespace doris { + +struct TrackingAggregateState { + TrackingAggregateState() { ++construct_count; } + ~TrackingAggregateState() { ++destroy_count; } + + static void reset_counters() { + construct_count = 0; + destroy_count = 0; + } + + static int construct_count; + static int destroy_count; +}; + +int TrackingAggregateState::construct_count = 0; +int TrackingAggregateState::destroy_count = 0; + +class ThrowOnDeserializeAggregateFunction final + : public IAggregateFunctionDataHelper { +public: + ThrowOnDeserializeAggregateFunction() + : IAggregateFunctionDataHelper( + DataTypes {std::make_shared()}) {} + + String get_name() const override { return "throw_on_deserialize"; } + + DataTypePtr get_return_type() const override { return std::make_shared(); } + + void add(AggregateDataPtr, const IColumn**, ssize_t, Arena&) const override {} + + void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena&) const override {} + + void serialize(ConstAggregateDataPtr, BufferWritable& buf) const override { + String payload; + buf.write_binary(payload); + } + + void deserialize(AggregateDataPtr, BufferReadable& buf, Arena&) const override { + String payload; + buf.read_binary(payload); + if (payload == "throw") { + throw Exception(ErrorCode::INTERNAL_ERROR, "mock deserialize failure"); + } + } + + void insert_result_into(ConstAggregateDataPtr, IColumn&) const override {} +}; + +class AggregateFunctionExceptionTest : public testing::Test { +protected: + void SetUp() override { TrackingAggregateState::reset_counters(); } + + MutableColumnPtr make_column(std::initializer_list payloads) { + auto column = ColumnString::create(); + VectorBufferWriter writer(*column); + for (const auto& payload : payloads) { + writer.write_binary(payload); + writer.commit(); + } + return column; + } + + ThrowOnDeserializeAggregateFunction function; + Arena arena; +}; + +TEST_F(AggregateFunctionExceptionTest, DeserializeVecDestroysCurrentStateOnFailure) { + auto column = make_column({"ok", "throw"}); + std::vector states(function.size_of_data() * 2); + + bool thrown = false; + try { + function.deserialize_vec(states.data(), static_cast(column.get()), arena, 2); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + if (!thrown) { + function.destroy_vec(states.data(), 2); + } + EXPECT_EQ(TrackingAggregateState::construct_count, 2); + EXPECT_EQ(TrackingAggregateState::destroy_count, 2); +} + +TEST_F(AggregateFunctionExceptionTest, DeserializeAndMergeVecDestroysRhsStateOnFailure) { + auto column = make_column({"throw"}); + std::vector place_storage(function.size_of_data()); + std::vector rhs_storage(function.size_of_data()); + auto* place = place_storage.data(); + function.create(place); + + std::array places {place}; + const auto destroy_count_before_call = TrackingAggregateState::destroy_count; + bool thrown = false; + try { + function.deserialize_and_merge_vec(places.data(), 0, rhs_storage.data(), column.get(), + arena, 1); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + EXPECT_EQ(TrackingAggregateState::destroy_count - destroy_count_before_call, 1); + + function.destroy(place); + EXPECT_EQ(TrackingAggregateState::construct_count, TrackingAggregateState::destroy_count); +} + +TEST_F(AggregateFunctionExceptionTest, + DeserializeAndMergeVecSelectedDestroysAllCreatedRhsStatesOnFailure) { + auto column = make_column({"skip", "throw"}); + std::vector place_storage(function.size_of_data()); + std::vector rhs_storage(function.size_of_data() * 2); + auto* place = place_storage.data(); + function.create(place); + + std::array places {nullptr, place}; + const auto destroy_count_before_call = TrackingAggregateState::destroy_count; + bool thrown = false; + try { + function.deserialize_and_merge_vec_selected(places.data(), 0, rhs_storage.data(), + column.get(), arena, 2); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + EXPECT_EQ(TrackingAggregateState::destroy_count - destroy_count_before_call, 2); + + function.destroy(place); + EXPECT_EQ(TrackingAggregateState::construct_count, TrackingAggregateState::destroy_count); +} + +} // namespace doris \ No newline at end of file From 507754e999251c89823c95acb90a8b57f006a99e Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 29 May 2026 10:08:17 +0800 Subject: [PATCH 04/10] [fix](be) Fix timestamptz group_array state serde (#63827) Fix collect_list/group_array on nested TIMESTAMPTZ values when complex aggregate state is serialized through JSON. This keeps the existing state format for compatibility, provides a UTC timezone during serde, and adds regression coverage for the nested group_array case. (cherry picked from commit 87316004891ab8d32f107b353c48bc2b65625425) --- .../aggregate/aggregate_function_collect.h | 7 ++++ .../test_timestamptz_agg_functions.out | 3 ++ .../test_timestamptz_agg_functions.groovy | 37 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/be/src/exprs/aggregate/aggregate_function_collect.h b/be/src/exprs/aggregate/aggregate_function_collect.h index 70b6e5d8d61bc4..63a3c6348225ad 100644 --- a/be/src/exprs/aggregate/aggregate_function_collect.h +++ b/be/src/exprs/aggregate/aggregate_function_collect.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -343,6 +344,10 @@ struct AggregateFunctionCollectListData { buf.write_binary(size); DataTypeSerDe::FormatOptions opt; + auto timezone = cctz::utc_time_zone(); + opt.timezone = &timezone; + // TODO: Refactor this aggregate state serialization to avoid + // round-tripping through a human-readable string format. auto tmp_str = ColumnString::create(); VectorBufferWriter tmp_buf(*tmp_str.get()); @@ -368,6 +373,8 @@ struct AggregateFunctionCollectListData { StringRef s; DataTypeSerDe::FormatOptions opt; + auto timezone = cctz::utc_time_zone(); + opt.timezone = &timezone; for (size_t i = 0; i < size; i++) { buf.read_binary(s); Slice slice(s.data, s.size); diff --git a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out index 850cbe14a980d5..f7ff2eb36d0dbe 100644 --- a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out +++ b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out @@ -11,3 +11,6 @@ true -- !group_array_union -- 3 +-- !group_array_nested_timestamptz -- +[["2024-01-01 00:00:00.000000+00:00", "2024-01-01 00:00:00.000000+00:00", "2024-01-02 00:00:00.000000+00:00"], ["2024-01-01 00:00:00.000000+00:00", "2024-01-02 00:00:00.000000+00:00", "2024-01-03 00:00:00.000000+00:00"]] + diff --git a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy index 89126b5a284772..e5bf945225ef45 100644 --- a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy +++ b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy @@ -56,4 +56,41 @@ suite("test_timestamptz_agg_functions", "datatype_p0") { qt_group_array_union "SELECT size(group_array_union(arr)) FROM test_tz_agg" sql "DROP TABLE IF EXISTS test_tz_agg" + + sql "DROP TABLE IF EXISTS tz_group_array_crash" + sql """ + CREATE TABLE tz_group_array_crash ( + grp INT, + arr ARRAY + ) + DUPLICATE KEY(grp) + DISTRIBUTED BY HASH(grp) BUCKETS 1 + PROPERTIES('replication_num' = '1') + """ + + sql """ + INSERT INTO tz_group_array_crash VALUES + ( + 1, + ARRAY( + CAST('2024-01-01 00:00:00 +00:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-01 08:00:00 +08:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-02 00:00:00 +00:00' AS TIMESTAMPTZ(6)) + ) + ), + ( + 1, + ARRAY( + CAST('2024-01-01 00:00:00 +00:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-02 08:00:00 +08:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-03 00:00:00 +00:00' AS TIMESTAMPTZ(6)) + ) + ) + """ + + qt_group_array_nested_timestamptz """ + SELECT CAST(array_sort(group_array(arr)) AS STRING) + FROM tz_group_array_crash + GROUP BY grp + """ } From 3ce541d1c932667ac02986488f69ac61f4f55d1a Mon Sep 17 00:00:00 2001 From: Mryange Date: Mon, 18 May 2026 16:20:49 +0800 Subject: [PATCH 05/10] [refine](exec) replace std::shared_mutex/std::shared_lock with annotated wrappers for thread safety analysis (#63109) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After #63070 replaced `std::mutex`/`std::lock_guard` with annotated wrappers (`AnnotatedMutex`/`LockGuard`), shared mutex usage still relied on raw `std::shared_mutex` and `std::shared_lock`/`std::unique_lock`, which are invisible to Clang's thread safety analysis. This leaves shared-lock sites unverified — `GUARDED_BY`, `REQUIRES_SHARED`, and other annotations cannot be enforced. - Add `AnnotatedSharedMutex` (wrapping `std::shared_mutex` with `CAPABILITY`/`ACQUIRE`/`RELEASE`/`ACQUIRE_SHARED`/`RELEASE_SHARED` annotations) and `SharedLockGuard` (RAII `SCOPED_CAPABILITY` with `ACQUIRE_SHARED`/`RELEASE`) in `thread_safety_annotations.h`. - Migrate `VDataStreamMgr` and `RuntimeFilterMergeControllerEntity` from `std::shared_mutex` to `AnnotatedSharedMutex`, and from `std::unique_lock`/`std::shared_lock` to `LockGuard`/`SharedLockGuard`. - Add `GUARDED_BY` annotations to the protected maps. - Extract `_find_recvr` as a private helper annotated with `REQUIRES_SHARED(_lock)`, eliminating the `bool acquire_lock` parameter that previously bypassed lock tracking. (cherry picked from commit ab1a4ddb59263c9f48d60c35e016212858a9547a) --- be/src/common/thread_safety_annotations.h | 48 +++++++++++++++++++ be/src/exec/exchange/vdata_stream_mgr.cpp | 28 ++++++----- be/src/exec/exchange/vdata_stream_mgr.h | 14 ++++-- .../runtime_filter/runtime_filter_mgr.cpp | 20 ++++---- .../exec/runtime_filter/runtime_filter_mgr.h | 9 ++-- .../exec/pipeline/vdata_stream_recvr_test.cpp | 2 +- 6 files changed, 87 insertions(+), 34 deletions(-) diff --git a/be/src/common/thread_safety_annotations.h b/be/src/common/thread_safety_annotations.h index 6cd8d4b0cae45c..6bbdb8ce6546ad 100644 --- a/be/src/common/thread_safety_annotations.h +++ b/be/src/common/thread_safety_annotations.h @@ -22,6 +22,7 @@ #pragma once #include +#include #ifdef BE_TEST namespace doris { @@ -93,6 +94,27 @@ class CAPABILITY("mutex") AnnotatedMutex { std::mutex _mutex; }; +// Annotated shared mutex wrapper for use with Clang thread safety analysis. +// Wraps std::shared_mutex and provides both exclusive and shared capability +// operations so GUARDED_BY / REQUIRES_SHARED / etc. can reference it. +class CAPABILITY("mutex") AnnotatedSharedMutex { +public: + void lock() ACQUIRE() { _mutex.lock(); } + void unlock() RELEASE() { _mutex.unlock(); } + bool try_lock() TRY_ACQUIRE(true) { return _mutex.try_lock(); } + + void lock_shared() ACQUIRE_SHARED() { _mutex.lock_shared(); } + void unlock_shared() RELEASE_SHARED() { _mutex.unlock_shared(); } + bool try_lock_shared() TRY_ACQUIRE_SHARED(true) { return _mutex.try_lock_shared(); } + + // Access the underlying std::shared_mutex (e.g., for std::condition_variable_any). + // Use with care — this bypasses thread safety annotations. + std::shared_mutex& native_handle() { return _mutex; } + +private: + std::shared_mutex _mutex; +}; + // RAII scoped lock guard annotated for thread safety analysis. // In BE_TEST builds, injects a random sleep before acquiring and after // releasing the lock to exercise concurrent code paths. @@ -119,6 +141,32 @@ class SCOPED_CAPABILITY LockGuard { MutexType& _mu; }; +// RAII scoped shared lock guard annotated for thread safety analysis. +// In BE_TEST builds, injects a random sleep before acquiring and after +// releasing the lock to exercise concurrent code paths. +template +class SCOPED_CAPABILITY SharedLockGuard { +public: + explicit SharedLockGuard(MutexType& mu) ACQUIRE_SHARED(mu) : _mu(mu) { +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + _mu.lock_shared(); + } + ~SharedLockGuard() RELEASE() { + _mu.unlock_shared(); +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + } + + SharedLockGuard(const SharedLockGuard&) = delete; + SharedLockGuard& operator=(const SharedLockGuard&) = delete; + +private: + MutexType& _mu; +}; + // RAII unique lock annotated for thread safety analysis. // Supports manual lock/unlock while preserving capability tracking. template diff --git a/be/src/exec/exchange/vdata_stream_mgr.cpp b/be/src/exec/exchange/vdata_stream_mgr.cpp index 17bab298c432c8..b3357f8d0b6006 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.cpp +++ b/be/src/exec/exchange/vdata_stream_mgr.cpp @@ -44,7 +44,7 @@ VDataStreamMgr::~VDataStreamMgr() { // It will core during graceful stop. auto receivers = std::vector>(); { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); auto receiver_iterator = _receiver_map.begin(); while (receiver_iterator != _receiver_map.end()) { // Could not call close directly, because during close method, it will remove itself @@ -77,22 +77,16 @@ std::shared_ptr VDataStreamMgr::create_recvr( this, memory_used_counter, state, fragment_instance_id, dest_node_id, num_senders, is_merging, profile, data_queue_capacity)); uint32_t hash_value = get_hash_value(fragment_instance_id, dest_node_id); - std::unique_lock l(_lock); + LockGuard l(_lock); _fragment_stream_set.insert(std::make_pair(fragment_instance_id, dest_node_id)); _receiver_map.insert(std::make_pair(hash_value, recvr)); return recvr; } -Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock) { +Status VDataStreamMgr::_find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) { VLOG_ROW << "looking up fragment_instance_id=" << print_id(fragment_instance_id) << ", node=" << node_id; - uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); - // Create lock guard and not own lock currently and will lock conditionally - std::shared_lock recvr_lock(_lock, std::defer_lock); - if (acquire_lock) { - recvr_lock.lock(); - } std::pair range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { @@ -108,6 +102,13 @@ Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNod node_id, print_id(fragment_instance_id)); } +Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, + std::shared_ptr* res) { + SharedLockGuard recvr_lock(_lock); + uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); + return _find_recvr(hash_value, fragment_instance_id, node_id, res); +} + Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request, ::google::protobuf::Closure** done, const int64_t wait_for_worker) { @@ -199,7 +200,7 @@ Status VDataStreamMgr::deregister_recvr(const TUniqueId& fragment_instance_id, P << ", node=" << node_id; uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); { - std::unique_lock l(_lock); + LockGuard l(_lock); auto range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { const std::shared_ptr& recvr = range.first->second; @@ -230,12 +231,13 @@ void VDataStreamMgr::cancel(const TUniqueId& fragment_instance_id, Status exec_s VLOG_QUERY << "cancelling all streams for fragment=" << print_id(fragment_instance_id); std::vector> recvrs; { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); FragmentStreamSet::iterator i = _fragment_stream_set.lower_bound(std::make_pair(fragment_instance_id, 0)); while (i != _fragment_stream_set.end() && i->first == fragment_instance_id) { std::shared_ptr recvr; - WARN_IF_ERROR(find_recvr(i->first, i->second, &recvr, false), ""); + uint32_t hash_value = get_hash_value(i->first, i->second); + WARN_IF_ERROR(_find_recvr(hash_value, i->first, i->second, &recvr), ""); if (recvr == nullptr) { // keep going but at least log it std::stringstream err; diff --git a/be/src/exec/exchange/vdata_stream_mgr.h b/be/src/exec/exchange/vdata_stream_mgr.h index 7bde8f3b4c0c9b..7f35d62720278c 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.h +++ b/be/src/exec/exchange/vdata_stream_mgr.h @@ -30,6 +30,7 @@ #include "common/be_mock_util.h" #include "common/global_types.h" #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "runtime/runtime_profile.h" namespace google { @@ -58,8 +59,7 @@ class VDataStreamMgr { RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity); MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, - bool acquire_lock = true); + std::shared_ptr* res); Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id); @@ -69,9 +69,9 @@ class VDataStreamMgr { void cancel(const TUniqueId& fragment_instance_id, Status exec_status); private: - std::shared_mutex _lock; + AnnotatedSharedMutex _lock; using StreamMap = std::unordered_multimap>; - StreamMap _receiver_map; + StreamMap _receiver_map GUARDED_BY(_lock); struct ComparisonOp { bool operator()(const std::pair& a, @@ -89,7 +89,11 @@ class VDataStreamMgr { } }; using FragmentStreamSet = std::set, ComparisonOp>; - FragmentStreamSet _fragment_stream_set; + FragmentStreamSet _fragment_stream_set GUARDED_BY(_lock); + + Status _find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) + REQUIRES_SHARED(_lock); uint32_t get_hash_value(const TUniqueId& fragment_instance_id, PlanNodeId node_id); }; diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp index 3615299fa38148..49007d5c73534b 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp @@ -185,7 +185,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc( auto filter_id = runtime_filter_desc->filter_id; GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } @@ -235,7 +235,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrfilter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); if (iter == _filter_map.end()) { return Status::InvalidArgument("unknown filter id {}", @@ -243,12 +243,12 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrsecond; - std::unique_lock l(iter->second.mtx); + std::unique_lock l(cnt_val.mtx); // Discard stale-stage runtime filter size requests from old recursive CTE rounds. // Each round increments the stage counter; only messages matching the current stage // should be processed. This prevents old PFC's runtime filters from corrupting // the merge state of the new round's filters. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } cnt_val.source_addrs.push_back(request->source_addr()); @@ -269,7 +269,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptr(); - sync_request->set_stage(iter->second.stage); + sync_request->set_stage(cnt_val.stage); auto closure = AutoReleaseClosure>:: @@ -336,7 +336,7 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto filter_id = request->filter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); VLOG_ROW << "recv filter id:" << request->filter_id() << " " << request->ShortDebugString(); if (iter == _filter_map.end()) { @@ -347,9 +347,9 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto& cnt_val = iter->second; bool is_ready = false; { - std::lock_guard l(iter->second.mtx); + std::lock_guard l(cnt_val.mtx); // Discard stale-stage merge requests from old recursive CTE rounds. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } if (cnt_val.merger == nullptr) { @@ -492,7 +492,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( for (const auto& filter_id : filter_ids) { GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } RETURN_IF_ERROR(cnt_val->reset(query_ctx)); @@ -502,7 +502,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( std::string RuntimeFilterMergeControllerEntity::debug_string() { std::string result = "RuntimeFilterMergeControllerEntity Info:\n"; - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); for (const auto& [filter_id, ctx] : _filter_map) { result += fmt::format("filter_id: {}, stage: {}, {}\n", filter_id, ctx.stage, ctx.merger->debug_string()); diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.h b/be/src/exec/runtime_filter/runtime_filter_mgr.h index f822e01196f853..418f9aa41b7414 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.h +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.h @@ -27,12 +27,11 @@ #include #include #include -#include #include -#include #include #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "util/uid_util.h" namespace butil { @@ -168,7 +167,7 @@ class RuntimeFilterMergeControllerEntity { std::string debug_string(); bool empty() { - std::shared_lock read_lock(_filter_map_mutex); + SharedLockGuard read_lock(_filter_map_mutex); return _filter_map.empty(); } @@ -185,10 +184,10 @@ class RuntimeFilterMergeControllerEntity { int64_t merge_time, PUniqueId query_id, int execution_timeout); // protect _filter_map - std::shared_mutex _filter_map_mutex; + AnnotatedSharedMutex _filter_map_mutex; std::shared_ptr _mem_tracker; - std::map _filter_map; + std::map _filter_map GUARDED_BY(_filter_map_mutex); }; #include "common/compile_check_end.h" } // namespace doris diff --git a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp index ab6b03b13c5572..f0c4e05c7e6528 100644 --- a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp +++ b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp @@ -577,7 +577,7 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) { struct MockVDataStreamMgr : public VDataStreamMgr { ~MockVDataStreamMgr() override = default; Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock = true) override { + std::shared_ptr* res) override { *res = recvr; return Status::OK(); } From 39cbffb4183c1817dd31d3a61c61d791dd9f8873 Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 22 May 2026 18:25:11 +0800 Subject: [PATCH 06/10] [opt](exec) skip result serialization for dry run queries (#63356) ### What problem does this PR solve? Issue Number: N/A Related PR: None Problem Summary: When dry_run_query is enabled, FE only needs the returned row count, but BE still spends most of PhysicalResultSink time serializing MySQL result rows In a local dry-run case against numbers("number"="1000000"), the profile showed AppendBatchTime = 77.689ms, TupleConvertTime = 68.650ms, and ResultSendTime = 2.702us, which means the dry-run path was still paying almost the full result sink conversion cost. This change keeps output expr evaluation intact, but returns early in the MySQL result writers once the output block is produced in dry-run mode. That preserves returned row accounting while skipping result serialization, block copy, and sink enqueue work that dry-run queries never consume. (cherry picked from commit 4938d638e3f22ecbb17deecd169453598655d08d) --- .../exec/sink/writer/vmysql_result_writer.cpp | 6 +++++ .../data_type_serde_mysql_test.cpp | 23 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/be/src/exec/sink/writer/vmysql_result_writer.cpp b/be/src/exec/sink/writer/vmysql_result_writer.cpp index 4101f18db3c457..7f98f626f3f2fc 100644 --- a/be/src/exec/sink/writer/vmysql_result_writer.cpp +++ b/be/src/exec/sink/writer/vmysql_result_writer.cpp @@ -297,6 +297,12 @@ Status VMysqlResultWriter::write(RuntimeState* state, Block& input_block) { Block block; RETURN_IF_ERROR(VExprContext::get_output_block_after_execute_exprs(_output_vexpr_ctxs, input_block, &block)); + + if (_is_dry_run) { + _written_rows += cast_set(block.rows()); + return Status::OK(); + } + const auto total_bytes = block.bytes(); if (total_bytes > config::thrift_max_message_size) [[unlikely]] { diff --git a/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp b/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp index d0a6cbdbbaef5f..e8f289bbf54055 100644 --- a/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp +++ b/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp @@ -77,6 +77,10 @@ class TestBlockSerializer final : public MySQLResultBlockBuffer { public: TestBlockSerializer(RuntimeState* state) : MySQLResultBlockBuffer(state) {} ~TestBlockSerializer() override = default; + size_t queue_size() { + std::lock_guard l(_lock); + return _result_batch_queue.size(); + } std::shared_ptr get_block() { std::lock_guard l(_lock); DCHECK_EQ(_result_batch_queue.size(), 1); @@ -86,7 +90,7 @@ class TestBlockSerializer final : public MySQLResultBlockBuffer { } }; -void serialize_and_deserialize_mysql_test() { +void serialize_and_deserialize_mysql_test(bool dry_run) { Block block; // create_descriptor_tablet(); std::vector> cols { @@ -317,12 +321,25 @@ void serialize_and_deserialize_mysql_test() { auto serializer = std::make_shared(&state); VMysqlResultWriter mysql_writer(serializer, _output_vexpr_ctxs, nullptr, false); - Status st = mysql_writer.write(&runtime_stat, block); + TQueryOptions query_options; + query_options.__set_dry_run_query(dry_run); + runtime_stat.set_query_options(query_options); + + Status st = mysql_writer.init(&runtime_stat); EXPECT_TRUE(st.ok()); + + st = mysql_writer.write(&runtime_stat, block); + EXPECT_TRUE(st.ok()); + EXPECT_EQ(mysql_writer.get_written_rows(), row_num); + EXPECT_EQ(serializer->queue_size(), dry_run ? 0 : 1); } TEST(DataTypeSerDeMysqlTest, ScalaSerDeTest) { - serialize_and_deserialize_mysql_test(); + serialize_and_deserialize_mysql_test(false); +} + +TEST(DataTypeSerDeMysqlTest, DryRunSkipsSerialization) { + serialize_and_deserialize_mysql_test(true); } } // namespace doris From ea00d2f39d63fd32e8725ae1251b12d636a598af Mon Sep 17 00:00:00 2001 From: Mryange Date: Thu, 21 May 2026 12:08:17 +0800 Subject: [PATCH 07/10] [opt](sort) avoid copying whole blocks during merge (#63429) Problem Summary: `MergeSorterState` used the generic copy-based merge path even when the current top sorted run could return its whole remaining block before any other run. This adds a direct whole-block fast path guarded by a total-order check, avoiding unnecessary `insert_range_from` work in inner merge. ### What is changed? - Add `MergeSortCursor::totally_less_or_equals()` to detect when the current run is wholly before the next child. - Return the current block directly from `MergeSorterState::_merge_sort_read_impl()` when the whole-block condition is satisfied. - Add focused BE unit tests for exact-batch and smaller-than-batch whole-block fast-path cases. (cherry picked from commit 974f9bdc89bfc7258ad124dcc6f1fecf3f4be0f8) --- be/src/exec/sort/sort_cursor.h | 5 ++ be/src/exec/sort/sorter.cpp | 18 ++++++ be/test/exec/operator/sort_operator_test.cpp | 19 +++--- be/test/exec/sort/heap_sorter_test.cpp | 16 ++--- be/test/exec/sort/merge_sorter_state.cpp | 67 ++++++++++++++++++++ 5 files changed, 107 insertions(+), 18 deletions(-) diff --git a/be/src/exec/sort/sort_cursor.h b/be/src/exec/sort/sort_cursor.h index d5b4a14e46158f..dae751258a5e20 100644 --- a/be/src/exec/sort/sort_cursor.h +++ b/be/src/exec/sort/sort_cursor.h @@ -205,6 +205,11 @@ struct MergeSortCursor { return !impl->empty() && greater_at(rhs, impl->pos, rhs.impl->pos) > 0; } + bool totally_less_or_equals(const MergeSortCursor& rhs) const { + return !impl->empty() && !rhs.impl->empty() && + greater_at(rhs, impl->rows - 1, rhs.impl->pos) <= 0; + } + /// Inverted so that the priority queue elements are removed in ascending order. bool operator<(const MergeSortCursor& rhs) const { return greater(rhs); } diff --git a/be/src/exec/sort/sorter.cpp b/be/src/exec/sort/sorter.cpp index 88160819328ce0..616cc2145a2d16 100644 --- a/be/src/exec/sort/sorter.cpp +++ b/be/src/exec/sort/sorter.cpp @@ -94,6 +94,24 @@ Status MergeSorterState::merge_sort_read(doris::Block* block, int batch_size, bo } void MergeSorterState::_merge_sort_read_impl(int batch_size, doris::Block* block, bool* eos) { + if (_queue.is_valid() && batch_size > 0) { + auto [current, current_rows] = _queue.current(); + current_rows = std::min(current_rows, static_cast(batch_size)); + const size_t step = std::min(_offset, current_rows); + + // Fast path when the current top run can contribute its whole remaining block + // before any other run. The returned block stays within batch_size because + // is_last(current_rows) can only hold after the min(batch_size, queue_batch_size) + // clamp above. + if (step == 0 && current->impl->is_first() && current->impl->is_last(current_rows) && + (_queue.size() == 1 || (*current).totally_less_or_equals(_queue.next_child()))) { + current->impl->block->swap(*block); + _queue.remove_top(); + *eos = false; + return; + } + } + size_t num_columns = unsorted_block()->columns(); MutableBlock m_block = VectorizedUtils::build_mutable_mem_reuse_block(block, *unsorted_block()); diff --git a/be/test/exec/operator/sort_operator_test.cpp b/be/test/exec/operator/sort_operator_test.cpp index 23fa37e57b01ef..bd6c0ee68c32a4 100644 --- a/be/test/exec/operator/sort_operator_test.cpp +++ b/be/test/exec/operator/sort_operator_test.cpp @@ -192,21 +192,20 @@ TEST_F(SortOperatorTest, test_dep) { EXPECT_TRUE(is_ready(source_local_state->dependencies())); { - Block block = ColumnHelper::create_block({}); + MutableBlock merged_block = ColumnHelper::create_block({}); bool eos = false; - auto st = source->get_block(state.get(), &block, &eos); - EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_FALSE(eos); + while (!eos) { + Block block; + auto st = source->get_block(state.get(), &block, &eos); + EXPECT_TRUE(st.ok()) << st.msg(); + EXPECT_TRUE(merged_block.merge(block)); + } + + auto block = merged_block.to_block(); EXPECT_EQ(block.rows(), 6); std::cout << block.dump_data() << std::endl; EXPECT_TRUE(ColumnHelper::block_equal( block, ColumnHelper::create_block({1, 2, 3, 4, 5, 6}))); - - block.clear(); - st = source->get_block(state.get(), &block, &eos); - EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_TRUE(eos); - EXPECT_EQ(block.rows(), 0); } } diff --git a/be/test/exec/sort/heap_sorter_test.cpp b/be/test/exec/sort/heap_sorter_test.cpp index 90b06764175f1e..9c91db2e5e3833 100644 --- a/be/test/exec/sort/heap_sorter_test.cpp +++ b/be/test/exec/sort/heap_sorter_test.cpp @@ -100,20 +100,20 @@ TEST_F(HeapSorterTest, test_topn_sorter1) { EXPECT_TRUE(sorter->prepare_for_read(false)); { - Block block; + MutableBlock merged_block = ColumnHelper::create_block({}, {}); bool eos = false; - EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); + while (!eos) { + Block block; + EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); + EXPECT_TRUE(merged_block.merge(block)); + } + + auto block = merged_block.to_block(); EXPECT_EQ(block.rows(), 6); EXPECT_TRUE(ColumnHelper::block_equal( block, Block {ColumnHelper::create_column_with_name({1, 2, 3, 4, 5, 6}), ColumnHelper::create_column_with_name({1, 2, 3, 4, 5, 6})})); - - block.clear_column_data(); - - EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); - EXPECT_EQ(block.rows(), 0); - EXPECT_EQ(eos, true); } } diff --git a/be/test/exec/sort/merge_sorter_state.cpp b/be/test/exec/sort/merge_sorter_state.cpp index 0dc8a1a8937164..7af89e7cbdf70b 100644 --- a/be/test/exec/sort/merge_sorter_state.cpp +++ b/be/test/exec/sort/merge_sorter_state.cpp @@ -101,4 +101,71 @@ TEST_F(MergeSorterStateTest, test1) { ColumnHelper::create_block({5, 6}))); } } + +TEST_F(MergeSorterStateTest, whole_block_fast_path_swaps_block) { + state.reset(new MergeSorterState(*row_desc, 0)); + auto first_block = create_block({1, 2, 3}); + auto second_block = create_block({4, 5, 6}); + auto first_column = first_block->get_by_position(0).column; + + state->add_sorted_block(first_block); + state->add_sorted_block(second_block); + + SortDescription desc {SortColumnDescription {0, 1, -1}}; + ASSERT_TRUE(state->build_merge_tree(desc)); + + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 3, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE( + ColumnHelper::block_equal(block, ColumnHelper::create_block({1, 2, 3}))); + EXPECT_EQ(block.get_by_position(0).column.get(), first_column.get()); +} + +TEST_F(MergeSorterStateTest, whole_block_fast_path_allows_smaller_than_batch) { + state.reset(new MergeSorterState(*row_desc, 0)); + auto first_block = create_block({1, 2, 3}); + auto second_block = create_block({4, 5, 6}); + auto first_column = first_block->get_by_position(0).column; + auto second_column = second_block->get_by_position(0).column; + + state->add_sorted_block(first_block); + state->add_sorted_block(second_block); + + SortDescription desc {SortColumnDescription {0, 1, -1}}; + ASSERT_TRUE(state->build_merge_tree(desc)); + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE(ColumnHelper::block_equal( + block, ColumnHelper::create_block({1, 2, 3}))); + EXPECT_EQ(block.get_by_position(0).column.get(), first_column.get()); + } + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE(ColumnHelper::block_equal( + block, ColumnHelper::create_block({4, 5, 6}))); + EXPECT_EQ(block.get_by_position(0).column.get(), second_column.get()); + } + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(eos); + EXPECT_EQ(block.rows(), 0); + } +} } // namespace doris \ No newline at end of file From 5b6a9af49a9b9ad6159d6e1d45b5936a961b0bf0 Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 15 May 2026 16:08:54 +0800 Subject: [PATCH 08/10] [fix](function) fix map_contains_entry runtime error when TIMESTAMPTZ is map key or value (#63124) ### What problem does this PR solve? Issue Number: N/A Problem Summary: `map_contains_entry` throws a `RUNTIME_ERROR` at BE execution time when the MAP column has `TIMESTAMPTZ` as its key or value type. Root cause: `FunctionMapContainsEntry::is_equality_comparison_supported` hard-coded a list of accepted primitive types (`is_date_type`, `is_time_type`, `is_number`, `is_string_type`, `is_ip`) but omitted `TYPE_TIMESTAMPTZ`. As a result, the pre-execution type guard always rejected TIMESTAMPTZ even though the underlying `dispatch_switch_all` + `ColumnVector::compare_at` path supports it correctly. The fix replaces the hand-maintained list with a direct call to `dispatch_switch_all`, which already covers TIMESTAMPTZ in its `DATETIME` branch, making the guard consistent with the actual dispatch layer. (cherry picked from commit 60b0d46dbe3f875939a7132bd354c22fcfb3133f) --- be/src/exprs/function/function_map.cpp | 7 +- .../test_timestamptz_map_contains_entry.out | 43 +++++ ...test_timestamptz_map_contains_entry.groovy | 155 ++++++++++++++++++ 3 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out create mode 100644 regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy diff --git a/be/src/exprs/function/function_map.cpp b/be/src/exprs/function/function_map.cpp index ffe3e773b9f609..d0a0ab639a94cc 100644 --- a/be/src/exprs/function/function_map.cpp +++ b/be/src/exprs/function/function_map.cpp @@ -867,10 +867,11 @@ class FunctionMapContainsEntry : public IFunction { /*nan_direction_hint=*/1) == 0; } - // whether this function supports equality comparison for the given primitive type + // whether this function supports equality comparison for the given primitive type. + // Uses dispatch_switch_all as the single source of truth so any type supported + // by the dispatch layer is automatically accepted here. bool is_equality_comparison_supported(PrimitiveType type) const { - return is_string_type(type) || is_number(type) || is_date_type(type) || - is_time_type(type) || is_ip(type); + return dispatch_switch_all(type, [](const auto&) { return true; }); } }; diff --git a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out new file mode 100644 index 00000000000000..43746eee1800bd --- /dev/null +++ b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out @@ -0,0 +1,43 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !value_hit -- +true + +-- !value_miss -- +false + +-- !value_miss_key -- +false + +-- !key_hit -- +true + +-- !key_miss_value -- +false + +-- !key_miss_key -- +false + +-- !table_value_hit -- +1 true +2 false + +-- !table_value_miss -- +1 false +2 false + +-- !table_key_hit -- +1 true +2 false + +-- !table_key_miss -- +1 false +2 false + +-- !null_search_key -- +1 false +2 false + +-- !null_search_value -- +1 false +2 false + diff --git a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy new file mode 100644 index 00000000000000..2b814ef8b6e539 --- /dev/null +++ b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_timestamptz_map_contains_entry") { + + sql "set time_zone = '+08:00';" + sql "set enable_nereids_planner = true;" + sql "set enable_fallback_to_original_planner = false;" + + // --- inline literal tests (no table needed) --- + + // TIMESTAMPTZ as map value: hit + qt_value_hit """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'a', + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map value: miss (wrong value) + qt_value_miss """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map value: miss (wrong key) + qt_value_miss_key """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'c', + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map key: hit + qt_key_hit """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'a' + ); + """ + + // TIMESTAMPTZ as map key: miss (wrong value) + qt_key_miss_value """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b' + ); + """ + + // TIMESTAMPTZ as map key: miss (wrong key) + qt_key_miss_key """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-03 00:00:00.000000 +00:00' as timestamptz(6)), + 'a' + ); + """ + + // --- table-based tests --- + + sql "DROP TABLE IF EXISTS test_timestamptz_map_contains_entry_t;" + sql """ + CREATE TABLE test_timestamptz_map_contains_entry_t ( + id INT, + map_s_tz MAP, + map_tz_s MAP + ) + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES("replication_num" = "1"); + """ + + sql """ + INSERT INTO test_timestamptz_map_contains_entry_t VALUES ( + 1, + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b') + ), ( + 2, + map('x', cast('2024-06-15 12:00:00.000000 +05:30' as timestamptz(6))), + map(cast('2024-06-15 12:00:00.000000 +05:30' as timestamptz(6)), 'x') + ); + """ + + // TIMESTAMPTZ as map value, hit + qt_table_value_hit """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map value, miss + qt_table_value_miss """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map key, hit + qt_table_key_hit """ + SELECT id, map_contains_entry(map_tz_s, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a') + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map key, miss + qt_table_key_miss """ + SELECT id, map_contains_entry(map_tz_s, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'b') + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // NULL search key + qt_null_search_key """ + SELECT id, map_contains_entry(map_s_tz, NULL, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // NULL search value + qt_null_search_value """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast(NULL as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ +} From a6f53a194f5946c00f8e3f250e47103de93c8df7 Mon Sep 17 00:00:00 2001 From: Mryange Date: Fri, 29 May 2026 10:39:25 +0800 Subject: [PATCH 09/10] [refine](function) use typed ANN query vector (#63834) ANN query vector extraction returned a generic `IColumn::Ptr`, so the TopN and range search paths had to downcast the column again before reading float data. This made the code more indirect and delayed type validation. This PR changes the helper and runtime state to keep the query vector as `ColumnFloat32::Ptr`, validates the concrete type at extraction time, and removes redundant casts from the ANN execution path. (cherry picked from commit 99691f6895d6d3f527ced114fb225a742888e4b4) --- .../index/ann/ann_range_search_runtime.cpp | 3 +-- .../index/ann/ann_range_search_runtime.h | 2 +- be/src/storage/index/ann/ann_topn_runtime.cpp | 19 +++++++++++++------ be/src/storage/index/ann/ann_topn_runtime.h | 5 +++-- .../index/ann/ann_range_search_test.cpp | 3 ++- .../index/ann/ann_topn_descriptor_test.cpp | 3 +-- .../index/ann/extract_query_vector_test.cpp | 18 +++++++++++++++++- 7 files changed, 38 insertions(+), 15 deletions(-) diff --git a/be/src/storage/index/ann/ann_range_search_runtime.cpp b/be/src/storage/index/ann/ann_range_search_runtime.cpp index a223c96e6c6be8..b38576469e8a3a 100644 --- a/be/src/storage/index/ann/ann_range_search_runtime.cpp +++ b/be/src/storage/index/ann/ann_range_search_runtime.cpp @@ -35,8 +35,7 @@ namespace doris::segment_v2 { */ AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const { AnnRangeSearchParams params; - const auto* query = assert_cast(query_value.get()); - params.query_value = query->get_data().data(); + params.query_value = query_value->get_data().data(); params.radius = static_cast(radius); params.roaring = nullptr; params.is_le_or_lt = is_le_or_lt; diff --git a/be/src/storage/index/ann/ann_range_search_runtime.h b/be/src/storage/index/ann/ann_range_search_runtime.h index c1063404f60466..7ca0a830d8b68f 100644 --- a/be/src/storage/index/ann/ann_range_search_runtime.h +++ b/be/src/storage/index/ann/ann_range_search_runtime.h @@ -133,7 +133,7 @@ struct AnnRangeSearchRuntime { double radius = 0.0; ///< Search radius/distance threshold AnnIndexMetric metric_type; ///< Distance metric (L2, Inner Product, etc.) doris::VectorSearchUserParams user_params; ///< User-defined search parameters - IColumn::Ptr query_value; ///< Query vector data (deep copied) + ColumnFloat32::Ptr query_value; ///< Query vector data }; #include "common/compile_check_end.h" } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/storage/index/ann/ann_topn_runtime.cpp b/be/src/storage/index/ann/ann_topn_runtime.cpp index 4ac4042395fed4..1742b65065ec63 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.cpp +++ b/be/src/storage/index/ann/ann_topn_runtime.cpp @@ -29,6 +29,7 @@ #include "core/column/column_array.h" #include "core/column/column_const.h" #include "core/column/column_nullable.h" +#include "core/column/column_vector.h" #include "core/data_type/primitive_type.h" #include "exprs/function/array/function_array_distance.h" #include "exprs/vexpr_context.h" @@ -43,7 +44,7 @@ namespace doris::segment_v2 { #include "common/compile_check_begin.h" -Result extract_query_vector(std::shared_ptr arg_expr) { +Result extract_query_vector(std::shared_ptr arg_expr) { if (arg_expr->is_constant() == false) { return ResultError(Status::InvalidArgument("Ann topn expr must be constant, got\n{}", arg_expr->debug_string())); @@ -99,7 +100,14 @@ Result extract_query_vector(std::shared_ptr arg_expr) { values_holder_col = value_nullable_col->get_nested_column_ptr(); } - return values_holder_col; + auto float_col = check_and_get_column_ptr(values_holder_col); + if (float_col.get() == nullptr) { + return ResultError(Status::InvalidArgument( + "Ann topn query vector elements must be Float32, got column: {}", + values_holder_col->get_name())); + } + + return float_col; } Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_desc) { @@ -188,10 +196,10 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator* DCHECK(ann_index_iterator != nullptr); DCHECK(_order_by_expr_ctx != nullptr); DCHECK(_order_by_expr_ctx->root() != nullptr); - size_t query_array_size = _query_array->size(); - if (_query_array.get() == nullptr || query_array_size == 0) { + if (_query_array.get() == nullptr || _query_array->size() == 0) { return Status::InternalError("Ann topn query vector is not initialized"); } + size_t query_array_size = _query_array->size(); // TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase. @@ -203,9 +211,8 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator* "Ann topn query vector dimension {} does not match index dimension {}", query_array_size, ann_index_reader->get_dimension()); } - const ColumnFloat32* query = assert_cast(_query_array.get()); segment_v2::AnnTopNParam ann_query_params { - .query_value = query->get_data().data(), + .query_value = _query_array->get_data().data(), .query_value_size = query_array_size, .limit = _limit, ._user_params = _user_params, diff --git a/be/src/storage/index/ann/ann_topn_runtime.h b/be/src/storage/index/ann/ann_topn_runtime.h index 63e04cc30b6256..9ad2cd7df0ba17 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.h +++ b/be/src/storage/index/ann/ann_topn_runtime.h @@ -36,6 +36,7 @@ #pragma once #include "core/column/column.h" +#include "core/column/column_vector.h" #include "core/data_type/primitive_type.h" #include "exprs/vectorized_fn_call.h" #include "exprs/vexpr.h" @@ -49,7 +50,7 @@ namespace doris::segment_v2 { struct AnnIndexStats; class AnnIndexIterator; -Result extract_query_vector(std::shared_ptr arg_expr); +Result extract_query_vector(std::shared_ptr arg_expr); /** * @brief Runtime execution engine for ANN (Approximate Nearest Neighbor) Top-N queries. @@ -162,7 +163,7 @@ class AnnTopNRuntime { size_t _src_column_idx = -1; ///< Source vector column index size_t _dest_column_idx = -1; ///< Destination distance column index segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type - IColumn::Ptr _query_array; ///< Query vector data (contiguous float buffer) + ColumnFloat32::Ptr _query_array; ///< Query vector data (contiguous float buffer) doris::VectorSearchUserParams _user_params; ///< User-defined search parameters }; #include "common/compile_check_end.h" diff --git a/be/test/storage/index/ann/ann_range_search_test.cpp b/be/test/storage/index/ann/ann_range_search_test.cpp index 400e822695ca32..890856b8b0a925 100644 --- a/be/test/storage/index/ann/ann_range_search_test.cpp +++ b/be/test/storage/index/ann/ann_range_search_test.cpp @@ -100,8 +100,9 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) { EXPECT_EQ(ann_range_search_runtime.radius, 10.0f); std::vector query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20}; std::vector query_array_f32; + const auto& query_value = range_search_ctx->_ann_range_search_runtime.query_value; for (int i = 0; i < query_array_groud_truth.size(); ++i) { - query_array_f32.push_back(static_cast(ann_range_search_runtime.query_value[i])); + query_array_f32.push_back(static_cast(query_value->get_data()[i])); } for (int i = 0; i < query_array_f32.size(); ++i) { EXPECT_EQ(query_array_f32[i], query_array_groud_truth[i]); diff --git a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp index 2cb9f293ee583b..880f42f6a9dd86 100644 --- a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp +++ b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp @@ -116,8 +116,7 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) { ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(), predicate->get_order_by_expr_ctx()->root()->debug_string()); - const ColumnFloat32* query_column = - assert_cast(predicate->_query_array.get()); + const auto& query_column = predicate->_query_array; const float* query_value = query_column->get_data().data(); const size_t query_value_size = predicate->_query_array->size(); ASSERT_EQ(query_value_size, 8); diff --git a/be/test/storage/index/ann/extract_query_vector_test.cpp b/be/test/storage/index/ann/extract_query_vector_test.cpp index 8fd6850218ec10..22ab34ab32bc9a 100644 --- a/be/test/storage/index/ann/extract_query_vector_test.cpp +++ b/be/test/storage/index/ann/extract_query_vector_test.cpp @@ -178,7 +178,7 @@ TEST_F(ExtractQueryVectorTest, ValuesMatchInput) { auto result = extract_query_vector(mock); ASSERT_TRUE(result.has_value()); - auto* float_col = assert_cast(result.value().get()); + const auto& float_col = result.value(); ASSERT_EQ(float_col->size(), 4u); for (size_t i = 0; i < input.size(); ++i) { EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]); @@ -240,4 +240,20 @@ TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) { EXPECT_TRUE(result.error().to_string().find("Array literal") != std::string::npos); } +TEST_F(ExtractQueryVectorTest, NonFloatArrayFails) { + auto int_col = ColumnInt32::create(); + int_col->insert_value(1); + int_col->insert_value(2); + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(2); + auto array_col = ColumnArray::create(std::move(int_col), std::move(offsets)); + + auto mock = std::make_shared(); + mock->set_column(std::move(array_col)); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("must be Float32") != std::string::npos); +} + } // namespace doris::segment_v2 From ff44136c1f5ce1f6e07a780a9a867b986148e881 Mon Sep 17 00:00:00 2001 From: Mryange Date: Thu, 21 May 2026 12:07:33 +0800 Subject: [PATCH 10/10] [fix](function) support TIMESTAMPDIFF MICROSECOND in nereids (#63365) ### What problem does this PR solve? Nereids rejects `TIMESTAMPDIFF(MICROSECOND, ...)` during analysis. The executable path already exists through `MicroSecondsDiff`, but the FE binder cannot reach it because: - `Interval.TimeUnit` does not define a standalone `MICROSECOND` - `DatetimeFunctionBinder` only routes `TIMESTAMPDIFF` up to `SECOND` As a result, queries such as `TIMESTAMPDIFF(MICROSECOND, MIN(t), MAX(t))` fail even for valid `DATETIMEV2(6)` inputs. (cherry picked from commit 4ab7cc024611e53d9317c2a1e64f265af46e2913) --- .../analysis/DatetimeFunctionBinder.java | 5 +++- .../trees/expressions/literal/Interval.java | 1 + .../analysis/DatetimeFunctionBinderTest.java | 10 +++++++ .../nereids_syntax_p0/test_timestampdiff.out | 6 ++++ .../test_timestampdiff.groovy | 28 +++++++++++++++++++ 5 files changed, 49 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java index c93f151cf0c092..4e1a768bd97fe9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java @@ -55,6 +55,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MicroSecondsDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteMicrosecondAdd; @@ -301,9 +302,11 @@ private Expression processTimestampDiff(TimeUnit unit, Expression start, Express return new MinutesDiff(end, start); case SECOND: return new SecondsDiff(end, start); + case MICROSECOND: + return new MicroSecondsDiff(end, start); default: throw new AnalysisException("Unsupported time stamp diff time unit: " + unit - + ", supported time unit: YEAR/QUARTER/MONTH/WEEK/DAY/HOUR/MINUTE/SECOND"); + + ", supported time unit: YEAR/QUARTER/MONTH/WEEK/DAY/HOUR/MINUTE/SECOND/MICROSECOND"); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java index 275e0f74fe1bc9..f490c225c444c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java @@ -106,6 +106,7 @@ public enum TimeUnit { MINUTE_SECOND("MINUTE_SECOND", false, 200), MINUTE_MICROSECOND("MINUTE_MICROSECOND", false, 200), SECOND("SECOND", true, 100), + MICROSECOND("MICROSECOND", true, 0), SECOND_MICROSECOND("SECOND_MICROSECOND", true, 100); private final String description; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java index 81f24ed878bd4d..a63e4a3e6282a2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java @@ -42,6 +42,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MicroSecondsDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd; @@ -110,6 +111,8 @@ public class DatetimeFunctionBinderTest { TinyIntType.INSTANCE, false, ImmutableList.of()); private final SlotReference secondUnit = new SlotReference(new ExprId(-1), "SECOND", TinyIntType.INSTANCE, false, ImmutableList.of()); + private final SlotReference microsecondUnit = new SlotReference(new ExprId(-1), "MICROSECOND", + TinyIntType.INSTANCE, false, ImmutableList.of()); private final SlotReference invalidUnit = new SlotReference(new ExprId(-1), "INVALID", TinyIntType.INSTANCE, false, ImmutableList.of()); @@ -172,6 +175,13 @@ void testTimestampDiff() { Assertions.assertEquals(dateTimeV2Literal2, result.child(0)); Assertions.assertEquals(dateTimeV2Literal1, result.child(1)); + timeDiff = new UnboundFunction(functionName, ImmutableList.of( + microsecondUnit, dateTimeV2Literal1, dateTimeV2Literal2)); + result = DatetimeFunctionBinder.INSTANCE.bind(timeDiff); + Assertions.assertInstanceOf(MicroSecondsDiff.class, result); + Assertions.assertEquals(dateTimeV2Literal2, result.child(0)); + Assertions.assertEquals(dateTimeV2Literal1, result.child(1)); + Assertions.assertThrowsExactly(AnalysisException.class, () -> DatetimeFunctionBinder.INSTANCE.bind( new UnboundFunction(functionName, ImmutableList.of(invalidUnit, diff --git a/regression-test/data/nereids_syntax_p0/test_timestampdiff.out b/regression-test/data/nereids_syntax_p0/test_timestampdiff.out index 0e2dd6a537559e..15623515ed485f 100644 --- a/regression-test/data/nereids_syntax_p0/test_timestampdiff.out +++ b/regression-test/data/nereids_syntax_p0/test_timestampdiff.out @@ -17,3 +17,9 @@ -- !select -- 40 +-- !select -- +876543 + +-- !select -- +2024-01-01T10:00:00.999999 2024-01-01T10:00:00.123456 876543 + diff --git a/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy b/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy index 34500732e22920..0a3e563bd7f866 100644 --- a/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy +++ b/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy @@ -37,4 +37,32 @@ suite("test_timestampdiff") { qt_select """ SELECT TIMESTAMPDIFF(second,'2003-02-03 11:00:00','2003-02-03 11:00:40'); """ + + qt_select """ + SELECT TIMESTAMPDIFF(microsecond, + CAST('2024-01-01 10:00:00.123456' AS DATETIMEV2(6)), + CAST('2024-01-01 10:00:00.999999' AS DATETIMEV2(6))); + """ + + sql """drop table if exists test_timestampdiff_microsecond""" + sql """ + create table test_timestampdiff_microsecond ( + id int, + t datetimev2(6) + ) + duplicate key(id) + distributed by hash(id) buckets 1 + properties("replication_num" = "1"); + """ + + sql """ + insert into test_timestampdiff_microsecond values + (1, '2024-01-01 10:00:00.123456'), + (2, '2024-01-01 10:00:00.999999'); + """ + + qt_select """ + SELECT MAX(t), MIN(t), TIMESTAMPDIFF(MICROSECOND, MIN(t), MAX(t)) + FROM test_timestampdiff_microsecond; + """ } \ No newline at end of file