diff --git a/be/src/common/thread_safety_annotations.h b/be/src/common/thread_safety_annotations.h index 6cd8d4b0cae45c..6bbdb8ce6546ad 100644 --- a/be/src/common/thread_safety_annotations.h +++ b/be/src/common/thread_safety_annotations.h @@ -22,6 +22,7 @@ #pragma once #include +#include #ifdef BE_TEST namespace doris { @@ -93,6 +94,27 @@ class CAPABILITY("mutex") AnnotatedMutex { std::mutex _mutex; }; +// Annotated shared mutex wrapper for use with Clang thread safety analysis. +// Wraps std::shared_mutex and provides both exclusive and shared capability +// operations so GUARDED_BY / REQUIRES_SHARED / etc. can reference it. +class CAPABILITY("mutex") AnnotatedSharedMutex { +public: + void lock() ACQUIRE() { _mutex.lock(); } + void unlock() RELEASE() { _mutex.unlock(); } + bool try_lock() TRY_ACQUIRE(true) { return _mutex.try_lock(); } + + void lock_shared() ACQUIRE_SHARED() { _mutex.lock_shared(); } + void unlock_shared() RELEASE_SHARED() { _mutex.unlock_shared(); } + bool try_lock_shared() TRY_ACQUIRE_SHARED(true) { return _mutex.try_lock_shared(); } + + // Access the underlying std::shared_mutex (e.g., for std::condition_variable_any). + // Use with care — this bypasses thread safety annotations. + std::shared_mutex& native_handle() { return _mutex; } + +private: + std::shared_mutex _mutex; +}; + // RAII scoped lock guard annotated for thread safety analysis. // In BE_TEST builds, injects a random sleep before acquiring and after // releasing the lock to exercise concurrent code paths. @@ -119,6 +141,32 @@ class SCOPED_CAPABILITY LockGuard { MutexType& _mu; }; +// RAII scoped shared lock guard annotated for thread safety analysis. +// In BE_TEST builds, injects a random sleep before acquiring and after +// releasing the lock to exercise concurrent code paths. +template +class SCOPED_CAPABILITY SharedLockGuard { +public: + explicit SharedLockGuard(MutexType& mu) ACQUIRE_SHARED(mu) : _mu(mu) { +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + _mu.lock_shared(); + } + ~SharedLockGuard() RELEASE() { + _mu.unlock_shared(); +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + } + + SharedLockGuard(const SharedLockGuard&) = delete; + SharedLockGuard& operator=(const SharedLockGuard&) = delete; + +private: + MutexType& _mu; +}; + // RAII unique lock annotated for thread safety analysis. // Supports manual lock/unlock while preserving capability tracking. template diff --git a/be/src/core/column/column_const.h b/be/src/core/column/column_const.h index 1d0a0d7e596d59..cf26588a6a5d84 100644 --- a/be/src/core/column/column_const.h +++ b/be/src/core/column/column_const.h @@ -126,7 +126,8 @@ class ColumnConst final : public COWHelper { void resize(size_t new_size) override { s = new_size; } MutableColumnPtr clone_resized(size_t new_size) const override { - return ColumnConst::create(data, new_size, false, false); + auto cloned_data = data->clone_resized(data->size()); + return ColumnConst::create(std::move(cloned_data), new_size, false, false); } size_t size() const override { return s; } diff --git a/be/src/exec/exchange/vdata_stream_mgr.cpp b/be/src/exec/exchange/vdata_stream_mgr.cpp index 17bab298c432c8..b3357f8d0b6006 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.cpp +++ b/be/src/exec/exchange/vdata_stream_mgr.cpp @@ -44,7 +44,7 @@ VDataStreamMgr::~VDataStreamMgr() { // It will core during graceful stop. auto receivers = std::vector>(); { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); auto receiver_iterator = _receiver_map.begin(); while (receiver_iterator != _receiver_map.end()) { // Could not call close directly, because during close method, it will remove itself @@ -77,22 +77,16 @@ std::shared_ptr VDataStreamMgr::create_recvr( this, memory_used_counter, state, fragment_instance_id, dest_node_id, num_senders, is_merging, profile, data_queue_capacity)); uint32_t hash_value = get_hash_value(fragment_instance_id, dest_node_id); - std::unique_lock l(_lock); + LockGuard l(_lock); _fragment_stream_set.insert(std::make_pair(fragment_instance_id, dest_node_id)); _receiver_map.insert(std::make_pair(hash_value, recvr)); return recvr; } -Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock) { +Status VDataStreamMgr::_find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) { VLOG_ROW << "looking up fragment_instance_id=" << print_id(fragment_instance_id) << ", node=" << node_id; - uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); - // Create lock guard and not own lock currently and will lock conditionally - std::shared_lock recvr_lock(_lock, std::defer_lock); - if (acquire_lock) { - recvr_lock.lock(); - } std::pair range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { @@ -108,6 +102,13 @@ Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNod node_id, print_id(fragment_instance_id)); } +Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, + std::shared_ptr* res) { + SharedLockGuard recvr_lock(_lock); + uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); + return _find_recvr(hash_value, fragment_instance_id, node_id, res); +} + Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request, ::google::protobuf::Closure** done, const int64_t wait_for_worker) { @@ -199,7 +200,7 @@ Status VDataStreamMgr::deregister_recvr(const TUniqueId& fragment_instance_id, P << ", node=" << node_id; uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); { - std::unique_lock l(_lock); + LockGuard l(_lock); auto range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { const std::shared_ptr& recvr = range.first->second; @@ -230,12 +231,13 @@ void VDataStreamMgr::cancel(const TUniqueId& fragment_instance_id, Status exec_s VLOG_QUERY << "cancelling all streams for fragment=" << print_id(fragment_instance_id); std::vector> recvrs; { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); FragmentStreamSet::iterator i = _fragment_stream_set.lower_bound(std::make_pair(fragment_instance_id, 0)); while (i != _fragment_stream_set.end() && i->first == fragment_instance_id) { std::shared_ptr recvr; - WARN_IF_ERROR(find_recvr(i->first, i->second, &recvr, false), ""); + uint32_t hash_value = get_hash_value(i->first, i->second); + WARN_IF_ERROR(_find_recvr(hash_value, i->first, i->second, &recvr), ""); if (recvr == nullptr) { // keep going but at least log it std::stringstream err; diff --git a/be/src/exec/exchange/vdata_stream_mgr.h b/be/src/exec/exchange/vdata_stream_mgr.h index 7bde8f3b4c0c9b..7f35d62720278c 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.h +++ b/be/src/exec/exchange/vdata_stream_mgr.h @@ -30,6 +30,7 @@ #include "common/be_mock_util.h" #include "common/global_types.h" #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "runtime/runtime_profile.h" namespace google { @@ -58,8 +59,7 @@ class VDataStreamMgr { RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity); MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, - bool acquire_lock = true); + std::shared_ptr* res); Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id); @@ -69,9 +69,9 @@ class VDataStreamMgr { void cancel(const TUniqueId& fragment_instance_id, Status exec_status); private: - std::shared_mutex _lock; + AnnotatedSharedMutex _lock; using StreamMap = std::unordered_multimap>; - StreamMap _receiver_map; + StreamMap _receiver_map GUARDED_BY(_lock); struct ComparisonOp { bool operator()(const std::pair& a, @@ -89,7 +89,11 @@ class VDataStreamMgr { } }; using FragmentStreamSet = std::set, ComparisonOp>; - FragmentStreamSet _fragment_stream_set; + FragmentStreamSet _fragment_stream_set GUARDED_BY(_lock); + + Status _find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) + REQUIRES_SHARED(_lock); uint32_t get_hash_value(const TUniqueId& fragment_instance_id, PlanNodeId node_id); }; diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp index 3615299fa38148..49007d5c73534b 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp @@ -185,7 +185,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc( auto filter_id = runtime_filter_desc->filter_id; GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } @@ -235,7 +235,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrfilter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); if (iter == _filter_map.end()) { return Status::InvalidArgument("unknown filter id {}", @@ -243,12 +243,12 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrsecond; - std::unique_lock l(iter->second.mtx); + std::unique_lock l(cnt_val.mtx); // Discard stale-stage runtime filter size requests from old recursive CTE rounds. // Each round increments the stage counter; only messages matching the current stage // should be processed. This prevents old PFC's runtime filters from corrupting // the merge state of the new round's filters. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } cnt_val.source_addrs.push_back(request->source_addr()); @@ -269,7 +269,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptr(); - sync_request->set_stage(iter->second.stage); + sync_request->set_stage(cnt_val.stage); auto closure = AutoReleaseClosure>:: @@ -336,7 +336,7 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto filter_id = request->filter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); VLOG_ROW << "recv filter id:" << request->filter_id() << " " << request->ShortDebugString(); if (iter == _filter_map.end()) { @@ -347,9 +347,9 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto& cnt_val = iter->second; bool is_ready = false; { - std::lock_guard l(iter->second.mtx); + std::lock_guard l(cnt_val.mtx); // Discard stale-stage merge requests from old recursive CTE rounds. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } if (cnt_val.merger == nullptr) { @@ -492,7 +492,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( for (const auto& filter_id : filter_ids) { GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } RETURN_IF_ERROR(cnt_val->reset(query_ctx)); @@ -502,7 +502,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( std::string RuntimeFilterMergeControllerEntity::debug_string() { std::string result = "RuntimeFilterMergeControllerEntity Info:\n"; - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); for (const auto& [filter_id, ctx] : _filter_map) { result += fmt::format("filter_id: {}, stage: {}, {}\n", filter_id, ctx.stage, ctx.merger->debug_string()); diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.h b/be/src/exec/runtime_filter/runtime_filter_mgr.h index f822e01196f853..418f9aa41b7414 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.h +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.h @@ -27,12 +27,11 @@ #include #include #include -#include #include -#include #include #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "util/uid_util.h" namespace butil { @@ -168,7 +167,7 @@ class RuntimeFilterMergeControllerEntity { std::string debug_string(); bool empty() { - std::shared_lock read_lock(_filter_map_mutex); + SharedLockGuard read_lock(_filter_map_mutex); return _filter_map.empty(); } @@ -185,10 +184,10 @@ class RuntimeFilterMergeControllerEntity { int64_t merge_time, PUniqueId query_id, int execution_timeout); // protect _filter_map - std::shared_mutex _filter_map_mutex; + AnnotatedSharedMutex _filter_map_mutex; std::shared_ptr _mem_tracker; - std::map _filter_map; + std::map _filter_map GUARDED_BY(_filter_map_mutex); }; #include "common/compile_check_end.h" } // namespace doris diff --git a/be/src/exec/sink/writer/vmysql_result_writer.cpp b/be/src/exec/sink/writer/vmysql_result_writer.cpp index 4101f18db3c457..7f98f626f3f2fc 100644 --- a/be/src/exec/sink/writer/vmysql_result_writer.cpp +++ b/be/src/exec/sink/writer/vmysql_result_writer.cpp @@ -297,6 +297,12 @@ Status VMysqlResultWriter::write(RuntimeState* state, Block& input_block) { Block block; RETURN_IF_ERROR(VExprContext::get_output_block_after_execute_exprs(_output_vexpr_ctxs, input_block, &block)); + + if (_is_dry_run) { + _written_rows += cast_set(block.rows()); + return Status::OK(); + } + const auto total_bytes = block.bytes(); if (total_bytes > config::thrift_max_message_size) [[unlikely]] { diff --git a/be/src/exec/sort/sort_cursor.h b/be/src/exec/sort/sort_cursor.h index d5b4a14e46158f..dae751258a5e20 100644 --- a/be/src/exec/sort/sort_cursor.h +++ b/be/src/exec/sort/sort_cursor.h @@ -205,6 +205,11 @@ struct MergeSortCursor { return !impl->empty() && greater_at(rhs, impl->pos, rhs.impl->pos) > 0; } + bool totally_less_or_equals(const MergeSortCursor& rhs) const { + return !impl->empty() && !rhs.impl->empty() && + greater_at(rhs, impl->rows - 1, rhs.impl->pos) <= 0; + } + /// Inverted so that the priority queue elements are removed in ascending order. bool operator<(const MergeSortCursor& rhs) const { return greater(rhs); } diff --git a/be/src/exec/sort/sorter.cpp b/be/src/exec/sort/sorter.cpp index 88160819328ce0..616cc2145a2d16 100644 --- a/be/src/exec/sort/sorter.cpp +++ b/be/src/exec/sort/sorter.cpp @@ -94,6 +94,24 @@ Status MergeSorterState::merge_sort_read(doris::Block* block, int batch_size, bo } void MergeSorterState::_merge_sort_read_impl(int batch_size, doris::Block* block, bool* eos) { + if (_queue.is_valid() && batch_size > 0) { + auto [current, current_rows] = _queue.current(); + current_rows = std::min(current_rows, static_cast(batch_size)); + const size_t step = std::min(_offset, current_rows); + + // Fast path when the current top run can contribute its whole remaining block + // before any other run. The returned block stays within batch_size because + // is_last(current_rows) can only hold after the min(batch_size, queue_batch_size) + // clamp above. + if (step == 0 && current->impl->is_first() && current->impl->is_last(current_rows) && + (_queue.size() == 1 || (*current).totally_less_or_equals(_queue.next_child()))) { + current->impl->block->swap(*block); + _queue.remove_top(); + *eos = false; + return; + } + } + size_t num_columns = unsorted_block()->columns(); MutableBlock m_block = VectorizedUtils::build_mutable_mem_reuse_block(block, *unsorted_block()); diff --git a/be/src/exprs/aggregate/aggregate_function.h b/be/src/exprs/aggregate/aggregate_function.h index 0e07f74c1aeab1..94eaaa9ad72403 100644 --- a/be/src/exprs/aggregate/aggregate_function.h +++ b/be/src/exprs/aggregate/aggregate_function.h @@ -479,19 +479,21 @@ class IAggregateFunctionHelper : public IAggregateFunction { size_t num_rows) const override { const Derived* derived = assert_cast(this); const auto size_of_data = derived->size_of_data(); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto place = places + size_of_data * i; VectorBufferReader buffer_reader(column->get_data_at(i)); derived->create(place); + ++created_count; derived->deserialize(place, buffer_reader, arena); - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = places + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = places + size_of_data * j; + derived->destroy(place); + } + throw; } } @@ -502,19 +504,21 @@ class IAggregateFunctionHelper : public IAggregateFunction { const auto size_of_data = derived->size_of_data(); const auto* column_string = assert_cast(column); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto rhs_place = rhs + size_of_data * i; VectorBufferReader buffer_reader(column_string->get_data_at(i)); derived->create(rhs_place); + ++created_count; derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader, arena); - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = rhs + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = rhs + size_of_data * j; + derived->destroy(place); + } + throw; } derived->destroy_vec(rhs, num_rows); @@ -526,22 +530,24 @@ class IAggregateFunctionHelper : public IAggregateFunction { const auto* derived = assert_cast(this); const auto size_of_data = derived->size_of_data(); const auto* column_string = assert_cast(column); - for (size_t i = 0; i != num_rows; ++i) { - try { + size_t created_count = 0; + try { + for (size_t i = 0; i != num_rows; ++i) { auto rhs_place = rhs + size_of_data * i; VectorBufferReader buffer_reader(column_string->get_data_at(i)); derived->create(rhs_place); + ++created_count; if (places[i]) { derived->deserialize_and_merge(places[i] + offset, rhs_place, buffer_reader, arena); } - } catch (...) { - for (int j = 0; j < i; ++j) { - auto place = rhs + size_of_data * j; - derived->destroy(place); - } - throw; } + } catch (...) { + for (size_t j = 0; j < created_count; ++j) { + auto place = rhs + size_of_data * j; + derived->destroy(place); + } + throw; } derived->destroy_vec(rhs, num_rows); } diff --git a/be/src/exprs/aggregate/aggregate_function_collect.h b/be/src/exprs/aggregate/aggregate_function_collect.h index 3f9c84f7dea373..63a3c6348225ad 100644 --- a/be/src/exprs/aggregate/aggregate_function_collect.h +++ b/be/src/exprs/aggregate/aggregate_function_collect.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -50,7 +51,7 @@ struct AggregateFunctionCollectSetData { using ElementType = typename PrimitiveTypeTraits::CppType; using ColVecType = typename PrimitiveTypeTraits::ColumnType; using SelfType = AggregateFunctionCollectSetData; - using Set = phmap::flat_hash_set; + using Set = doris::flat_hash_set; Set data_set; Int64 max_size = -1; @@ -119,7 +120,7 @@ struct AggregateFunctionCollectSetData { using ElementType = StringRef; using ColVecType = ColumnString; using SelfType = AggregateFunctionCollectSetData; - using Set = phmap::flat_hash_set; + using Set = doris::flat_hash_set; Set data_set; Int64 max_size = -1; @@ -343,6 +344,10 @@ struct AggregateFunctionCollectListData { buf.write_binary(size); DataTypeSerDe::FormatOptions opt; + auto timezone = cctz::utc_time_zone(); + opt.timezone = &timezone; + // TODO: Refactor this aggregate state serialization to avoid + // round-tripping through a human-readable string format. auto tmp_str = ColumnString::create(); VectorBufferWriter tmp_buf(*tmp_str.get()); @@ -368,6 +373,8 @@ struct AggregateFunctionCollectListData { StringRef s; DataTypeSerDe::FormatOptions opt; + auto timezone = cctz::utc_time_zone(); + opt.timezone = &timezone; for (size_t i = 0; i < size; i++) { buf.read_binary(s); Slice slice(s.data, s.size); diff --git a/be/src/exprs/aggregate/aggregate_function_distinct.h b/be/src/exprs/aggregate/aggregate_function_distinct.h index 618d9b46f41996..825e782f0cba53 100644 --- a/be/src/exprs/aggregate/aggregate_function_distinct.h +++ b/be/src/exprs/aggregate/aggregate_function_distinct.h @@ -52,8 +52,8 @@ template struct AggregateFunctionDistinctSingleNumericData { /// When creating, the hash table must be small. using Container = std::conditional_t< - stable, phmap::flat_hash_map::CppType, uint32_t>, - phmap::flat_hash_set::CppType>>; + stable, doris::flat_hash_map::CppType, uint32_t>, + doris::flat_hash_set::CppType>>; using Self = AggregateFunctionDistinctSingleNumericData; Container data; @@ -126,8 +126,8 @@ struct AggregateFunctionDistinctSingleNumericData { template struct AggregateFunctionDistinctGenericData { /// When creating, the hash table must be small. - using Container = std::conditional_t, - phmap::flat_hash_set>; + using Container = std::conditional_t, + doris::flat_hash_set>; using Self = AggregateFunctionDistinctGenericData; Container data; diff --git a/be/src/exprs/aggregate/aggregate_function_map.h b/be/src/exprs/aggregate/aggregate_function_map.h index f9aff592503cc0..a16bea7867c91a 100644 --- a/be/src/exprs/aggregate/aggregate_function_map.h +++ b/be/src/exprs/aggregate/aggregate_function_map.h @@ -35,7 +35,7 @@ namespace doris { template struct AggregateFunctionMapAggData { using KeyType = typename PrimitiveTypeTraits::CppType; - using Map = phmap::flat_hash_map; + using Map = doris::flat_hash_map; AggregateFunctionMapAggData() { throw Exception(Status::FatalError("__builtin_unreachable")); } diff --git a/be/src/exprs/aggregate/aggregate_function_map_v2.h b/be/src/exprs/aggregate/aggregate_function_map_v2.h index 3181b1ad4261d0..1d821c486c6b0a 100644 --- a/be/src/exprs/aggregate/aggregate_function_map_v2.h +++ b/be/src/exprs/aggregate/aggregate_function_map_v2.h @@ -33,7 +33,7 @@ namespace doris { #include "common/compile_check_begin.h" struct AggregateFunctionMapAggDataV2 { - using Map = phmap::flat_hash_map; + using Map = doris::flat_hash_map; AggregateFunctionMapAggDataV2() { throw Exception(Status::FatalError("__builtin_unreachable")); diff --git a/be/src/exprs/function/function_map.cpp b/be/src/exprs/function/function_map.cpp index ffe3e773b9f609..d0a0ab639a94cc 100644 --- a/be/src/exprs/function/function_map.cpp +++ b/be/src/exprs/function/function_map.cpp @@ -867,10 +867,11 @@ class FunctionMapContainsEntry : public IFunction { /*nan_direction_hint=*/1) == 0; } - // whether this function supports equality comparison for the given primitive type + // whether this function supports equality comparison for the given primitive type. + // Uses dispatch_switch_all as the single source of truth so any type supported + // by the dispatch layer is automatically accepted here. bool is_equality_comparison_supported(PrimitiveType type) const { - return is_string_type(type) || is_number(type) || is_date_type(type) || - is_time_type(type) || is_ip(type); + return dispatch_switch_all(type, [](const auto&) { return true; }); } }; diff --git a/be/src/exprs/function/function_quantile_state.cpp b/be/src/exprs/function/function_quantile_state.cpp index a0edb82dbf9450..af1b80822f0007 100644 --- a/be/src/exprs/function/function_quantile_state.cpp +++ b/be/src/exprs/function/function_quantile_state.cpp @@ -169,6 +169,8 @@ class FunctionQuantileStatePercent : public IFunction { bool use_default_implementation_for_nulls() const override { return false; } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; } + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const override { auto res_data_column = ColumnFloat64::create(); diff --git a/be/src/exprs/function/function_regexp.cpp b/be/src/exprs/function/function_regexp.cpp index 8a9871c8eb3d6c..fd642e8abf4395 100644 --- a/be/src/exprs/function/function_regexp.cpp +++ b/be/src/exprs/function/function_regexp.cpp @@ -34,6 +34,7 @@ #include "core/block/column_with_type_and_name.h" #include "core/column/column.h" #include "core/column/column_const.h" +#include "core/column/column_execute_util.h" #include "core/column/column_nullable.h" #include "core/column/column_string.h" #include "core/column/column_vector.h" @@ -189,23 +190,26 @@ struct RegexpExtractEngine { }; struct RegexpCountImpl { + using StringColumnView = ColumnView; + static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], size_t input_rows_count, ColumnInt32::Container& result_data) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - for (int i = 0; i < input_rows_count; ++i) { + auto str_col = StringColumnView::create(argument_columns[0]); + auto pattern_col = StringColumnView::create(argument_columns[1]); + for (size_t i = 0; i < input_rows_count; ++i) { + DCHECK(!str_col.is_null_at(i)); + DCHECK(!pattern_col.is_null_at(i)); result_data[i] = _execute_inner_loop(context, str_col, pattern_col, i); } } - static int _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, - const ColumnString* pattern_col, const size_t index_now) { + static int _execute_inner_loop(FunctionContext* context, const StringColumnView& str_col, + const StringColumnView& pattern_col, const size_t index_now) { re2::RE2* re = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); std::unique_ptr scoped_re; if (re == nullptr) { std::string error_str; - DCHECK(pattern_col); - const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, false)); + const auto pattern = pattern_col.value_at(index_now); bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(), StringRef(), scoped_re); if (!st) { @@ -216,7 +220,7 @@ struct RegexpCountImpl { re = scoped_re.get(); } - const auto& str = str_col->get_data_at(index_now); + const auto str = str_col.value_at(index_now); int count = 0; size_t pos = 0; while (pos < str.size) { diff --git a/be/src/exprs/function/uniform.cpp b/be/src/exprs/function/uniform.cpp index 3bd1e139e1528f..e639df7a2958bb 100644 --- a/be/src/exprs/function/uniform.cpp +++ b/be/src/exprs/function/uniform.cpp @@ -30,6 +30,7 @@ #include "core/block/block.h" #include "core/block/column_numbers.h" #include "core/column/column.h" +#include "core/column/column_execute_util.h" #include "core/column/column_vector.h" #include "core/data_type/data_type_number.h" // IWYU pragma: keep #include "core/data_type/primitive_type.h" @@ -74,12 +75,12 @@ struct UniformIntImpl { "uniform's min should be less than max, but got [{}, {})", min, max); } - // Get gen column (seed values) - const auto& gen_column = block.get_by_position(arguments[2]).column; + auto gen_column = + ColumnView::create(block.get_by_position(arguments[2]).column); for (int i = 0; i < input_rows_count; i++) { // Use gen value as seed for each row - auto seed = (*gen_column)[i].get(); + auto seed = gen_column.value_at(i); std::mt19937_64 generator(seed); std::uniform_int_distribution distribution(min, max); res_data[i] = distribution(generator); @@ -123,11 +124,12 @@ struct UniformDoubleImpl { } // Get gen column (seed values) - const auto& gen_column = block.get_by_position(arguments[2]).column; + auto gen_column = + ColumnView::create(block.get_by_position(arguments[2]).column); for (int i = 0; i < input_rows_count; i++) { // Use gen value as seed for each row - auto seed = (*gen_column)[i].get(); + auto seed = gen_column.value_at(i); std::mt19937_64 generator(seed); std::uniform_real_distribution distribution(min, max); res_data[i] = distribution(generator); @@ -146,6 +148,8 @@ class FunctionUniform : public IFunction { static FunctionPtr create() { return std::make_shared>(); } String get_name() const override { return name; } + bool use_default_implementation_for_constants() const override { return false; } + size_t get_number_of_arguments() const override { return get_variadic_argument_types_impl().size(); } @@ -158,6 +162,8 @@ class FunctionUniform : public IFunction { return Impl::get_variadic_argument_types(); } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {0, 1}; } + Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope if (scope == FunctionContext::FRAGMENT_LOCAL) { diff --git a/be/src/storage/index/ann/ann_range_search_runtime.cpp b/be/src/storage/index/ann/ann_range_search_runtime.cpp index a223c96e6c6be8..b38576469e8a3a 100644 --- a/be/src/storage/index/ann/ann_range_search_runtime.cpp +++ b/be/src/storage/index/ann/ann_range_search_runtime.cpp @@ -35,8 +35,7 @@ namespace doris::segment_v2 { */ AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const { AnnRangeSearchParams params; - const auto* query = assert_cast(query_value.get()); - params.query_value = query->get_data().data(); + params.query_value = query_value->get_data().data(); params.radius = static_cast(radius); params.roaring = nullptr; params.is_le_or_lt = is_le_or_lt; diff --git a/be/src/storage/index/ann/ann_range_search_runtime.h b/be/src/storage/index/ann/ann_range_search_runtime.h index c1063404f60466..7ca0a830d8b68f 100644 --- a/be/src/storage/index/ann/ann_range_search_runtime.h +++ b/be/src/storage/index/ann/ann_range_search_runtime.h @@ -133,7 +133,7 @@ struct AnnRangeSearchRuntime { double radius = 0.0; ///< Search radius/distance threshold AnnIndexMetric metric_type; ///< Distance metric (L2, Inner Product, etc.) doris::VectorSearchUserParams user_params; ///< User-defined search parameters - IColumn::Ptr query_value; ///< Query vector data (deep copied) + ColumnFloat32::Ptr query_value; ///< Query vector data }; #include "common/compile_check_end.h" } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/storage/index/ann/ann_topn_runtime.cpp b/be/src/storage/index/ann/ann_topn_runtime.cpp index 4ac4042395fed4..1742b65065ec63 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.cpp +++ b/be/src/storage/index/ann/ann_topn_runtime.cpp @@ -29,6 +29,7 @@ #include "core/column/column_array.h" #include "core/column/column_const.h" #include "core/column/column_nullable.h" +#include "core/column/column_vector.h" #include "core/data_type/primitive_type.h" #include "exprs/function/array/function_array_distance.h" #include "exprs/vexpr_context.h" @@ -43,7 +44,7 @@ namespace doris::segment_v2 { #include "common/compile_check_begin.h" -Result extract_query_vector(std::shared_ptr arg_expr) { +Result extract_query_vector(std::shared_ptr arg_expr) { if (arg_expr->is_constant() == false) { return ResultError(Status::InvalidArgument("Ann topn expr must be constant, got\n{}", arg_expr->debug_string())); @@ -99,7 +100,14 @@ Result extract_query_vector(std::shared_ptr arg_expr) { values_holder_col = value_nullable_col->get_nested_column_ptr(); } - return values_holder_col; + auto float_col = check_and_get_column_ptr(values_holder_col); + if (float_col.get() == nullptr) { + return ResultError(Status::InvalidArgument( + "Ann topn query vector elements must be Float32, got column: {}", + values_holder_col->get_name())); + } + + return float_col; } Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& row_desc) { @@ -188,10 +196,10 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator* DCHECK(ann_index_iterator != nullptr); DCHECK(_order_by_expr_ctx != nullptr); DCHECK(_order_by_expr_ctx->root() != nullptr); - size_t query_array_size = _query_array->size(); - if (_query_array.get() == nullptr || query_array_size == 0) { + if (_query_array.get() == nullptr || _query_array->size() == 0) { return Status::InternalError("Ann topn query vector is not initialized"); } + size_t query_array_size = _query_array->size(); // TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase. @@ -203,9 +211,8 @@ Status AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator* "Ann topn query vector dimension {} does not match index dimension {}", query_array_size, ann_index_reader->get_dimension()); } - const ColumnFloat32* query = assert_cast(_query_array.get()); segment_v2::AnnTopNParam ann_query_params { - .query_value = query->get_data().data(), + .query_value = _query_array->get_data().data(), .query_value_size = query_array_size, .limit = _limit, ._user_params = _user_params, diff --git a/be/src/storage/index/ann/ann_topn_runtime.h b/be/src/storage/index/ann/ann_topn_runtime.h index 63e04cc30b6256..9ad2cd7df0ba17 100644 --- a/be/src/storage/index/ann/ann_topn_runtime.h +++ b/be/src/storage/index/ann/ann_topn_runtime.h @@ -36,6 +36,7 @@ #pragma once #include "core/column/column.h" +#include "core/column/column_vector.h" #include "core/data_type/primitive_type.h" #include "exprs/vectorized_fn_call.h" #include "exprs/vexpr.h" @@ -49,7 +50,7 @@ namespace doris::segment_v2 { struct AnnIndexStats; class AnnIndexIterator; -Result extract_query_vector(std::shared_ptr arg_expr); +Result extract_query_vector(std::shared_ptr arg_expr); /** * @brief Runtime execution engine for ANN (Approximate Nearest Neighbor) Top-N queries. @@ -162,7 +163,7 @@ class AnnTopNRuntime { size_t _src_column_idx = -1; ///< Source vector column index size_t _dest_column_idx = -1; ///< Destination distance column index segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type - IColumn::Ptr _query_array; ///< Query vector data (contiguous float buffer) + ColumnFloat32::Ptr _query_array; ///< Query vector data (contiguous float buffer) doris::VectorSearchUserParams _user_params; ///< User-defined search parameters }; #include "common/compile_check_end.h" diff --git a/be/test/core/column/column_const_test.cpp b/be/test/core/column/column_const_test.cpp index f6f81ec3aaba4f..e9f57df213bce3 100644 --- a/be/test/core/column/column_const_test.cpp +++ b/be/test/core/column/column_const_test.cpp @@ -41,6 +41,19 @@ TEST(ColumnConstTest, TestCreate) { EXPECT_TRUE(!is_column_const(column_const2->get_data_column())); } +TEST(ColumnConstTest, clone_resized_clones_nested_data) { + auto column_data = ColumnHelper::create_column({7}); + auto column_const = ColumnConst::create(column_data, 3); + + auto cloned = column_const->clone_resized(5); + const auto& cloned_const = assert_cast(*cloned); + + EXPECT_EQ(cloned_const.size(), 5); + EXPECT_EQ(cloned_const.get_data_column_ptr()->size(), 1); + EXPECT_EQ(cloned_const.get_data_column().get_int(0), 7); + EXPECT_NE(column_const->get_data_column_ptr().get(), cloned_const.get_data_column_ptr().get()); +} + TEST(ColumnConstTest, TestFilter) { { auto column_data = ColumnHelper::create_column({7}); diff --git a/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp b/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp index d0a6cbdbbaef5f..e8f289bbf54055 100644 --- a/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp +++ b/be/test/core/data_type_serde/data_type_serde_mysql_test.cpp @@ -77,6 +77,10 @@ class TestBlockSerializer final : public MySQLResultBlockBuffer { public: TestBlockSerializer(RuntimeState* state) : MySQLResultBlockBuffer(state) {} ~TestBlockSerializer() override = default; + size_t queue_size() { + std::lock_guard l(_lock); + return _result_batch_queue.size(); + } std::shared_ptr get_block() { std::lock_guard l(_lock); DCHECK_EQ(_result_batch_queue.size(), 1); @@ -86,7 +90,7 @@ class TestBlockSerializer final : public MySQLResultBlockBuffer { } }; -void serialize_and_deserialize_mysql_test() { +void serialize_and_deserialize_mysql_test(bool dry_run) { Block block; // create_descriptor_tablet(); std::vector> cols { @@ -317,12 +321,25 @@ void serialize_and_deserialize_mysql_test() { auto serializer = std::make_shared(&state); VMysqlResultWriter mysql_writer(serializer, _output_vexpr_ctxs, nullptr, false); - Status st = mysql_writer.write(&runtime_stat, block); + TQueryOptions query_options; + query_options.__set_dry_run_query(dry_run); + runtime_stat.set_query_options(query_options); + + Status st = mysql_writer.init(&runtime_stat); EXPECT_TRUE(st.ok()); + + st = mysql_writer.write(&runtime_stat, block); + EXPECT_TRUE(st.ok()); + EXPECT_EQ(mysql_writer.get_written_rows(), row_num); + EXPECT_EQ(serializer->queue_size(), dry_run ? 0 : 1); } TEST(DataTypeSerDeMysqlTest, ScalaSerDeTest) { - serialize_and_deserialize_mysql_test(); + serialize_and_deserialize_mysql_test(false); +} + +TEST(DataTypeSerDeMysqlTest, DryRunSkipsSerialization) { + serialize_and_deserialize_mysql_test(true); } } // namespace doris diff --git a/be/test/exec/operator/sort_operator_test.cpp b/be/test/exec/operator/sort_operator_test.cpp index 23fa37e57b01ef..bd6c0ee68c32a4 100644 --- a/be/test/exec/operator/sort_operator_test.cpp +++ b/be/test/exec/operator/sort_operator_test.cpp @@ -192,21 +192,20 @@ TEST_F(SortOperatorTest, test_dep) { EXPECT_TRUE(is_ready(source_local_state->dependencies())); { - Block block = ColumnHelper::create_block({}); + MutableBlock merged_block = ColumnHelper::create_block({}); bool eos = false; - auto st = source->get_block(state.get(), &block, &eos); - EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_FALSE(eos); + while (!eos) { + Block block; + auto st = source->get_block(state.get(), &block, &eos); + EXPECT_TRUE(st.ok()) << st.msg(); + EXPECT_TRUE(merged_block.merge(block)); + } + + auto block = merged_block.to_block(); EXPECT_EQ(block.rows(), 6); std::cout << block.dump_data() << std::endl; EXPECT_TRUE(ColumnHelper::block_equal( block, ColumnHelper::create_block({1, 2, 3, 4, 5, 6}))); - - block.clear(); - st = source->get_block(state.get(), &block, &eos); - EXPECT_TRUE(st.ok()) << st.msg(); - EXPECT_TRUE(eos); - EXPECT_EQ(block.rows(), 0); } } diff --git a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp index ab6b03b13c5572..f0c4e05c7e6528 100644 --- a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp +++ b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp @@ -577,7 +577,7 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) { struct MockVDataStreamMgr : public VDataStreamMgr { ~MockVDataStreamMgr() override = default; Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock = true) override { + std::shared_ptr* res) override { *res = recvr; return Status::OK(); } diff --git a/be/test/exec/sort/heap_sorter_test.cpp b/be/test/exec/sort/heap_sorter_test.cpp index 90b06764175f1e..9c91db2e5e3833 100644 --- a/be/test/exec/sort/heap_sorter_test.cpp +++ b/be/test/exec/sort/heap_sorter_test.cpp @@ -100,20 +100,20 @@ TEST_F(HeapSorterTest, test_topn_sorter1) { EXPECT_TRUE(sorter->prepare_for_read(false)); { - Block block; + MutableBlock merged_block = ColumnHelper::create_block({}, {}); bool eos = false; - EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); + while (!eos) { + Block block; + EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); + EXPECT_TRUE(merged_block.merge(block)); + } + + auto block = merged_block.to_block(); EXPECT_EQ(block.rows(), 6); EXPECT_TRUE(ColumnHelper::block_equal( block, Block {ColumnHelper::create_column_with_name({1, 2, 3, 4, 5, 6}), ColumnHelper::create_column_with_name({1, 2, 3, 4, 5, 6})})); - - block.clear_column_data(); - - EXPECT_TRUE(sorter->get_next(&_state, &block, &eos)); - EXPECT_EQ(block.rows(), 0); - EXPECT_EQ(eos, true); } } diff --git a/be/test/exec/sort/merge_sorter_state.cpp b/be/test/exec/sort/merge_sorter_state.cpp index 0dc8a1a8937164..7af89e7cbdf70b 100644 --- a/be/test/exec/sort/merge_sorter_state.cpp +++ b/be/test/exec/sort/merge_sorter_state.cpp @@ -101,4 +101,71 @@ TEST_F(MergeSorterStateTest, test1) { ColumnHelper::create_block({5, 6}))); } } + +TEST_F(MergeSorterStateTest, whole_block_fast_path_swaps_block) { + state.reset(new MergeSorterState(*row_desc, 0)); + auto first_block = create_block({1, 2, 3}); + auto second_block = create_block({4, 5, 6}); + auto first_column = first_block->get_by_position(0).column; + + state->add_sorted_block(first_block); + state->add_sorted_block(second_block); + + SortDescription desc {SortColumnDescription {0, 1, -1}}; + ASSERT_TRUE(state->build_merge_tree(desc)); + + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 3, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE( + ColumnHelper::block_equal(block, ColumnHelper::create_block({1, 2, 3}))); + EXPECT_EQ(block.get_by_position(0).column.get(), first_column.get()); +} + +TEST_F(MergeSorterStateTest, whole_block_fast_path_allows_smaller_than_batch) { + state.reset(new MergeSorterState(*row_desc, 0)); + auto first_block = create_block({1, 2, 3}); + auto second_block = create_block({4, 5, 6}); + auto first_column = first_block->get_by_position(0).column; + auto second_column = second_block->get_by_position(0).column; + + state->add_sorted_block(first_block); + state->add_sorted_block(second_block); + + SortDescription desc {SortColumnDescription {0, 1, -1}}; + ASSERT_TRUE(state->build_merge_tree(desc)); + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE(ColumnHelper::block_equal( + block, ColumnHelper::create_block({1, 2, 3}))); + EXPECT_EQ(block.get_by_position(0).column.get(), first_column.get()); + } + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(eos); + EXPECT_TRUE(ColumnHelper::block_equal( + block, ColumnHelper::create_block({4, 5, 6}))); + EXPECT_EQ(block.get_by_position(0).column.get(), second_column.get()); + } + + { + Block block; + bool eos = false; + Status status = state->merge_sort_read(&block, 4, &eos); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(eos); + EXPECT_EQ(block.rows(), 0); + } +} } // namespace doris \ No newline at end of file diff --git a/be/test/exprs/aggregate/aggregate_function_exception_test.cpp b/be/test/exprs/aggregate/aggregate_function_exception_test.cpp new file mode 100644 index 00000000000000..21ee64dba4aef1 --- /dev/null +++ b/be/test/exprs/aggregate/aggregate_function_exception_test.cpp @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "core/arena.h" +#include "exprs/aggregate/aggregate_function.h" + +namespace doris { + +struct TrackingAggregateState { + TrackingAggregateState() { ++construct_count; } + ~TrackingAggregateState() { ++destroy_count; } + + static void reset_counters() { + construct_count = 0; + destroy_count = 0; + } + + static int construct_count; + static int destroy_count; +}; + +int TrackingAggregateState::construct_count = 0; +int TrackingAggregateState::destroy_count = 0; + +class ThrowOnDeserializeAggregateFunction final + : public IAggregateFunctionDataHelper { +public: + ThrowOnDeserializeAggregateFunction() + : IAggregateFunctionDataHelper( + DataTypes {std::make_shared()}) {} + + String get_name() const override { return "throw_on_deserialize"; } + + DataTypePtr get_return_type() const override { return std::make_shared(); } + + void add(AggregateDataPtr, const IColumn**, ssize_t, Arena&) const override {} + + void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena&) const override {} + + void serialize(ConstAggregateDataPtr, BufferWritable& buf) const override { + String payload; + buf.write_binary(payload); + } + + void deserialize(AggregateDataPtr, BufferReadable& buf, Arena&) const override { + String payload; + buf.read_binary(payload); + if (payload == "throw") { + throw Exception(ErrorCode::INTERNAL_ERROR, "mock deserialize failure"); + } + } + + void insert_result_into(ConstAggregateDataPtr, IColumn&) const override {} +}; + +class AggregateFunctionExceptionTest : public testing::Test { +protected: + void SetUp() override { TrackingAggregateState::reset_counters(); } + + MutableColumnPtr make_column(std::initializer_list payloads) { + auto column = ColumnString::create(); + VectorBufferWriter writer(*column); + for (const auto& payload : payloads) { + writer.write_binary(payload); + writer.commit(); + } + return column; + } + + ThrowOnDeserializeAggregateFunction function; + Arena arena; +}; + +TEST_F(AggregateFunctionExceptionTest, DeserializeVecDestroysCurrentStateOnFailure) { + auto column = make_column({"ok", "throw"}); + std::vector states(function.size_of_data() * 2); + + bool thrown = false; + try { + function.deserialize_vec(states.data(), static_cast(column.get()), arena, 2); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + if (!thrown) { + function.destroy_vec(states.data(), 2); + } + EXPECT_EQ(TrackingAggregateState::construct_count, 2); + EXPECT_EQ(TrackingAggregateState::destroy_count, 2); +} + +TEST_F(AggregateFunctionExceptionTest, DeserializeAndMergeVecDestroysRhsStateOnFailure) { + auto column = make_column({"throw"}); + std::vector place_storage(function.size_of_data()); + std::vector rhs_storage(function.size_of_data()); + auto* place = place_storage.data(); + function.create(place); + + std::array places {place}; + const auto destroy_count_before_call = TrackingAggregateState::destroy_count; + bool thrown = false; + try { + function.deserialize_and_merge_vec(places.data(), 0, rhs_storage.data(), column.get(), + arena, 1); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + EXPECT_EQ(TrackingAggregateState::destroy_count - destroy_count_before_call, 1); + + function.destroy(place); + EXPECT_EQ(TrackingAggregateState::construct_count, TrackingAggregateState::destroy_count); +} + +TEST_F(AggregateFunctionExceptionTest, + DeserializeAndMergeVecSelectedDestroysAllCreatedRhsStatesOnFailure) { + auto column = make_column({"skip", "throw"}); + std::vector place_storage(function.size_of_data()); + std::vector rhs_storage(function.size_of_data() * 2); + auto* place = place_storage.data(); + function.create(place); + + std::array places {nullptr, place}; + const auto destroy_count_before_call = TrackingAggregateState::destroy_count; + bool thrown = false; + try { + function.deserialize_and_merge_vec_selected(places.data(), 0, rhs_storage.data(), + column.get(), arena, 2); + } catch (const Exception&) { + thrown = true; + } + + EXPECT_TRUE(thrown); + EXPECT_EQ(TrackingAggregateState::destroy_count - destroy_count_before_call, 2); + + function.destroy(place); + EXPECT_EQ(TrackingAggregateState::construct_count, TrackingAggregateState::destroy_count); +} + +} // namespace doris \ No newline at end of file diff --git a/be/test/exprs/function/function_math_test.cpp b/be/test/exprs/function/function_math_test.cpp index 4e51a5dc3e700b..cf1b3a442ea686 100644 --- a/be/test/exprs/function/function_math_test.cpp +++ b/be/test/exprs/function/function_math_test.cpp @@ -18,14 +18,17 @@ #include #include #include +#include #include +#include "core/column/column_const.h" #include "core/data_type/data_type_decimal.h" #include "core/data_type/data_type_number.h" #include "core/data_type/data_type_string.h" #include "core/types.h" #include "exprs/function/function_test_util.h" #include "testutil/any_type.h" +#include "testutil/column_helper.h" namespace doris { @@ -532,6 +535,11 @@ TEST(MathFunctionTest, hex_test) { } TEST(MathFunctionTest, random_test) { +#ifndef NDEBUG + GTEST_SKIP() << "random(seed) exact-value assertions are release-only; debug builds run " + "mock_const_execute before the real call."; +#endif + std::string func_name = "random"; // random(x) InputTypeSet input_types = {Consted {PrimitiveType::TYPE_BIGINT}}; DataSet data_set = {{{Null()}, Null()}, @@ -547,6 +555,56 @@ TEST(MathFunctionTest, random_test) { } } +TEST(MathFunctionTest, uniform_mixed_const_probe_test) { + auto input_type = std::make_shared(); + auto return_type = std::make_shared(); + + Block block; + auto min_data = ColumnHelper::create_column({1}); + auto max_data = ColumnHelper::create_column({10}); + auto seed_column = ColumnHelper::create_column({101, 202, 303}); + + block.insert({ColumnConst::create(min_data, 3), input_type, "min"}); + block.insert({ColumnConst::create(max_data, 3), input_type, "max"}); + block.insert({seed_column, input_type, "seed"}); + + FunctionBasePtr function = SimpleFunctionFactory::instance().get_function( + "uniform", block.get_columns_with_type_and_name(), return_type); + ASSERT_TRUE(function != nullptr); + + block.insert({nullptr, return_type, "result"}); + + FunctionUtils fn_utils(return_type, {input_type, input_type, input_type}, false); + auto* fn_ctx = fn_utils.get_fn_ctx(); + std::vector> constant_cols { + std::make_shared(block.get_by_position(0).column), + std::make_shared(block.get_by_position(1).column), + nullptr, + }; + fn_ctx->set_constant_cols(constant_cols); + + ASSERT_TRUE(function->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL).ok()); + ASSERT_TRUE(function->open(fn_ctx, FunctionContext::THREAD_LOCAL).ok()); + + auto exec_status = function->execute(fn_ctx, block, {0, 1, 2}, 3, 3); + + static_cast(function->close(fn_ctx, FunctionContext::THREAD_LOCAL)); + static_cast(function->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + + ASSERT_TRUE(exec_status.ok()) << exec_status.to_string(); + + const auto& result_column = assert_cast(*block.get_by_position(3).column); + auto expected_uniform = [](int64_t seed) { + std::mt19937_64 generator(seed); + std::uniform_int_distribution distribution(1, 10); + return distribution(generator); + }; + + EXPECT_EQ(result_column.get_element(0), expected_uniform(101)); + EXPECT_EQ(result_column.get_element(1), expected_uniform(202)); + EXPECT_EQ(result_column.get_element(2), expected_uniform(303)); +} + TEST(MathFunctionTest, conv_test) { std::string func_name = "conv"; diff --git a/be/test/exprs/function/function_quantile_state_test.cpp b/be/test/exprs/function/function_quantile_state_test.cpp index 1cb1ced1dae561..e8f2fca702895f 100644 --- a/be/test/exprs/function/function_quantile_state_test.cpp +++ b/be/test/exprs/function/function_quantile_state_test.cpp @@ -213,4 +213,21 @@ TEST(function_quantile_state_test, function_quantile_state_roundtrip) { 0.01); } +TEST(function_quantile_state_test, function_quantile_percent_mixed_const_test) { + std::string func_name = "quantile_percent"; + InputTypeSet input_types = {PrimitiveType::TYPE_QUANTILE_STATE, + ConstedNotnull {PrimitiveType::TYPE_FLOAT}}; + + QuantileState quantile_state; + quantile_state.add_value(1.0); + quantile_state.add_value(2.0); + quantile_state.add_value(3.0); + quantile_state.add_value(4.0); + quantile_state.add_value(5.0); + + DataSet data_set = {{{&quantile_state, 0.5F}, 3.0}}; + + static_cast(check_function(func_name, input_types, data_set)); +} + } // namespace doris diff --git a/be/test/exprs/function/function_string_test.cpp b/be/test/exprs/function/function_string_test.cpp index edf888f2c8f1b3..90456da258a960 100644 --- a/be/test/exprs/function/function_string_test.cpp +++ b/be/test/exprs/function/function_string_test.cpp @@ -3854,4 +3854,20 @@ TEST(function_string_test, function_unicode_normalize_invalid_mode) { EXPECT_NE(Status::OK(), st); } +TEST(function_string_test, function_regexp_count_mixed_const_test) { + std::string func_name = "regexp_count"; + + InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR, PrimitiveType::TYPE_VARCHAR}; + DataSet data_set = { + {{std::string("a.b:c;d"), std::string("[.:;]")}, std::int32_t(3)}, + {{std::string("a1b2346c3d"), std::string("\\d+")}, std::int32_t(3)}, + {{std::string("abcd"), std::string("")}, std::int32_t(0)}, + {{std::string("book keeper"), std::string("oo|ee")}, std::int32_t(2)}, + {{Null(), std::string("\\d+")}, Null()}, + {{std::string("abcd"), Null()}, Null()}, + }; + + check_function_all_arg_comb(func_name, input_types, data_set); +} + } // namespace doris diff --git a/be/test/storage/index/ann/ann_range_search_test.cpp b/be/test/storage/index/ann/ann_range_search_test.cpp index 400e822695ca32..890856b8b0a925 100644 --- a/be/test/storage/index/ann/ann_range_search_test.cpp +++ b/be/test/storage/index/ann/ann_range_search_test.cpp @@ -100,8 +100,9 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) { EXPECT_EQ(ann_range_search_runtime.radius, 10.0f); std::vector query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20}; std::vector query_array_f32; + const auto& query_value = range_search_ctx->_ann_range_search_runtime.query_value; for (int i = 0; i < query_array_groud_truth.size(); ++i) { - query_array_f32.push_back(static_cast(ann_range_search_runtime.query_value[i])); + query_array_f32.push_back(static_cast(query_value->get_data()[i])); } for (int i = 0; i < query_array_f32.size(); ++i) { EXPECT_EQ(query_array_f32[i], query_array_groud_truth[i]); diff --git a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp index 2cb9f293ee583b..880f42f6a9dd86 100644 --- a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp +++ b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp @@ -116,8 +116,7 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) { ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(), predicate->get_order_by_expr_ctx()->root()->debug_string()); - const ColumnFloat32* query_column = - assert_cast(predicate->_query_array.get()); + const auto& query_column = predicate->_query_array; const float* query_value = query_column->get_data().data(); const size_t query_value_size = predicate->_query_array->size(); ASSERT_EQ(query_value_size, 8); diff --git a/be/test/storage/index/ann/extract_query_vector_test.cpp b/be/test/storage/index/ann/extract_query_vector_test.cpp index 8fd6850218ec10..22ab34ab32bc9a 100644 --- a/be/test/storage/index/ann/extract_query_vector_test.cpp +++ b/be/test/storage/index/ann/extract_query_vector_test.cpp @@ -178,7 +178,7 @@ TEST_F(ExtractQueryVectorTest, ValuesMatchInput) { auto result = extract_query_vector(mock); ASSERT_TRUE(result.has_value()); - auto* float_col = assert_cast(result.value().get()); + const auto& float_col = result.value(); ASSERT_EQ(float_col->size(), 4u); for (size_t i = 0; i < input.size(); ++i) { EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]); @@ -240,4 +240,20 @@ TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) { EXPECT_TRUE(result.error().to_string().find("Array literal") != std::string::npos); } +TEST_F(ExtractQueryVectorTest, NonFloatArrayFails) { + auto int_col = ColumnInt32::create(); + int_col->insert_value(1); + int_col->insert_value(2); + auto offsets = ColumnArray::ColumnOffsets::create(); + offsets->insert_value(2); + auto array_col = ColumnArray::create(std::move(int_col), std::move(offsets)); + + auto mock = std::make_shared(); + mock->set_column(std::move(array_col)); + + auto result = extract_query_vector(mock); + ASSERT_FALSE(result.has_value()); + EXPECT_TRUE(result.error().to_string().find("must be Float32") != std::string::npos); +} + } // namespace doris::segment_v2 diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java index c93f151cf0c092..4e1a768bd97fe9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinder.java @@ -55,6 +55,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MicroSecondsDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteMicrosecondAdd; @@ -301,9 +302,11 @@ private Expression processTimestampDiff(TimeUnit unit, Expression start, Express return new MinutesDiff(end, start); case SECOND: return new SecondsDiff(end, start); + case MICROSECOND: + return new MicroSecondsDiff(end, start); default: throw new AnalysisException("Unsupported time stamp diff time unit: " + unit - + ", supported time unit: YEAR/QUARTER/MONTH/WEEK/DAY/HOUR/MINUTE/SECOND"); + + ", supported time unit: YEAR/QUARTER/MONTH/WEEK/DAY/HOUR/MINUTE/SECOND/MICROSECOND"); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java index 275e0f74fe1bc9..f490c225c444c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Interval.java @@ -106,6 +106,7 @@ public enum TimeUnit { MINUTE_SECOND("MINUTE_SECOND", false, 200), MINUTE_MICROSECOND("MINUTE_MICROSECOND", false, 200), SECOND("SECOND", true, 100), + MICROSECOND("MICROSECOND", true, 0), SECOND_MICROSECOND("SECOND_MICROSECOND", true, 100); private final String description; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java index 81f24ed878bd4d..a63e4a3e6282a2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/DatetimeFunctionBinderTest.java @@ -42,6 +42,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MicroSecondsDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteCeil; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinuteFloor; import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd; @@ -110,6 +111,8 @@ public class DatetimeFunctionBinderTest { TinyIntType.INSTANCE, false, ImmutableList.of()); private final SlotReference secondUnit = new SlotReference(new ExprId(-1), "SECOND", TinyIntType.INSTANCE, false, ImmutableList.of()); + private final SlotReference microsecondUnit = new SlotReference(new ExprId(-1), "MICROSECOND", + TinyIntType.INSTANCE, false, ImmutableList.of()); private final SlotReference invalidUnit = new SlotReference(new ExprId(-1), "INVALID", TinyIntType.INSTANCE, false, ImmutableList.of()); @@ -172,6 +175,13 @@ void testTimestampDiff() { Assertions.assertEquals(dateTimeV2Literal2, result.child(0)); Assertions.assertEquals(dateTimeV2Literal1, result.child(1)); + timeDiff = new UnboundFunction(functionName, ImmutableList.of( + microsecondUnit, dateTimeV2Literal1, dateTimeV2Literal2)); + result = DatetimeFunctionBinder.INSTANCE.bind(timeDiff); + Assertions.assertInstanceOf(MicroSecondsDiff.class, result); + Assertions.assertEquals(dateTimeV2Literal2, result.child(0)); + Assertions.assertEquals(dateTimeV2Literal1, result.child(1)); + Assertions.assertThrowsExactly(AnalysisException.class, () -> DatetimeFunctionBinder.INSTANCE.bind( new UnboundFunction(functionName, ImmutableList.of(invalidUnit, diff --git a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out index 850cbe14a980d5..f7ff2eb36d0dbe 100644 --- a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out +++ b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_agg_functions.out @@ -11,3 +11,6 @@ true -- !group_array_union -- 3 +-- !group_array_nested_timestamptz -- +[["2024-01-01 00:00:00.000000+00:00", "2024-01-01 00:00:00.000000+00:00", "2024-01-02 00:00:00.000000+00:00"], ["2024-01-01 00:00:00.000000+00:00", "2024-01-02 00:00:00.000000+00:00", "2024-01-03 00:00:00.000000+00:00"]] + diff --git a/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out new file mode 100644 index 00000000000000..43746eee1800bd --- /dev/null +++ b/regression-test/data/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.out @@ -0,0 +1,43 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !value_hit -- +true + +-- !value_miss -- +false + +-- !value_miss_key -- +false + +-- !key_hit -- +true + +-- !key_miss_value -- +false + +-- !key_miss_key -- +false + +-- !table_value_hit -- +1 true +2 false + +-- !table_value_miss -- +1 false +2 false + +-- !table_key_hit -- +1 true +2 false + +-- !table_key_miss -- +1 false +2 false + +-- !null_search_key -- +1 false +2 false + +-- !null_search_value -- +1 false +2 false + diff --git a/regression-test/data/nereids_syntax_p0/test_timestampdiff.out b/regression-test/data/nereids_syntax_p0/test_timestampdiff.out index 0e2dd6a537559e..15623515ed485f 100644 --- a/regression-test/data/nereids_syntax_p0/test_timestampdiff.out +++ b/regression-test/data/nereids_syntax_p0/test_timestampdiff.out @@ -17,3 +17,9 @@ -- !select -- 40 +-- !select -- +876543 + +-- !select -- +2024-01-01T10:00:00.999999 2024-01-01T10:00:00.123456 876543 + diff --git a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy index 89126b5a284772..e5bf945225ef45 100644 --- a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy +++ b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_agg_functions.groovy @@ -56,4 +56,41 @@ suite("test_timestamptz_agg_functions", "datatype_p0") { qt_group_array_union "SELECT size(group_array_union(arr)) FROM test_tz_agg" sql "DROP TABLE IF EXISTS test_tz_agg" + + sql "DROP TABLE IF EXISTS tz_group_array_crash" + sql """ + CREATE TABLE tz_group_array_crash ( + grp INT, + arr ARRAY + ) + DUPLICATE KEY(grp) + DISTRIBUTED BY HASH(grp) BUCKETS 1 + PROPERTIES('replication_num' = '1') + """ + + sql """ + INSERT INTO tz_group_array_crash VALUES + ( + 1, + ARRAY( + CAST('2024-01-01 00:00:00 +00:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-01 08:00:00 +08:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-02 00:00:00 +00:00' AS TIMESTAMPTZ(6)) + ) + ), + ( + 1, + ARRAY( + CAST('2024-01-01 00:00:00 +00:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-02 08:00:00 +08:00' AS TIMESTAMPTZ(6)), + CAST('2024-01-03 00:00:00 +00:00' AS TIMESTAMPTZ(6)) + ) + ) + """ + + qt_group_array_nested_timestamptz """ + SELECT CAST(array_sort(group_array(arr)) AS STRING) + FROM tz_group_array_crash + GROUP BY grp + """ } diff --git a/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy new file mode 100644 index 00000000000000..2b814ef8b6e539 --- /dev/null +++ b/regression-test/suites/datatype_p0/timestamptz/test_timestamptz_map_contains_entry.groovy @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_timestamptz_map_contains_entry") { + + sql "set time_zone = '+08:00';" + sql "set enable_nereids_planner = true;" + sql "set enable_fallback_to_original_planner = false;" + + // --- inline literal tests (no table needed) --- + + // TIMESTAMPTZ as map value: hit + qt_value_hit """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'a', + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map value: miss (wrong value) + qt_value_miss """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map value: miss (wrong key) + qt_value_miss_key """ + SELECT map_contains_entry( + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + 'c', + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)) + ); + """ + + // TIMESTAMPTZ as map key: hit + qt_key_hit """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'a' + ); + """ + + // TIMESTAMPTZ as map key: miss (wrong value) + qt_key_miss_value """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b' + ); + """ + + // TIMESTAMPTZ as map key: miss (wrong key) + qt_key_miss_key """ + SELECT map_contains_entry( + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b'), + cast('2024-01-03 00:00:00.000000 +00:00' as timestamptz(6)), + 'a' + ); + """ + + // --- table-based tests --- + + sql "DROP TABLE IF EXISTS test_timestamptz_map_contains_entry_t;" + sql """ + CREATE TABLE test_timestamptz_map_contains_entry_t ( + id INT, + map_s_tz MAP, + map_tz_s MAP + ) + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES("replication_num" = "1"); + """ + + sql """ + INSERT INTO test_timestamptz_map_contains_entry_t VALUES ( + 1, + map('a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), + 'b', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))), + map(cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a', + cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6)), 'b') + ), ( + 2, + map('x', cast('2024-06-15 12:00:00.000000 +05:30' as timestamptz(6))), + map(cast('2024-06-15 12:00:00.000000 +05:30' as timestamptz(6)), 'x') + ); + """ + + // TIMESTAMPTZ as map value, hit + qt_table_value_hit """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map value, miss + qt_table_value_miss """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast('2024-01-02 03:04:05.123456 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map key, hit + qt_table_key_hit """ + SELECT id, map_contains_entry(map_tz_s, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'a') + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // TIMESTAMPTZ as map key, miss + qt_table_key_miss """ + SELECT id, map_contains_entry(map_tz_s, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6)), 'b') + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // NULL search key + qt_null_search_key """ + SELECT id, map_contains_entry(map_s_tz, NULL, cast('2024-01-01 00:00:00.000000 +00:00' as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ + + // NULL search value + qt_null_search_value """ + SELECT id, map_contains_entry(map_s_tz, 'a', cast(NULL as timestamptz(6))) + FROM test_timestamptz_map_contains_entry_t + ORDER BY id; + """ +} diff --git a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy index 68642fa31ec91b..f43bc4cb6eea97 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy @@ -62,6 +62,8 @@ suite("nereids_scalar_fn_U") { def result = sql """select uniform(1, 100, random()*10000) from numbers("number" = "10");""" assertTrue(result.size() == 10) + def doubleResult = sql """select uniform(1.23, 100.100, random()*10000) from numbers("number" = "10");""" + assertTrue(doubleResult.size() == 10) test { sql """select uniform(100, 1, random()*10000) from numbers("number" = "10");""" exception "uniform's min should be less than max" diff --git a/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy b/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy index 34500732e22920..0a3e563bd7f866 100644 --- a/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy +++ b/regression-test/suites/nereids_syntax_p0/test_timestampdiff.groovy @@ -37,4 +37,32 @@ suite("test_timestampdiff") { qt_select """ SELECT TIMESTAMPDIFF(second,'2003-02-03 11:00:00','2003-02-03 11:00:40'); """ + + qt_select """ + SELECT TIMESTAMPDIFF(microsecond, + CAST('2024-01-01 10:00:00.123456' AS DATETIMEV2(6)), + CAST('2024-01-01 10:00:00.999999' AS DATETIMEV2(6))); + """ + + sql """drop table if exists test_timestampdiff_microsecond""" + sql """ + create table test_timestampdiff_microsecond ( + id int, + t datetimev2(6) + ) + duplicate key(id) + distributed by hash(id) buckets 1 + properties("replication_num" = "1"); + """ + + sql """ + insert into test_timestampdiff_microsecond values + (1, '2024-01-01 10:00:00.123456'), + (2, '2024-01-01 10:00:00.999999'); + """ + + qt_select """ + SELECT MAX(t), MIN(t), TIMESTAMPDIFF(MICROSECOND, MIN(t), MAX(t)) + FROM test_timestampdiff_microsecond; + """ } \ No newline at end of file