diff --git a/include/dsn/cpp/serverlet.h b/include/dsn/cpp/serverlet.h index 423ee72e5..35389c0ba 100644 --- a/include/dsn/cpp/serverlet.h +++ b/include/dsn/cpp/serverlet.h @@ -71,12 +71,17 @@ namespace dsn void operator () (const TResponse& resp) { - if (_response != nullptr) + if (_response == nullptr) { - ::dsn::marshall(_response, resp); - auto err = dsn_rpc_reply(_response); - dassert(err == ERR_OK, "dsn_rpc_reply failed: %s", error_code(err).to_string()); + derror("rpc_replier got null response"); + return; } + + auto err = ::dsn::try_marshall(_response, resp); + dassert(err == ERR_OK, "marshall response failed: %s", err.to_string()); + + auto reply_err = dsn_rpc_reply(_response); + dassert(reply_err == ERR_OK, "dsn_rpc_reply failed: %s", error_code(reply_err).to_string()); } bool is_empty() const @@ -326,9 +331,12 @@ namespace dsn { auto msg = dsn_msg_create_response(request); dassert(msg != nullptr, "dsn_msg_create_response failed"); - ::dsn::marshall(msg, resp); - auto err = dsn_rpc_reply(msg); - dassert(err == ERR_OK, "dsn_rpc_reply failed: %s", error_code(err).to_string()); + + auto err = ::dsn::try_marshall(msg, resp); + dassert(err == ERR_OK, "marshall response failed: %s", err.to_string()); + + auto reply_err = dsn_rpc_reply(msg); + dassert(reply_err == ERR_OK, "dsn_rpc_reply failed: %s", error_code(reply_err).to_string()); } /*@}*/ } // end namespace diff --git a/include/dsn/cpp/test_utils.h b/include/dsn/cpp/test_utils.h index f4891567e..ba98560bf 100644 --- a/include/dsn/cpp/test_utils.h +++ b/include/dsn/cpp/test_utils.h @@ -56,6 +56,8 @@ DEFINE_TASK_CODE_RPC(RPC_TEST_HASH1, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERV DEFINE_TASK_CODE_RPC(RPC_TEST_HASH2, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) DEFINE_TASK_CODE_RPC(RPC_TEST_HASH3, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) DEFINE_TASK_CODE_RPC(RPC_TEST_HASH4, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) +DEFINE_TASK_CODE_RPC(RPC_TEST_HASH5, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) +DEFINE_TASK_CODE_RPC(RPC_TEST_HASH6, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) DEFINE_TASK_CODE_RPC(RPC_TEST_STRING_COMMAND, TASK_PRIORITY_COMMON, THREAD_POOL_TEST_SERVER) DEFINE_TASK_CODE_AIO(LPC_AIO_TEST, TASK_PRIORITY_COMMON, THREAD_POOL_DEFAULT) @@ -124,6 +126,8 @@ class test_client : register_async_rpc_handler(RPC_TEST_HASH2, "rpc.test.hash2", &test_client::on_rpc_test); register_async_rpc_handler(RPC_TEST_HASH3, "rpc.test.hash3", &test_client::on_rpc_test); register_async_rpc_handler(RPC_TEST_HASH4, "rpc.test.hash4", &test_client::on_rpc_test); + register_async_rpc_handler(RPC_TEST_HASH5, "rpc.test.hash5", &test_client::on_rpc_test); + register_async_rpc_handler(RPC_TEST_HASH6, "rpc.test.hash6", &test_client::on_rpc_test); register_rpc_handler(RPC_TEST_STRING_COMMAND, "rpc.test.string.command", &test_client::on_rpc_string_test); } diff --git a/include/dsn/tool-api/message_parser.h b/include/dsn/tool-api/message_parser.h index f08435b3a..9b08dfe97 100644 --- a/include/dsn/tool-api/message_parser.h +++ b/include/dsn/tool-api/message_parser.h @@ -53,11 +53,14 @@ namespace dsn : _buffer_occupied(0), _buffer_block_size(buffer_block_size) {} ~message_reader() {} - // called before read to extend read buffer + // called before read to extend read buffer; returns nullptr if the buffer cannot be prepared. DSN_API char* read_buffer_ptr(unsigned int read_next); // get remaining buffer capacity - unsigned int read_buffer_capacity() const { return _buffer.length() - _buffer_occupied; } + unsigned int read_buffer_capacity() const + { + return _buffer.length() >= _buffer_occupied ? _buffer.length() - _buffer_occupied : 0; + } // called after read to mark data occupied void mark_read(unsigned int read_length) { _buffer_occupied += read_length; } diff --git a/include/dsn/tool-api/rpc_message.h b/include/dsn/tool-api/rpc_message.h index 939af077d..b7d874fc9 100644 --- a/include/dsn/tool-api/rpc_message.h +++ b/include/dsn/tool-api/rpc_message.h @@ -151,11 +151,11 @@ namespace dsn // // routines for buffer management // - DSN_API void write_next(void** ptr, size_t* size, size_t min_size); - DSN_API void write_commit(size_t size); + DSN_API bool write_next(void** ptr, size_t* size, size_t min_size); + DSN_API bool write_commit(size_t size); DSN_API void write_append(const blob& data); DSN_API bool read_next(void** ptr, size_t* size); - DSN_API void read_commit(size_t size); + DSN_API bool read_commit(size_t size); size_t body_size() { return (size_t)header->body_length; } DSN_API void* rw_ptr(size_t offset_begin); diff --git a/include/dsn/utility/factory_store.h b/include/dsn/utility/factory_store.h index 34a0c2f15..adb099b06 100644 --- a/include/dsn/utility/factory_store.h +++ b/include/dsn/utility/factory_store.h @@ -183,8 +183,6 @@ namespace dsn { fprintf(stderr, "\t\t%s (type: %s)\n", it->c_str(), entry.type == PROVIDER_TYPE_MAIN ? "provider" : "aspect"); } fprintf(stderr, "\tPlease specify the correct factory name in your tool_app or in configuration file\n"); - - std::abort(); } private: diff --git a/src/core/src/address.cpp b/src/core/src/address.cpp index 1a433c0aa..476713f02 100644 --- a/src/core/src/address.cpp +++ b/src/core/src/address.cpp @@ -64,6 +64,7 @@ # include # include "group_address.h" # include "uri_address.h" +# include "c_api_guard.h" namespace dsn { @@ -99,6 +100,7 @@ static bool net_init() // name to ip etc. DSN_API uint32_t dsn_ipv4_from_host(const char* name) { + DSN_C_GUARD_BEGIN if ((name == nullptr) || (name[0] == '\0')) { derror("dsn_ipv4_from_host got null or empty name"); @@ -153,6 +155,7 @@ DSN_API uint32_t dsn_ipv4_from_host(const char* name) // converts from network byte order to host byte order return (uint32_t)ntohl(addr.sin_addr.s_addr); + DSN_C_GUARD_END(0) } static bool interface_has_prefix(const char* network_interface, const char* prefix) @@ -201,6 +204,7 @@ static bool is_default_interface(const struct ifaddrs* ifa) // an address that is inconsistent with localhost-based tests. DSN_API uint32_t dsn_ipv4_local(const char* network_interface) { + DSN_C_GUARD_BEGIN uint32_t ret = 0; # if !defined(_WIN32) @@ -283,6 +287,7 @@ DSN_API uint32_t dsn_ipv4_local(const char* network_interface) # endif return ret; + DSN_C_GUARD_END(0) } DSN_API const char* dsn_address_to_string(dsn_address_t addr) @@ -393,6 +398,7 @@ DSN_API dsn_address_t dsn_address_build_uri( DSN_API dsn_group_t dsn_group_build(const char* name) // must be paired with release later { + DSN_C_GUARD_BEGIN if (name == nullptr || name[0] == '\0') { derror("dsn_group_build got null or empty name"); @@ -401,6 +407,7 @@ DSN_API dsn_group_t dsn_group_build(const char* name) // must be paired with rel auto g = new ::dsn::rpc_group_address(name); return g; + DSN_C_GUARD_END(nullptr) } DSN_API int dsn_group_count(dsn_group_t g) @@ -417,6 +424,7 @@ DSN_API int dsn_group_count(dsn_group_t g) DSN_API bool dsn_group_add(dsn_group_t g, dsn_address_t ep) { + DSN_C_GUARD_BEGIN if (g == nullptr) { derror("dsn_group_add got null group"); @@ -426,6 +434,7 @@ DSN_API bool dsn_group_add(dsn_group_t g, dsn_address_t ep) auto grp = (::dsn::rpc_group_address*)(g); ::dsn::rpc_address addr(ep); return grp->add(addr); + DSN_C_GUARD_END(false) } DSN_API void dsn_group_set_leader(dsn_group_t g, dsn_address_t ep) @@ -517,6 +526,7 @@ DSN_API dsn_address_t dsn_group_forward_leader(dsn_group_t g) DSN_API bool dsn_group_remove(dsn_group_t g, dsn_address_t ep) { + DSN_C_GUARD_BEGIN if (g == nullptr) { derror("dsn_group_remove got null group"); @@ -526,6 +536,7 @@ DSN_API bool dsn_group_remove(dsn_group_t g, dsn_address_t ep) auto grp = (::dsn::rpc_group_address*)(g); ::dsn::rpc_address addr(ep); return grp->remove(addr); + DSN_C_GUARD_END(false) } DSN_API void dsn_group_destroy(dsn_group_t g) @@ -542,6 +553,7 @@ DSN_API void dsn_group_destroy(dsn_group_t g) DSN_API dsn_uri_t dsn_uri_build(const char* url) // must be paired with destroy later { + DSN_C_GUARD_BEGIN if (url == nullptr || url[0] == '\0') { derror("dsn_uri_build got null or empty url"); @@ -549,6 +561,7 @@ DSN_API dsn_uri_t dsn_uri_build(const char* url) // must be paired with destroy } return (dsn_uri_t)new ::dsn::rpc_uri_address(url); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_uri_destroy(dsn_uri_t uri) diff --git a/src/core/src/app_manager.cpp b/src/core/src/app_manager.cpp index f6e5c745a..cc9bc3177 100644 --- a/src/core/src/app_manager.cpp +++ b/src/core/src/app_manager.cpp @@ -36,9 +36,11 @@ # include "service_engine.h" # include "rpc_engine.h" # include +# include "c_api_guard.h" DSN_API bool dsn_register_app(dsn_app* app_type) { + DSN_C_GUARD_BEGIN if (app_type == nullptr) { derror("dsn_register_app got null app_type"); @@ -72,10 +74,12 @@ DSN_API bool dsn_register_app(dsn_app* app_type) delete app; } return r; + DSN_C_GUARD_END(false) } DSN_API bool dsn_get_app_callbacks(const char* name, /* out */ dsn_app_callbacks* callbacks) { + DSN_C_GUARD_BEGIN if (name == nullptr || name[0] == '\0') { derror("dsn_get_app_callbacks got null or empty name"); @@ -100,6 +104,7 @@ DSN_API bool dsn_get_app_callbacks(const char* name, /* out */ dsn_app_callbacks dwarn("application model '%s' is not found, make sure it is registered", name); return false; } + DSN_C_GUARD_END(false) } DSN_API dsn_error_t dsn_hosted_app_create( @@ -110,6 +115,7 @@ DSN_API dsn_error_t dsn_hosted_app_create( /*out*/void** app_context_for_callbacks ) { + DSN_C_GUARD_BEGIN if (type == nullptr || type[0] == '\0') { derror("dsn_hosted_app_create got null or empty type"); @@ -152,10 +158,12 @@ DSN_API dsn_error_t dsn_hosted_app_create( return node->get_l2_handler() .create_app(type, gpid, data_dir, app_context_for_downcalls, app_context_for_callbacks) .get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_hosted_app_start(void* app_context, int argc, char** argv) { + DSN_C_GUARD_BEGIN if (app_context == nullptr) { derror("dsn_hosted_app_start got null app_context"); @@ -191,10 +199,12 @@ DSN_API dsn_error_t dsn_hosted_app_start(void* app_context, int argc, char** arg } return node->get_l2_handler().start_app(app_context, argc, argv).get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_hosted_app_destroy(void* app_context, bool cleanup) { + DSN_C_GUARD_BEGIN if (app_context == nullptr) { derror("dsn_hosted_app_destroy got null app_context"); @@ -209,10 +219,12 @@ DSN_API dsn_error_t dsn_hosted_app_destroy(void* app_context, bool cleanup) } return node->get_l2_handler().destroy_app(app_context, cleanup).get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_hosted_app_commit_rpc_request(void* app_context, dsn_message_t msg, bool exec_inline) { + DSN_C_GUARD_BEGIN if (app_context == nullptr) { derror("dsn_hosted_app_commit_rpc_request got null app_context"); @@ -255,6 +267,7 @@ DSN_API dsn_error_t dsn_hosted_app_commit_rpc_request(void* app_context, dsn_mes } return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } diff --git a/src/core/src/c_api_guard.h b/src/core/src/c_api_guard.h new file mode 100644 index 000000000..d70541759 --- /dev/null +++ b/src/core/src/c_api_guard.h @@ -0,0 +1,109 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015 Microsoft Corporation + * + * -=- Robust Distributed System Nucleus (rDSN) -=- + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +/* + * Description: + * no-throw boundary guard for the rDSN C API (DSN_API) entry points. + * + * The rDSN public API is a C ABI. A C++ exception that unwinds across the + * C ABI boundary is undefined behavior, so every DSN_API entry point that + * may invoke throwing C++ code (allocation, STL containers, serialization, + * calls into the engines) must translate an escaping exception into the + * API's documented failure value (dsn_error_t / bool / nullptr / 0) + * instead of letting it propagate. + * + * Usage: + * + * DSN_API dsn_message_t dsn_msg_create_request(...) + * { + * DSN_C_GUARD_BEGIN + * ... body that may throw ... + * return msg; + * DSN_C_GUARD_END(nullptr) // fail value for this API + * } + * + * DSN_API void dsn_some_void_api(...) + * { + * DSN_C_GUARD_BEGIN + * ... body that may throw ... + * DSN_C_GUARD_END_VOID() + * } + * + * Two catch clauses are intentional: catch(const std::exception&) reports + * the diagnostic via ex.what(); catch(...) is still required because the C + * ABI boundary must also stop non-std exceptions (e.g. thrown ints, or + * foreign exception objects) from unwinding into the C caller. + * + * This guard does NOT catch dsn_coredump()/abort() (e.g. from dassert): + * those are deliberate fatal-invariant aborts, not exceptions, and are + * handled separately by converting recoverable aborts into error returns. + * + * Revision history: + * 2026-xx-xx, first version + */ + +# pragma once + +# include +# include + +// open the guarded region; pair with DSN_C_GUARD_END / DSN_C_GUARD_END_VOID +# define DSN_C_GUARD_BEGIN try { + +// close the guarded region for an API that returns a value; on an escaping +// C++ exception, logs the diagnostic and returns (fail_ret). The trailing +// return after the handlers is intentional: it makes (fail_ret) the function's +// last statement so -Wreturn-type (-Werror) never fires regardless of how the +// guarded body is structured, and it is the value returned on the exception +// path. +# define DSN_C_GUARD_END(fail_ret) \ + } \ + catch (const std::exception &ex) \ + { \ + derror("%s: C API call aborted by C++ exception: %s", \ + __FUNCTION__, ex.what()); \ + } \ + catch (...) \ + { \ + derror("%s: C API call aborted by unknown (non-std) C++ exception", \ + __FUNCTION__); \ + } \ + return (fail_ret); + +// close the guarded region for an API that returns void; on an escaping C++ +// exception, logs the diagnostic and returns normally. +# define DSN_C_GUARD_END_VOID() \ + } \ + catch (const std::exception &ex) \ + { \ + derror("%s: C API call aborted by C++ exception: %s", \ + __FUNCTION__, ex.what()); \ + } \ + catch (...) \ + { \ + derror("%s: C API call aborted by unknown (non-std) C++ exception", \ + __FUNCTION__); \ + } diff --git a/src/core/src/command_manager.cpp b/src/core/src/command_manager.cpp index 0ab8b4899..d8aa19728 100644 --- a/src/core/src/command_manager.cpp +++ b/src/core/src/command_manager.cpp @@ -46,6 +46,7 @@ # include # include "rpc_engine.h" # include +# include "c_api_guard.h" # ifdef __TITLE__ # undef __TITLE__ @@ -54,6 +55,7 @@ DSN_API const char* dsn_cli_run(const char* command_line) // return command output { + DSN_C_GUARD_BEGIN ::dsn::safe_string output; if (!dsn::run_command(command_line, output)) { @@ -70,6 +72,7 @@ DSN_API const char* dsn_cli_run(const char* command_line) // return command outp memcpy(c_output, &output[0], output.length()); c_output[output.length()] = '\0'; return c_output; + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_cli_free(const char* command_output) @@ -92,6 +95,7 @@ DSN_API dsn_handle_t dsn_cli_register( dsn_cli_free_handler output_freer ) { + DSN_C_GUARD_BEGIN if (cmd_handler == nullptr) { derror("dsn_cli_register got null command handler"); @@ -122,6 +126,7 @@ DSN_API dsn_handle_t dsn_cli_register( return cpp_output; } ); + DSN_C_GUARD_END(nullptr) } DSN_API dsn_handle_t dsn_cli_app_register( @@ -133,6 +138,7 @@ DSN_API dsn_handle_t dsn_cli_app_register( dsn_cli_free_handler output_freer ) { + DSN_C_GUARD_BEGIN if (command == nullptr || command[0] == '\0') { derror("dsn_cli_app_register got null or empty command"); @@ -162,6 +168,7 @@ DSN_API dsn_handle_t dsn_cli_app_register( dsn::command_manager::instance().set_cli_target_address(handle, dsn::task::get_current_rpc()->primary_address()); return handle; + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_cli_deregister(dsn_handle_t handle) diff --git a/src/core/src/disk_engine.cpp b/src/core/src/disk_engine.cpp index 128703561..9b3bf0ef5 100644 --- a/src/core/src/disk_engine.cpp +++ b/src/core/src/disk_engine.cpp @@ -38,6 +38,7 @@ # include # include # include "transient_memory.h" +# include # ifdef __TITLE__ # undef __TITLE__ @@ -331,10 +332,30 @@ void disk_engine::write(aio_task* aio) void disk_engine::process_write(aio_task* aio, uint32_t sz) { + auto complete_write_with_error = [this](aio_task* wk, error_code err) + { + auto df = (disk_file*)wk->aio()->file_object; + uint32_t next_sz; + auto next = df->on_write_completed(wk, (void*)&next_sz, err, 0); + if (next) + { + process_write(next, next_sz); + } + }; + // no batching if (aio->aio()->buffer_size == sz) { - aio->collapse(); + try + { + aio->collapse(); + } + catch (const std::exception& ex) + { + derror("collapse aio write buffer failed: %s", ex.what()); + complete_write_with_error(aio, ERR_FILE_OPERATION_FAILED); + return; + } return _provider->aio(aio); } @@ -342,36 +363,49 @@ void disk_engine::process_write(aio_task* aio, uint32_t sz) else { // merge the buffers - auto bb = tls_trans_mem_alloc_blob((size_t)sz); - char* ptr = (char*)bb.data(); - auto current_wk = aio; - do + try { - current_wk->copy_to(ptr); - ptr += current_wk->aio()->buffer_size; - current_wk = (aio_task*)current_wk->next; - } while (current_wk); - + auto bb = tls_trans_mem_alloc_blob((size_t)sz); + char* ptr = (char*)bb.data(); + auto current_wk = aio; + do + { + current_wk->copy_to(ptr); + ptr += current_wk->aio()->buffer_size; + current_wk = (aio_task*)current_wk->next; + } while (current_wk); - dassert(ptr == (char*)bb.data() + bb.length(), ""); + if (ptr != (char*)bb.data() + bb.length()) + { + derror("merge aio write buffers copied unexpected size"); + complete_write_with_error(aio, ERR_FILE_OPERATION_FAILED); + return; + } - // setup io task - auto new_task = new batch_write_io_task( - aio, - bb - ); - auto dio = new_task->aio(); - dio->buffer = (void*)bb.data(); - dio->buffer_size = sz; - dio->file_offset = aio->aio()->file_offset; - - dio->file = aio->aio()->file; - dio->file_object = aio->aio()->file_object; - dio->engine = aio->aio()->engine; - dio->type = AIO_Write; - - new_task->add_ref(); // released in complete_io - return _provider->aio(new_task); + // setup io task + auto new_task = new batch_write_io_task( + aio, + bb + ); + auto dio = new_task->aio(); + dio->buffer = (void*)bb.data(); + dio->buffer_size = sz; + dio->file_offset = aio->aio()->file_offset; + + dio->file = aio->aio()->file; + dio->file_object = aio->aio()->file_object; + dio->engine = aio->aio()->engine; + dio->type = AIO_Write; + + new_task->add_ref(); // released in complete_io + return _provider->aio(new_task); + } + catch (const std::exception& ex) + { + derror("merge aio write buffers failed: %s", ex.what()); + complete_write_with_error(aio, ERR_FILE_OPERATION_FAILED); + return; + } } } diff --git a/src/core/src/global_checkers.cpp b/src/core/src/global_checkers.cpp index e9da20475..63306f5b3 100644 --- a/src/core/src/global_checkers.cpp +++ b/src/core/src/global_checkers.cpp @@ -36,6 +36,7 @@ # include # include # include +# include "c_api_guard.h" namespace dsn { @@ -54,6 +55,7 @@ namespace dsn DSN_API bool dsn_register_app_checker(const char* name, dsn_checker_create create, dsn_checker_apply apply) { + DSN_C_GUARD_BEGIN if (name == nullptr || name[0] == '\0') { derror("dsn_register_app_checker got null or empty name"); @@ -79,4 +81,5 @@ DSN_API bool dsn_register_app_checker(const char* name, dsn_checker_create creat ::dsn::global_checker_store::instance().checkers.push_back(ck); return true; + DSN_C_GUARD_END(false) } diff --git a/src/core/src/global_config.cpp b/src/core/src/global_config.cpp index 73522105b..061136166 100644 --- a/src/core/src/global_config.cpp +++ b/src/core/src/global_config.cpp @@ -37,7 +37,9 @@ # include # include # include +# include # include +# include # include # include # include @@ -52,6 +54,54 @@ namespace dsn { +static bool read_all_config_keys(const char* section, std::vector& keys) +{ + int key_capacity = 0; + int key_count = dsn_config_get_all_keys(section, nullptr, &key_capacity); + if (key_count < 0) + { + fprintf(stderr, "failed to read config keys from section [%s]\n", section); + return false; + } + + if (key_count == 0) + { + keys.clear(); + return true; + } + + try + { + keys.resize(key_count); + } + catch (const std::exception& ex) + { + fprintf(stderr, "failed to allocate config keys for section [%s]: %s\n", section, ex.what()); + return false; + } + + key_capacity = key_count; + int actual_key_count = dsn_config_get_all_keys(section, keys.data(), &key_capacity); + if (actual_key_count < 0) + { + fprintf(stderr, "failed to read config keys from section [%s]\n", section); + return false; + } + + if (actual_key_count > key_capacity) + { + fprintf(stderr, + "config keys in section [%s] changed while reading: capacity = %d, actual = %d\n", + section, + key_capacity, + actual_key_count); + return false; + } + + keys.resize(key_capacity); + return true; +} + static bool build_client_network_confs( const char* section, /*out*/ network_client_configs& nss, @@ -59,14 +109,21 @@ static bool build_client_network_confs( { nss.clear(); - const char* keys[128]; - int kcapacity = 128; - int kcount = dsn_config_get_all_keys(section, keys, &kcapacity); - dassert(kcount <= 128, ""); + std::vector keys; + if (!read_all_config_keys(section, keys)) + { + return false; + } - for (int i = 0; i < kcapacity; i++) + for (const char* key : keys) { - std::string k = keys[i]; + if (key == nullptr) + { + fprintf(stderr, "config section [%s] has null key\n", section); + return false; + } + + std::string k = key; if (k.length() <= strlen("network.client.")) continue; @@ -140,14 +197,21 @@ static bool build_server_network_confs( { nss.clear(); - const char* keys[128]; - int kcapacity = 128; - int kcount = dsn_config_get_all_keys(section, keys, &kcapacity); - dassert(kcount <= 128, ""); + std::vector keys; + if (!read_all_config_keys(section, keys)) + { + return false; + } - for (int i = 0; i < kcapacity; i++) + for (const char* key : keys) { - std::string k = keys[i]; + if (key == nullptr) + { + fprintf(stderr, "config section [%s] has null key\n", section); + return false; + } + + std::string k = key; if (k.length() <= strlen("network.server.")) continue; diff --git a/src/core/src/main.cpp b/src/core/src/main.cpp index efb9d0375..6adf46c65 100644 --- a/src/core/src/main.cpp +++ b/src/core/src/main.cpp @@ -56,6 +56,7 @@ # include "coredump.h" # include "transient_memory.h" # include "library_utils.h" +# include "c_api_guard.h" # include # include # include @@ -100,6 +101,7 @@ static struct _all_info_ DSN_API const char* dsn_config_get_value_string(const char* section, const char* key, const char* default_value, const char* dsptr) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_config_get_value_string got null or empty section"); @@ -125,10 +127,12 @@ DSN_API const char* dsn_config_get_value_string(const char* section, const char* } return dsn_all.config->get_string_value(section, key, default_value, dsptr); + DSN_C_GUARD_END(default_value) } DSN_API bool dsn_config_get_value_bool(const char* section, const char* key, bool default_value, const char* dsptr) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_config_get_value_bool got null or empty section"); @@ -148,10 +152,12 @@ DSN_API bool dsn_config_get_value_bool(const char* section, const char* key, boo } return dsn_all.config->get_value(section, key, default_value, dsptr); + DSN_C_GUARD_END(default_value) } DSN_API uint64_t dsn_config_get_value_uint64(const char* section, const char* key, uint64_t default_value, const char* dsptr) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_config_get_value_uint64 got null or empty section"); @@ -171,10 +177,12 @@ DSN_API uint64_t dsn_config_get_value_uint64(const char* section, const char* ke } return dsn_all.config->get_value(section, key, default_value, dsptr); + DSN_C_GUARD_END(default_value) } DSN_API double dsn_config_get_value_double(const char* section, const char* key, double default_value, const char* dsptr) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_config_get_value_double got null or empty section"); @@ -194,10 +202,12 @@ DSN_API double dsn_config_get_value_double(const char* section, const char* key, } return dsn_all.config->get_value(section, key, default_value, dsptr); + DSN_C_GUARD_END(default_value) } DSN_API int dsn_config_get_all_sections(const char** buffers, /*inout*/ int* buffer_count) { + DSN_C_GUARD_BEGIN if (buffer_count == nullptr) { derror("dsn_config_get_all_sections got null buffer_count"); @@ -235,10 +245,12 @@ DSN_API int dsn_config_get_all_sections(const char** buffers, /*inout*/ int* buf } return scount; + DSN_C_GUARD_END(-1) } DSN_API int dsn_config_get_all_keys(const char* section, const char** buffers, /*inout*/ int* buffer_count) // return all key count (may greater than buffer_count) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_config_get_all_keys got null or empty section"); @@ -282,10 +294,12 @@ DSN_API int dsn_config_get_all_keys(const char* section, const char** buffers, / } return kcount; + DSN_C_GUARD_END(-1) } DSN_API void dsn_config_dump(const char* file) { + DSN_C_GUARD_BEGIN if (file == nullptr || file[0] == '\0') { derror("dsn_config_dump got null or empty file"); @@ -301,6 +315,7 @@ DSN_API void dsn_config_dump(const char* file) std::ofstream os(file, std::ios::out); dsn_all.config->dump(os); os.close(); + DSN_C_GUARD_END_VOID() } extern bool dsn_log_init(); @@ -816,6 +831,7 @@ DSN_API bool dsn_run_config(const char* config, bool sleep_after_init) DSN_API int dsn_get_all_apps(dsn_app_info* info_buffer, int count) { + DSN_C_GUARD_BEGIN if (info_buffer == nullptr) { derror("dsn_get_all_apps got null info_buffer"); @@ -861,6 +877,7 @@ DSN_API int dsn_get_all_apps(dsn_app_info* info_buffer, int count) } } return i; + DSN_C_GUARD_END(-1) } diff --git a/src/core/src/message_parser.cpp b/src/core/src/message_parser.cpp index 8aef543a6..d08389766 100644 --- a/src/core/src/message_parser.cpp +++ b/src/core/src/message_parser.cpp @@ -35,6 +35,8 @@ # include "message_parser_manager.h" # include +# include +# include # ifdef __TITLE__ # undef __TITLE__ @@ -164,27 +166,47 @@ namespace dsn { //-------------------- msg reader -------------------- char* message_reader::read_buffer_ptr(unsigned int read_next) { - if (read_next + _buffer_occupied > _buffer.length()) + if (read_next > std::numeric_limits::max() - _buffer_occupied) { - // remember currently read content - blob rb; - if (_buffer_occupied > 0) - rb = _buffer.range(0, _buffer_occupied); - - // switch to next - unsigned int sz = (read_next + _buffer_occupied > _buffer_block_size ? - read_next + _buffer_occupied : _buffer_block_size); - _buffer.assign(dsn::make_shared_array(sz), 0, sz); - _buffer_occupied = 0; - - // copy - if (rb.length() > 0) + derror("message_reader::read_buffer_ptr got too large read size, read_next = %u, occupied = %u", + read_next, + _buffer_occupied); + return nullptr; + } + + const unsigned int required_size = read_next + _buffer_occupied; + if (required_size > _buffer.length()) + { + try + { + // remember currently read content + blob rb; + if (_buffer_occupied > 0) + rb = _buffer.range(0, _buffer_occupied); + + // switch to next + unsigned int sz = (required_size > _buffer_block_size ? required_size : _buffer_block_size); + _buffer.assign(dsn::make_shared_array(sz), 0, sz); + _buffer_occupied = 0; + + // copy + if (rb.length() > 0) + { + memcpy((void*)_buffer.data(), (const void*)rb.data(), rb.length()); + _buffer_occupied = rb.length(); + } + + if (read_next + _buffer_occupied > _buffer.length()) + { + derror("message_reader::read_buffer_ptr failed to prepare enough buffer"); + return nullptr; + } + } + catch (const std::exception& ex) { - memcpy((void*)_buffer.data(), (const void*)rb.data(), rb.length()); - _buffer_occupied = rb.length(); + derror("message_reader::read_buffer_ptr failed: %s", ex.what()); + return nullptr; } - - dassert (read_next + _buffer_occupied <= _buffer.length(), ""); } return (char*)(_buffer.data() + _buffer_occupied); diff --git a/src/core/src/perf_counters.cpp b/src/core/src/perf_counters.cpp index 41cc47575..a17d17caa 100644 --- a/src/core/src/perf_counters.cpp +++ b/src/core/src/perf_counters.cpp @@ -41,9 +41,11 @@ # include # include "service_engine.h" # include "perf_counters.h" +# include "c_api_guard.h" DSN_API dsn_handle_t dsn_perf_counter_create(const char* section, const char* name, dsn_perf_counter_type_t type, const char* description) { + DSN_C_GUARD_BEGIN if (section == nullptr || section[0] == '\0') { derror("dsn_perf_counter_create got null or empty section"); @@ -73,6 +75,7 @@ DSN_API dsn_handle_t dsn_perf_counter_create(const char* section, const char* na auto c = dsn::perf_counters::instance().get_counter(cnode->name(), section, name, type, description, true); c->add_ref(); return c.get(); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_perf_counter_remove(dsn_handle_t handle) diff --git a/src/core/src/rpc_message.cpp b/src/core/src/rpc_message.cpp index c443c0ad1..8438b1d2b 100644 --- a/src/core/src/rpc_message.cpp +++ b/src/core/src/rpc_message.cpp @@ -38,9 +38,11 @@ # include # include # include // for isprint() +# include # include "task_engine.h" # include "transient_memory.h" +# include "c_api_guard.h" using namespace dsn::utils; @@ -65,6 +67,7 @@ DSN_API dsn_message_t dsn_msg_create_request( uint64_t partition_hash ) { + DSN_C_GUARD_BEGIN const auto spec = ::dsn::task_spec::get(rpc_code); if (rpc_code == ::dsn::TASK_CODE_INVALID || spec == nullptr || spec->type != TASK_TYPE_RPC_REQUEST) { @@ -73,6 +76,7 @@ DSN_API dsn_message_t dsn_msg_create_request( } return ::dsn::message_ex::create_request(rpc_code, timeout_milliseconds, thread_hash, partition_hash); + DSN_C_GUARD_END(nullptr) } DSN_API dsn_message_t dsn_msg_create_received_request( @@ -84,6 +88,7 @@ DSN_API dsn_message_t dsn_msg_create_received_request( uint64_t partition_hash ) { + DSN_C_GUARD_BEGIN const auto spec = ::dsn::task_spec::get(rpc_code); if (spec == nullptr || spec->type != TASK_TYPE_RPC_REQUEST) { @@ -112,16 +117,24 @@ DSN_API dsn_message_t dsn_msg_create_received_request( ::dsn::blob bb((const char*)buffer, 0, size); auto msg = ::dsn::message_ex::create_receive_message_with_standalone_header(bb); + if (msg == nullptr) + { + derror("dsn_msg_create_received_request failed to create message"); + return nullptr; + } + msg->local_rpc_code = rpc_code; msg->header->client.thread_hash = thread_hash; msg->header->client.partition_hash = partition_hash; msg->header->context.u.serialize_format = serialization_type; msg->add_ref(); // released by callers explicitly using dsn_msg_release return msg; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_message_t dsn_msg_copy(dsn_message_t msg, bool clone_content, bool copy_for_receive) { + DSN_C_GUARD_BEGIN if (msg == nullptr) { derror("dsn_msg_copy got null message"); @@ -129,10 +142,12 @@ DSN_API dsn_message_t dsn_msg_copy(dsn_message_t msg, bool clone_content, bool c } return ((::dsn::message_ex*)msg)->copy(clone_content, copy_for_receive); + DSN_C_GUARD_END(nullptr) } DSN_API dsn_message_t dsn_msg_create_response(dsn_message_t request) { + DSN_C_GUARD_BEGIN if (request == nullptr) { derror("dsn_msg_create_response got null request"); @@ -147,10 +162,12 @@ DSN_API dsn_message_t dsn_msg_create_response(dsn_message_t request) auto msg = ((::dsn::message_ex*)request)->create_response(); return msg; + DSN_C_GUARD_END(nullptr) } DSN_API bool dsn_msg_write_next(dsn_message_t msg, void** ptr, size_t* size, size_t min_size) { + DSN_C_GUARD_BEGIN if (msg == nullptr) { derror("dsn_msg_write_next got null message"); @@ -177,20 +194,21 @@ DSN_API bool dsn_msg_write_next(dsn_message_t msg, void** ptr, size_t* size, siz return false; } - ((::dsn::message_ex*)msg)->write_next(ptr, size, min_size); - return true; + return ((::dsn::message_ex*)msg)->write_next(ptr, size, min_size); + DSN_C_GUARD_END(false) } DSN_API bool dsn_msg_write_commit(dsn_message_t msg, size_t size) { + DSN_C_GUARD_BEGIN if (msg == nullptr) { derror("dsn_msg_write_commit got null message"); return false; } - ((::dsn::message_ex*)msg)->write_commit(size); - return true; + return ((::dsn::message_ex*)msg)->write_commit(size); + DSN_C_GUARD_END(false) } DSN_API bool dsn_msg_read_next(dsn_message_t msg, void** ptr, size_t* size) @@ -232,8 +250,7 @@ DSN_API bool dsn_msg_read_commit(dsn_message_t msg, size_t size) return false; } - ((::dsn::message_ex*)msg)->read_commit(size); - return true; + return ((::dsn::message_ex*)msg)->read_commit(size); } DSN_API size_t dsn_msg_body_size(dsn_message_t msg) @@ -554,19 +571,38 @@ message_ex* message_ex::create_receive_message(const blob& data) message_ex* message_ex::create_receive_message_with_standalone_header(const blob& data) { - message_ex* msg = new message_ex(); - std::shared_ptr header_holder(static_cast(dsn_transient_malloc(sizeof(message_header))), [](char* c) {dsn_transient_free(c);}); - msg->header = reinterpret_cast(header_holder.get()); - memset(reinterpret_cast(msg->header), 0, sizeof(message_header)); - msg->buffers.emplace_back(blob(std::move(header_holder), sizeof(message_header))); - msg->buffers.push_back(data); + message_ex* msg = nullptr; + try + { + msg = new message_ex(); + char* header_ptr = static_cast(dsn_transient_malloc(sizeof(message_header))); + if (header_ptr == nullptr) + { + derror("message_ex::create_receive_message_with_standalone_header failed to allocate header"); + delete msg; + return nullptr; + } - msg->header->body_length = data.length(); - msg->_is_read = true; - //we skip the message header - msg->_rw_index = 1; + std::shared_ptr header_holder(header_ptr, [](char* c) { dsn_transient_free(c); }); + msg->header = reinterpret_cast(header_holder.get()); + memset(reinterpret_cast(msg->header), 0, sizeof(message_header)); + msg->buffers.emplace_back(blob(std::move(header_holder), sizeof(message_header))); + msg->buffers.push_back(data); - return msg; + msg->header->body_length = data.length(); + msg->_is_read = true; + //we skip the message header + msg->_rw_index = 1; + + return msg; + } + catch (const std::exception& ex) + { + derror("message_ex::create_receive_message_with_standalone_header failed: %s", ex.what()); + } + + delete msg; + return nullptr; } message_ex* message_ex::copy(bool clone_content, bool copy_for_receive) @@ -779,12 +815,39 @@ void message_ex::prepare_buffer_header() header = (message_header*)ptr; } -void message_ex::write_next(void** ptr, size_t* size, size_t min_size) +bool message_ex::write_next(void** ptr, size_t* size, size_t min_size) { // printf("%p %s\n", this, __FUNCTION__); - dassert(!this->_is_read && this->_rw_committed, "there are pending msg write not committed" - ", please invoke dsn_msg_write_next and dsn_msg_write_commit in pairs"); + if (ptr == nullptr || size == nullptr) + { + derror("message_ex::write_next: null ptr or size out-parameter"); + return false; + } + + if (this->_is_read || !this->_rw_committed) + { + derror("message_ex::write_next: there are pending msg write not committed" + ", please invoke dsn_msg_write_next and dsn_msg_write_commit in pairs"); + *ptr = nullptr; + *size = 0; + return false; + } ::dsn::tls_trans_mem_next(ptr, size, min_size); + + // tls_trans_mem_next must hand back a valid transient buffer. A genuine + // allocation failure throws (translated to false at the dsn_msg_write_next + // boundary), so a null *ptr here means the allocator broke its contract. + // Bail out instead of doing pointer arithmetic on null below; re-commit 0 + // bytes first to keep the tls_trans_mem_next/commit pairing consistent for + // later writes on this thread. + if (*ptr == nullptr) + { + derror("message_ex::write_next: transient memory returned a null buffer"); + ::dsn::tls_trans_mem_commit(0); + *size = 0; + return false; + } + this->_rw_committed = false; // optimization @@ -802,7 +865,7 @@ void message_ex::write_next(void** ptr, size_t* size, size_t min_size) (int)(lbb.length() + *size) ); - return; + return true; } } @@ -816,13 +879,18 @@ void message_ex::write_next(void** ptr, size_t* size, size_t min_size) this->buffers.push_back(buffer); dassert(this->_rw_index + 1 == (int)this->buffers.size(), "message write buffer count is not right"); + return true; } -void message_ex::write_commit(size_t size) +bool message_ex::write_commit(size_t size) { // printf("%p %s\n", this, __FUNCTION__); - dassert(!this->_rw_committed, "there are no pending msg write to be committed" - ", please invoke dsn_msg_write_next and dsn_msg_write_commit in pairs"); + if (this->_rw_committed) + { + derror("message_ex::write_commit: there are no pending msg write to be committed" + ", please invoke dsn_msg_write_next and dsn_msg_write_commit in pairs"); + return false; + } ::dsn::tls_trans_mem_commit(size); @@ -830,6 +898,7 @@ void message_ex::write_commit(size_t size) *this->buffers.rbegin() = this->buffers.rbegin()->range(0, (int)this->_rw_offset); this->_rw_committed = true; this->header->body_length += (int)size; + return true; } void message_ex::write_append(const blob& data) @@ -851,8 +920,20 @@ void message_ex::write_append(const blob& data) bool message_ex::read_next(void** ptr, size_t* size) { // printf("%p %s %d\n", this, __FUNCTION__, utils::get_current_tid()); - dassert(this->_is_read && this->_rw_committed, "there are pending msg read not committed" - ", please invoke dsn_msg_read_next and dsn_msg_read_commit in pairs"); + if (ptr == nullptr || size == nullptr) + { + derror("message_ex::read_next: null ptr or size out-parameter"); + return false; + } + + if (!this->_is_read || !this->_rw_committed) + { + derror("message_ex::read_next: there are pending msg read not committed" + ", please invoke dsn_msg_read_next and dsn_msg_read_commit in pairs"); + *ptr = nullptr; + *size = 0; + return false; + } int idx = this->_rw_index; if (-1 == idx || @@ -864,9 +945,21 @@ bool message_ex::read_next(void** ptr, size_t* size) if (idx < (int)this->buffers.size()) { - this->_rw_committed = false; *ptr = (void*)(this->buffers[idx].data() + this->_rw_offset); *size = (size_t)this->buffers[idx].length() - this->_rw_offset; + + // Don't hand back a null *ptr as a successful read: buffers[idx] can be + // an empty blob whose data() is null (and _rw_offset is 0 in that case). + // Validate the produced *ptr here, before mutating the read state, so the + // read next/commit pairing stays consistent on this failure path. + if (*ptr == nullptr) + { + derror("message_ex::read_next: null buffer pointer at current read index"); + *size = 0; + return false; + } + + this->_rw_committed = false; return true; } else @@ -877,15 +970,24 @@ bool message_ex::read_next(void** ptr, size_t* size) } } -void message_ex::read_commit(size_t size) +bool message_ex::read_commit(size_t size) { // printf("%p %s\n", this, __FUNCTION__); - dassert(!this->_rw_committed, "there are no pending msg read to be committed" - ", please invoke dsn_msg_read_next and dsn_msg_read_commit in pairs"); + if (this->_rw_committed) + { + derror("message_ex::read_commit: there are no pending msg read to be committed" + ", please invoke dsn_msg_read_next and dsn_msg_read_commit in pairs"); + return false; + } - dassert(-1 != this->_rw_index, "no buffer in curent msg is under read"); + if (-1 == this->_rw_index) + { + derror("message_ex::read_commit: no buffer in current msg is under read"); + return false; + } this->_rw_offset += (int)size; this->_rw_committed = true; + return true; } void* message_ex::rw_ptr(size_t offset_begin) diff --git a/src/core/src/rpc_message.test.cpp b/src/core/src/rpc_message.test.cpp index 486d35d8c..da4a7e588 100644 --- a/src/core/src/rpc_message.test.cpp +++ b/src/core/src/rpc_message.test.cpp @@ -246,6 +246,77 @@ TEST(core, message_ex) } } +TEST(core, message_ex_rw_pairing) +{ + // write_next/write_commit and read_next/read_commit used to abort the + // process via dassert when they were called out of order. They now report + // the sequencing violation by returning false instead, so a faulty caller + // can recover. Exercise those pairing checks here. + const char* data = "adaoihfeuifgggggisdosghkbvjhzxvdafdiofgeof"; + size_t data_size = strlen(data); + + void* ptr = nullptr; + size_t sz = 0; + + { // write side + message_ex* request = message_ex::create_request(RPC_CODE_FOR_TEST, 100, 1); + request->add_ref(); + + // nothing is pending on a fresh message, so there is nothing to commit + ASSERT_FALSE(request->write_commit(data_size)); + + tls_trans_mem_alloc(1024); // reset tls buffer + ASSERT_TRUE(request->write_next(&ptr, &sz, data_size)); + ASSERT_NE(nullptr, ptr); + + // a second write_next before commit is rejected and clears the out-params + void* pending_ptr = reinterpret_cast(1); + size_t pending_sz = 1; + ASSERT_FALSE(request->write_next(&pending_ptr, &pending_sz, data_size)); + ASSERT_EQ(nullptr, pending_ptr); + ASSERT_EQ(0u, pending_sz); + + memcpy(ptr, data, data_size); + ASSERT_TRUE(request->write_commit(data_size)); + // committing again is rejected: there is no longer a pending write + ASSERT_FALSE(request->write_commit(data_size)); + + request->release_ref(); + } + + { // read side + message_ex* request = message_ex::create_request(RPC_CODE_FOR_TEST, 100, 1); + request->add_ref(); + + ASSERT_TRUE(request->write_next(&ptr, &sz, data_size)); + memcpy(ptr, data, data_size); + ASSERT_TRUE(request->write_commit(data_size)); + + message_ex* receive = message_ex::create_receive_message(request->buffers[0]); + receive->add_ref(); + + // no buffer is under read yet, so there is nothing to commit + ASSERT_FALSE(receive->read_commit(data_size)); + + ASSERT_TRUE(receive->read_next(&ptr, &sz)); + ASSERT_EQ(data_size, sz); + + // a second read_next before commit is rejected and clears the out-params + void* pending_ptr = reinterpret_cast(1); + size_t pending_sz = 1; + ASSERT_FALSE(receive->read_next(&pending_ptr, &pending_sz)); + ASSERT_EQ(nullptr, pending_ptr); + ASSERT_EQ(0u, pending_sz); + + ASSERT_TRUE(receive->read_commit(sz)); + // committing again is rejected: there is no longer a pending read + ASSERT_FALSE(receive->read_commit(sz)); + + receive->release_ref(); + request->release_ref(); + } +} + TEST(core, dsn_msg_invalid_parameters) { void* ptr = reinterpret_cast(1); diff --git a/src/core/src/service_api_c.cpp b/src/core/src/service_api_c.cpp index 255b6d75f..002a47b75 100644 --- a/src/core/src/service_api_c.cpp +++ b/src/core/src/service_api_c.cpp @@ -57,6 +57,7 @@ # include "crc.h" # include "transient_memory.h" # include "library_utils.h" +# include "c_api_guard.h" # include # if defined(_WIN32) @@ -119,6 +120,7 @@ DSN_API const char* dsn_error_to_string(dsn_error_t err) DSN_API dsn_error_t dsn_error_from_string(const char* s, dsn_error_t default_err) { + DSN_C_GUARD_BEGIN if (s == nullptr || s[0] == '\0') { derror("dsn_error_from_string got null or empty string"); @@ -127,6 +129,7 @@ DSN_API dsn_error_t dsn_error_from_string(const char* s, dsn_error_t default_err auto r = error_code_mgr::instance().get_id(s); return r == -1 ? default_err : r; + DSN_C_GUARD_END(default_err) } DSN_API volatile int* dsn_task_queue_virtual_length_ptr( @@ -134,6 +137,7 @@ DSN_API volatile int* dsn_task_queue_virtual_length_ptr( int hash ) { + DSN_C_GUARD_BEGIN if (code < 0 || code == ::dsn::TASK_CODE_INVALID) { derror("dsn_task_queue_virtual_length_ptr got invalid code = %d", code); @@ -148,6 +152,7 @@ DSN_API volatile int* dsn_task_queue_virtual_length_ptr( } return node->computation()->get_task_queue_virtual_length_ptr(code, hash); + DSN_C_GUARD_END(nullptr) } // use ::dsn::threadpool_code2; for parsing purpose @@ -171,6 +176,7 @@ DSN_API const char* dsn_threadpool_code_to_string(dsn_threadpool_code_t pool_cod DSN_API dsn_threadpool_code_t dsn_threadpool_code_from_string(const char* s, dsn_threadpool_code_t default_code) { + DSN_C_GUARD_BEGIN if (s == nullptr || s[0] == '\0') { derror("dsn_threadpool_code_from_string got null or empty string"); @@ -179,6 +185,7 @@ DSN_API dsn_threadpool_code_t dsn_threadpool_code_from_string(const char* s, dsn auto r = ::dsn::utils::customized_id_mgr< ::dsn::threadpool_code2_>::instance().get_id(s); return r == -1 ? default_code : r; + DSN_C_GUARD_END(default_code) } DSN_API int dsn_threadpool_code_max() @@ -303,6 +310,7 @@ DSN_API const char* dsn_task_code_to_string(dsn_task_code_t code) DSN_API dsn_task_code_t dsn_task_code_from_string(const char* s, dsn_task_code_t default_code) { + DSN_C_GUARD_BEGIN if (s == nullptr || s[0] == '\0') { derror("dsn_task_code_from_string got null or empty string"); @@ -311,6 +319,7 @@ DSN_API dsn_task_code_t dsn_task_code_from_string(const char* s, dsn_task_code_t auto r = ::dsn::utils::customized_id_mgr::instance().get_id(s); return r == -1 ? default_code : r; + DSN_C_GUARD_END(default_code) } DSN_API int dsn_task_code_max() @@ -400,6 +409,7 @@ DSN_API uint64_t dsn_crc64_concatenate(uint32_t xy_init, uint64_t x_init, uint64 DSN_API dsn_task_t dsn_task_create(dsn_task_code_t code, dsn_task_handler_t cb, void* context, int hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_COMPUTE) { @@ -418,11 +428,13 @@ DSN_API dsn_task_t dsn_task_create(dsn_task_code_t code, dsn_task_handler_t cb, t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_task_create_timer(dsn_task_code_t code, dsn_task_handler_t cb, void* context, int hash, int interval_milliseconds, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_COMPUTE) { @@ -446,11 +458,13 @@ DSN_API dsn_task_t dsn_task_create_timer(dsn_task_code_t code, dsn_task_handler_ t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_task_create_ex(dsn_task_code_t code, dsn_task_handler_t cb, dsn_task_cancelled_handler_t on_cancel, void* context, int hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_COMPUTE) { @@ -468,12 +482,14 @@ DSN_API dsn_task_t dsn_task_create_ex(dsn_task_code_t code, dsn_task_handler_t c t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_task_create_timer_ex(dsn_task_code_t code, dsn_task_handler_t cb, dsn_task_cancelled_handler_t on_cancel, void* context, int hash, int interval_milliseconds, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_COMPUTE) { @@ -498,10 +514,12 @@ DSN_API dsn_task_t dsn_task_create_timer_ex(dsn_task_code_t code, dsn_task_handl t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_tracker_t dsn_task_tracker_create(int task_bucket_count) { + DSN_C_GUARD_BEGIN if (task_bucket_count <= 0) { derror("dsn_task_tracker_create got invalid task_bucket_count = %d", task_bucket_count); @@ -509,6 +527,7 @@ DSN_API dsn_task_tracker_t dsn_task_tracker_create(int task_bucket_count) } return (dsn_task_tracker_t)(new ::dsn::task_tracker(task_bucket_count)); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_task_tracker_destroy(dsn_task_tracker_t tracker) @@ -546,6 +565,7 @@ DSN_API void dsn_task_tracker_wait_all(dsn_task_tracker_t tracker) DSN_API bool dsn_task_call(dsn_task_t task, int delay_milliseconds) { + DSN_C_GUARD_BEGIN if (task == nullptr) { derror("dsn_task_call got null task"); @@ -568,6 +588,7 @@ DSN_API bool dsn_task_call(dsn_task_t task, int delay_milliseconds) t->set_delay(delay_milliseconds); t->enqueue(); return true; + DSN_C_GUARD_END(false) } DSN_API void dsn_task_add_ref(dsn_task_t task) @@ -666,8 +687,11 @@ DSN_API void dsn_task_wait(dsn_task_t task) auto t = (::dsn::task*)task; auto r = t->wait(); - dassert(r, - "task wait without timeout must succeeds (task_id = %016" PRIx64 ")", + // dsn_task_wait returns void, so it has no channel to report a failed wait; + // per design it keeps the dassert. For an infinite (no-timeout) wait the + // only way r is false is a self-wait, which task::wait already logged. + dassert(r, + "dsn_task_wait failed: a task must not wait on itself (task_id = %016" PRIx64 ")", t->id() ); } @@ -699,34 +723,52 @@ DSN_API dsn_error_t dsn_task_error(dsn_task_t task) // synchronization - concurrent access and coordination among threads // //------------------------------------------------------------------------------ -DSN_API dsn_handle_t dsn_exlock_create(bool recursive) +namespace { + +template +::dsn::ilock* create_lock_chain(const char* api_name, + const char* factory_name, + const TAspects& aspects) { - if (recursive) + TProvider* last = ::dsn::utils::factory_store::create( + factory_name, ::dsn::PROVIDER_TYPE_MAIN, nullptr); + if (last == nullptr) { - ::dsn::lock_provider* last = ::dsn::utils::factory_store< ::dsn::lock_provider>::create( - ::dsn::service_engine::fast_instance().spec().lock_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, nullptr); + derror("%s got null provider factory result for '%s'", api_name, factory_name); + return nullptr; + } - // TODO: perf opt by saving the func ptrs somewhere - for (auto& s : ::dsn::service_engine::fast_instance().spec().lock_aspects) + // TODO: perf opt by saving the func ptrs somewhere + for (auto& s : aspects) + { + TProvider* next = ::dsn::utils::factory_store::create( + s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, last); + if (next == nullptr) { - last = ::dsn::utils::factory_store< ::dsn::lock_provider>::create(s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, last); + derror("%s got null aspect factory result for '%s'", api_name, s.c_str()); + // Provider destructors own their inner_provider, so deleting the head + // releases the whole chain built so far. + delete last; + return nullptr; } - - return (dsn_handle_t)dynamic_cast< ::dsn::ilock*>(last); + last = next; } - else - { - ::dsn::lock_nr_provider* last = ::dsn::utils::factory_store< ::dsn::lock_nr_provider>::create( - ::dsn::service_engine::fast_instance().spec().lock_nr_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, nullptr); - // TODO: perf opt by saving the func ptrs somewhere - for (auto& s : ::dsn::service_engine::fast_instance().spec().lock_nr_aspects) - { - last = ::dsn::utils::factory_store< ::dsn::lock_nr_provider>::create(s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, last); - } + return static_cast< ::dsn::ilock*>(last); +} - return (dsn_handle_t)dynamic_cast< ::dsn::ilock*>(last); - } +} // anonymous namespace + +DSN_API dsn_handle_t dsn_exlock_create(bool recursive) +{ + DSN_C_GUARD_BEGIN + const auto& spec = ::dsn::service_engine::fast_instance().spec(); + return recursive + ? (dsn_handle_t)create_lock_chain< ::dsn::lock_provider>( + __FUNCTION__, spec.lock_factory_name.c_str(), spec.lock_aspects) + : (dsn_handle_t)create_lock_chain< ::dsn::lock_nr_provider>( + __FUNCTION__, spec.lock_nr_factory_name.c_str(), spec.lock_nr_aspects); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_exlock_destroy(dsn_handle_t l) @@ -783,15 +825,34 @@ DSN_API void dsn_exlock_unlock(dsn_handle_t l) // non-recursive rwlock DSN_API dsn_handle_t dsn_rwlock_nr_create() { + DSN_C_GUARD_BEGIN + const auto& spec = ::dsn::service_engine::fast_instance().spec(); ::dsn::rwlock_nr_provider* last = ::dsn::utils::factory_store< ::dsn::rwlock_nr_provider>::create( - ::dsn::service_engine::fast_instance().spec().rwlock_nr_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, nullptr); + spec.rwlock_nr_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, nullptr); + if (last == nullptr) + { + derror("dsn_rwlock_nr_create got null provider factory result for '%s'", + spec.rwlock_nr_factory_name.c_str()); + return nullptr; + } // TODO: perf opt by saving the func ptrs somewhere - for (auto& s : ::dsn::service_engine::fast_instance().spec().rwlock_nr_aspects) + for (auto& s : spec.rwlock_nr_aspects) { - last = ::dsn::utils::factory_store< ::dsn::rwlock_nr_provider>::create(s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, last); + ::dsn::rwlock_nr_provider* next = ::dsn::utils::factory_store< ::dsn::rwlock_nr_provider>::create( + s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, last); + if (next == nullptr) + { + derror("dsn_rwlock_nr_create got null aspect factory result for '%s'", s.c_str()); + // Provider destructors own their inner_provider, so deleting the head + // releases the whole chain built so far. + delete last; + return nullptr; + } + last = next; } return (dsn_handle_t)(last); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_rwlock_nr_destroy(dsn_handle_t l) @@ -881,22 +942,40 @@ DSN_API bool dsn_rwlock_nr_try_lock_write(dsn_handle_t l) DSN_API dsn_handle_t dsn_semaphore_create(int initial_count) { + DSN_C_GUARD_BEGIN if (initial_count < 0) { derror("dsn_semaphore_create got invalid initial_count = %d", initial_count); return nullptr; } + const auto& spec = ::dsn::service_engine::fast_instance().spec(); ::dsn::semaphore_provider* last = ::dsn::utils::factory_store< ::dsn::semaphore_provider>::create( - ::dsn::service_engine::fast_instance().spec().semaphore_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, initial_count, nullptr); + spec.semaphore_factory_name.c_str(), ::dsn::PROVIDER_TYPE_MAIN, initial_count, nullptr); + if (last == nullptr) + { + derror("dsn_semaphore_create got null provider factory result for '%s'", + spec.semaphore_factory_name.c_str()); + return nullptr; + } // TODO: perf opt by saving the func ptrs somewhere - for (auto& s : ::dsn::service_engine::fast_instance().spec().semaphore_aspects) + for (auto& s : spec.semaphore_aspects) { - last = ::dsn::utils::factory_store< ::dsn::semaphore_provider>::create( + ::dsn::semaphore_provider* next = ::dsn::utils::factory_store< ::dsn::semaphore_provider>::create( s.c_str(), ::dsn::PROVIDER_TYPE_ASPECT, initial_count, last); + if (next == nullptr) + { + derror("dsn_semaphore_create got null aspect factory result for '%s'", s.c_str()); + // Provider destructors own their inner_provider, so deleting the head + // releases the whole chain built so far. + delete last; + return nullptr; + } + last = next; } return (dsn_handle_t)(last); + DSN_C_GUARD_END(nullptr) } DSN_API void dsn_semaphore_destroy(dsn_handle_t s) @@ -977,6 +1056,7 @@ DSN_API bool dsn_rpc_register_handler( dsn_gpid gpid ) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_RPC_REQUEST) { @@ -1009,10 +1089,12 @@ DSN_API bool dsn_rpc_register_handler( delete h; } return r; + DSN_C_GUARD_END(false) } DSN_API void* dsn_rpc_unregiser_handler(dsn_task_code_t code, dsn_gpid gpid) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_RPC_REQUEST) { @@ -1031,11 +1113,13 @@ DSN_API void* dsn_rpc_unregiser_handler(dsn_task_code_t code, dsn_gpid gpid) } return param; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_rpc_create_response_task(dsn_message_t request, dsn_rpc_response_handler_t cb, void* context, int reply_thread_hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto msg = ((::dsn::message_ex*)request); if (msg == nullptr) { @@ -1047,12 +1131,14 @@ DSN_API dsn_task_t dsn_rpc_create_response_task(dsn_message_t request, dsn_rpc_r t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_rpc_create_response_task_ex(dsn_message_t request, dsn_rpc_response_handler_t cb, dsn_task_cancelled_handler_t on_cancel, void* context, int reply_thread_hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto msg = ((::dsn::message_ex*)request); if (msg == nullptr) { @@ -1064,10 +1150,12 @@ DSN_API dsn_task_t dsn_rpc_create_response_task_ex(dsn_message_t request, dsn_rp t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_error_t dsn_rpc_call(dsn_address_t server, dsn_task_t rpc_call) { + DSN_C_GUARD_BEGIN ::dsn::rpc_response_task* task = (::dsn::rpc_response_task*)rpc_call; if (task == nullptr) { @@ -1104,10 +1192,12 @@ DSN_API dsn_error_t dsn_rpc_call(dsn_address_t server, dsn_task_t rpc_call) msg->server_address = server; rpc->call(msg, task); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_message_t dsn_rpc_call_wait(dsn_address_t server, dsn_message_t request) { + DSN_C_GUARD_BEGIN auto msg = ((::dsn::message_ex*)request); if (msg == nullptr) { @@ -1146,10 +1236,12 @@ DSN_API dsn_message_t dsn_rpc_call_wait(dsn_address_t server, dsn_message_t requ rtask->release_ref(); // added above return nullptr; } + DSN_C_GUARD_END(nullptr) } DSN_API dsn_error_t dsn_rpc_call_one_way(dsn_address_t server, dsn_message_t request) { + DSN_C_GUARD_BEGIN auto msg = ((::dsn::message_ex*)request); if (msg == nullptr) { @@ -1174,10 +1266,12 @@ DSN_API dsn_error_t dsn_rpc_call_one_way(dsn_address_t server, dsn_message_t req rpc->call(msg, nullptr); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_rpc_reply(dsn_message_t response, dsn_error_t err) { + DSN_C_GUARD_BEGIN auto msg = ((::dsn::message_ex*)response); if (msg == nullptr) { @@ -1194,10 +1288,12 @@ DSN_API dsn_error_t dsn_rpc_reply(dsn_message_t response, dsn_error_t err) rpc->reply(msg, err); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_rpc_forward(dsn_message_t request, dsn_address_t addr) { + DSN_C_GUARD_BEGIN auto msg = (::dsn::message_ex*)(request); if (msg == nullptr) { @@ -1221,6 +1317,7 @@ DSN_API dsn_error_t dsn_rpc_forward(dsn_message_t request, dsn_address_t addr) rpc->forward(msg, target); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_message_t dsn_rpc_get_response(dsn_task_t rpc_call) @@ -1250,6 +1347,7 @@ DSN_API dsn_message_t dsn_rpc_get_response(dsn_task_t rpc_call) DSN_API dsn_error_t dsn_rpc_enqueue_response(dsn_task_t rpc_call, dsn_error_t err, dsn_message_t response) { + DSN_C_GUARD_BEGIN ::dsn::rpc_response_task* task = (::dsn::rpc_response_task*)rpc_call; if (task == nullptr) { @@ -1266,6 +1364,7 @@ DSN_API dsn_error_t dsn_rpc_enqueue_response(dsn_task_t rpc_call, dsn_error_t er auto resp = ((::dsn::message_ex*)response); task->enqueue(err, resp); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } //------------------------------------------------------------------------------ @@ -1276,6 +1375,7 @@ DSN_API dsn_error_t dsn_rpc_enqueue_response(dsn_task_t rpc_call, dsn_error_t er DSN_API dsn_handle_t dsn_file_open(const char* file_name, int flag, int pmode) { + DSN_C_GUARD_BEGIN if (file_name == nullptr || file_name[0] == '\0') { derror("dsn_file_open got null or empty file_name"); @@ -1290,10 +1390,12 @@ DSN_API dsn_handle_t dsn_file_open(const char* file_name, int flag, int pmode) } return disk->open(file_name, flag, pmode); + DSN_C_GUARD_END(nullptr) } DSN_API dsn_error_t dsn_file_close(dsn_handle_t file) { + DSN_C_GUARD_BEGIN if (file == nullptr) { derror("dsn_file_close got null file handle"); @@ -1308,10 +1410,12 @@ DSN_API dsn_error_t dsn_file_close(dsn_handle_t file) } return disk->close(file).get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_file_flush(dsn_handle_t file) { + DSN_C_GUARD_BEGIN if (file == nullptr) { derror("dsn_file_flush got null file handle"); @@ -1326,6 +1430,7 @@ DSN_API dsn_error_t dsn_file_flush(dsn_handle_t file) } return disk->flush(file).get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } // native HANDLE: HANDLE for windows, int for non-windows @@ -1343,6 +1448,7 @@ DSN_API void* dsn_file_native_handle(dsn_handle_t file) DSN_API dsn_task_t dsn_file_create_aio_task(dsn_task_code_t code, dsn_aio_handler_t cb, void* context, int hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_AIO) { @@ -1354,12 +1460,14 @@ DSN_API dsn_task_t dsn_file_create_aio_task(dsn_task_code_t code, dsn_aio_handle t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_task_t dsn_file_create_aio_task_ex(dsn_task_code_t code, dsn_aio_handler_t cb, dsn_task_cancelled_handler_t on_cancel, void* context, int hash, dsn_task_tracker_t tracker) { + DSN_C_GUARD_BEGIN auto sp = ::dsn::task_spec::get(code); if (code == ::dsn::TASK_CODE_INVALID || sp == nullptr || sp->type != TASK_TYPE_AIO) { @@ -1371,10 +1479,12 @@ DSN_API dsn_task_t dsn_file_create_aio_task_ex(dsn_task_code_t code, dsn_aio_han t->set_tracker((dsn::task_tracker*)tracker); t->spec().on_task_create.execute(::dsn::task::get_current_task(), t); return t; + DSN_C_GUARD_END(nullptr) } DSN_API dsn_error_t dsn_file_read(dsn_handle_t file, char* buffer, int count, uint64_t offset, dsn_task_t cb) { + DSN_C_GUARD_BEGIN if (file == nullptr) { derror("dsn_file_read got null file handle"); @@ -1416,10 +1526,12 @@ DSN_API dsn_error_t dsn_file_read(dsn_handle_t file, char* buffer, int count, ui disk->read(callback); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_file_write(dsn_handle_t file, const char* buffer, int count, uint64_t offset, dsn_task_t cb) { + DSN_C_GUARD_BEGIN if (file == nullptr) { derror("dsn_file_write got null file handle"); @@ -1461,10 +1573,12 @@ DSN_API dsn_error_t dsn_file_write(dsn_handle_t file, const char* buffer, int co disk->write(callback); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_file_write_vector(dsn_handle_t file, const dsn_file_buffer_t* buffers, int buffer_count, uint64_t offset, dsn_task_t cb) { + DSN_C_GUARD_BEGIN if (file == nullptr) { derror("dsn_file_write_vector got null file handle"); @@ -1528,11 +1642,13 @@ DSN_API dsn_error_t dsn_file_write_vector(dsn_handle_t file, const dsn_file_buff disk->write(callback); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_file_copy_remote_directory(dsn_address_t remote, const char* source_dir, const char* dest_dir, bool overwrite, dsn_task_t cb) { + DSN_C_GUARD_BEGIN if (::dsn::rpc_address(remote).is_invalid()) { derror("dsn_file_copy_remote_directory got invalid remote address"); @@ -1574,10 +1690,12 @@ DSN_API dsn_error_t dsn_file_copy_remote_directory(dsn_address_t remote, const c nfs->call(rci, callback); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API dsn_error_t dsn_file_copy_remote_files(dsn_address_t remote, const char* source_dir, const char** source_files, const char* dest_dir, bool overwrite, dsn_task_t cb) { + DSN_C_GUARD_BEGIN if (::dsn::rpc_address(remote).is_invalid()) { derror("dsn_file_copy_remote_files got invalid remote address"); @@ -1638,6 +1756,7 @@ DSN_API dsn_error_t dsn_file_copy_remote_files(dsn_address_t remote, const char* nfs->call(rci, callback); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } DSN_API size_t dsn_file_get_io_size(dsn_task_t cb_task) @@ -1660,6 +1779,7 @@ DSN_API size_t dsn_file_get_io_size(dsn_task_t cb_task) DSN_API dsn_error_t dsn_file_task_enqueue(dsn_task_t cb_task, dsn_error_t err, size_t size) { + DSN_C_GUARD_BEGIN ::dsn::task* task = (::dsn::task*)cb_task; if (task == nullptr) { @@ -1675,6 +1795,7 @@ DSN_API dsn_error_t dsn_file_task_enqueue(dsn_task_t cb_task, dsn_error_t err, s ((::dsn::aio_task*)task)->enqueue(err, size); return ::dsn::ERR_OK.get(); + DSN_C_GUARD_END(::dsn::ERR_UNKNOWN.get()) } //------------------------------------------------------------------------------ @@ -1796,6 +1917,7 @@ NORETURN DSN_API void dsn_exit(int code) DSN_API bool dsn_mimic_app(const char* app_name, int index) { + DSN_C_GUARD_BEGIN if (app_name == nullptr || app_name[0] == '\0') { derror("dsn_mimic_app got null or empty app_name"); @@ -1844,10 +1966,12 @@ DSN_API bool dsn_mimic_app(const char* app_name, int index) derror("cannot find host app %s with index %d", app_name, index); return false; + DSN_C_GUARD_END(false) } DSN_API const char* dsn_get_app_data_dir(dsn_gpid gpid) { + DSN_C_GUARD_BEGIN if (gpid.value != 0 && (gpid.u.app_id <= 0 || gpid.u.partition_index < 0)) { derror("dsn_get_app_data_dir got invalid gpid = %d.%d", @@ -1858,6 +1982,7 @@ DSN_API const char* dsn_get_app_data_dir(dsn_gpid gpid) auto info = dsn_get_app_info_ptr(gpid); return info ? info->data_dir : nullptr; + DSN_C_GUARD_END(nullptr) } DSN_API bool dsn_get_current_app_info(/*out*/ dsn_app_info* app_info) @@ -1880,6 +2005,7 @@ DSN_API bool dsn_get_current_app_info(/*out*/ dsn_app_info* app_info) DSN_API dsn_app_info* dsn_get_app_info_ptr(dsn_gpid gpid) { + DSN_C_GUARD_BEGIN if (gpid.value != 0 && (gpid.u.app_id <= 0 || gpid.u.partition_index < 0)) { derror("dsn_get_app_info_ptr got invalid gpid = %d.%d", @@ -1900,6 +2026,7 @@ DSN_API dsn_app_info* dsn_get_app_info_ptr(dsn_gpid gpid) } else return nullptr; + DSN_C_GUARD_END(nullptr) } ::dsn::utils::notify_event s_loader_event; diff --git a/src/core/src/service_api_c.test.cpp b/src/core/src/service_api_c.test.cpp index 7ad988383..0ab42f735 100644 --- a/src/core/src/service_api_c.test.cpp +++ b/src/core/src/service_api_c.test.cpp @@ -38,6 +38,7 @@ # include # include # include +# include # include # include "service_engine.h" # include @@ -242,6 +243,25 @@ TEST(core, dsn_config) ASSERT_STREQ("count", buffers[0]); } +TEST(core, dsn_config_get_all_keys_reports_total_count) +{ + const std::string section = "core.test.large_key_count"; + for (int i = 0; i < 140; ++i) + { + const std::string key = "key." + std::to_string(i); + get_main_config()->set(section.c_str(), key.c_str(), "value", ""); + } + + const char* buffers[128]; + int buffer_count = 128; + ASSERT_EQ(140, dsn_config_get_all_keys(section.c_str(), buffers, &buffer_count)); + ASSERT_EQ(128, buffer_count); + + buffer_count = 0; + ASSERT_EQ(140, dsn_config_get_all_keys(section.c_str(), nullptr, &buffer_count)); + ASSERT_EQ(0, buffer_count); +} + TEST(core, dsn_config_invalid_numeric_values) { get_main_config()->set("core.invalid_config", "bad_uint64", "12bad", ""); diff --git a/src/core/src/task.cpp b/src/core/src/task.cpp index 89bf08ba1..0f05f0457 100644 --- a/src/core/src/task.cpp +++ b/src/core/src/task.cpp @@ -283,7 +283,16 @@ void task::signal_waiters() // multiple callers may wait on this bool task::wait(int timeout_milliseconds, bool on_cancel) { - dassert (this != task::get_current_task(), "task cannot wait itself"); + // a task waiting on itself would deadlock; this is a caller misuse rather + // than an internal invariant, so report it through the bool return instead + // of aborting the whole process. dsn_task_wait_timeout (bool) propagates + // this to the caller; the void dsn_task_wait keeps its own dassert because + // it has no error channel. + if (this == task::get_current_task()) + { + derror("task %016llx cannot wait itself", static_cast(id())); + return false; + } auto cs = state(); if (!on_cancel) diff --git a/src/core/src/test/test.config.core.fj.ini b/src/core/src/test/test.config.core.fj.ini index 79589eaea..a6f3094de 100644 --- a/src/core/src/test/test.config.core.fj.ini +++ b/src/core/src/test/test.config.core.fj.ini @@ -96,7 +96,7 @@ rpc_message_crc_required = true rpc_request_drop_ratio = 0 rpc_timeout_milliseconds = 1000 rpc_request_data_corrupted_ratio = 1 -rpc_message_data_corrupted_type = header +rpc_message_data_corrupted_type = header_rpc [task.RPC_TEST_HASH2] is_trace = true @@ -112,7 +112,7 @@ rpc_message_crc_required = true rpc_response_drop_ratio = 0 rpc_timeout_milliseconds = 1000 rpc_response_data_corrupted_ratio = 1 -rpc_message_data_corrupted_type = header +rpc_message_data_corrupted_type = header_id [task.RPC_TEST_HASH4_ACK] is_trace = true @@ -122,6 +122,22 @@ rpc_timeout_milliseconds = 1000 rpc_response_data_corrupted_ratio = 1 rpc_message_data_corrupted_type = body +[task.RPC_TEST_HASH5] +is_trace = true +rpc_message_crc_required = true +rpc_request_drop_ratio = 0 +rpc_timeout_milliseconds = 1000 +rpc_request_data_corrupted_ratio = 1 +rpc_message_data_corrupted_type = header + +[task.RPC_TEST_HASH6_ACK] +is_trace = true +rpc_message_crc_required = true +rpc_response_drop_ratio = 0 +rpc_timeout_milliseconds = 1000 +rpc_response_data_corrupted_ratio = 1 +rpc_message_data_corrupted_type = header + [task.LPC_AIO_IMMEDIATE_CALLBACK] is_trace = false is_profile = false diff --git a/src/plugins/tools.common/asio_net_provider.cpp b/src/plugins/tools.common/asio_net_provider.cpp index dcf862786..07af1fb19 100644 --- a/src/plugins/tools.common/asio_net_provider.cpp +++ b/src/plugins/tools.common/asio_net_provider.cpp @@ -215,13 +215,17 @@ namespace dsn { return _parsers[hdr_format]; } - void asio_udp_provider::do_receive() + bool asio_udp_provider::do_receive() { std::shared_ptr< ::boost::asio::ip::udp::endpoint> send_endpoint(new ::boost::asio::ip::udp::endpoint); _recv_reader.truncate_read(); auto buffer_ptr = _recv_reader.read_buffer_ptr(max_udp_packet_size); - dassert(_recv_reader.read_buffer_capacity() >= max_udp_packet_size, "failed to load enough buffer in parser"); + if (buffer_ptr == nullptr || _recv_reader.read_buffer_capacity() < max_udp_packet_size) + { + derror("%s: asio udp read failed: unable to prepare read buffer", _address.to_string()); + return false; + } _socket->async_receive_from( ::boost::asio::buffer(buffer_ptr, max_udp_packet_size), @@ -281,6 +285,8 @@ namespace dsn { do_receive(); } ); + + return true; } error_code asio_udp_provider::start(rpc_channel channel, int port, bool client_only, io_modifer& ctx) @@ -370,7 +376,10 @@ namespace dsn { return ERR_NETWORK_START_FAILED; } - do_receive(); + if (!do_receive()) + { + return ERR_NETWORK_START_FAILED; + } return ERR_OK; } diff --git a/src/plugins/tools.common/asio_net_provider.h b/src/plugins/tools.common/asio_net_provider.h index a4d1f36e6..d807a50d2 100644 --- a/src/plugins/tools.common/asio_net_provider.h +++ b/src/plugins/tools.common/asio_net_provider.h @@ -85,7 +85,7 @@ namespace dsn { } private: - void do_receive(); + bool do_receive(); // create parser on demand message_parser* get_message_parser(network_header_format hdr_format); diff --git a/src/plugins/tools.common/asio_rpc_session.cpp b/src/plugins/tools.common/asio_rpc_session.cpp index 95bc23fef..2640caa0f 100644 --- a/src/plugins/tools.common/asio_rpc_session.cpp +++ b/src/plugins/tools.common/asio_rpc_session.cpp @@ -100,6 +100,13 @@ namespace dsn { void* ptr = _reader.read_buffer_ptr(read_next); int remaining = _reader.read_buffer_capacity(); + if (ptr == nullptr || remaining <= 0) + { + derror("asio read from %s failed: unable to prepare read buffer", _remote_addr.to_string()); + on_failure(); + release_ref(); + return; + } _socket->async_read_some(boost::asio::buffer(ptr, remaining), [this](boost::system::error_code ec, std::size_t length) diff --git a/src/plugins/tools.common/dsn_message_parser.test.cpp b/src/plugins/tools.common/dsn_message_parser.test.cpp index 887bfd3c5..5edcf2aa6 100644 --- a/src/plugins/tools.common/dsn_message_parser.test.cpp +++ b/src/plugins/tools.common/dsn_message_parser.test.cpp @@ -27,6 +27,7 @@ #include "dsn_message_parser.h" #include #include +#include using namespace dsn; @@ -75,3 +76,15 @@ TEST(tools_common, dsn_message_parser_rejects_non_terminated_names) ASSERT_EQ(nullptr, parser.get_message_on_receive(&reader2, read_next)); ASSERT_EQ(-1, read_next); } + +TEST(tools_common, message_reader_rejects_overflow_read_size) +{ + message_reader reader(16); + ASSERT_NE(nullptr, reader.read_buffer_ptr(8)); + reader.mark_read(8); + + ASSERT_EQ(nullptr, + reader.read_buffer_ptr(std::numeric_limits::max() - 7)); + ASSERT_EQ(8u, reader._buffer_occupied); + ASSERT_EQ(8u, reader.read_buffer_capacity()); +} diff --git a/src/plugins/tools.common/fault_injector.cpp b/src/plugins/tools.common/fault_injector.cpp index 0a2db4d17..94498eefd 100644 --- a/src/plugins/tools.common/fault_injector.cpp +++ b/src/plugins/tools.common/fault_injector.cpp @@ -82,7 +82,7 @@ namespace dsn { CONFIG_FLD(double, double, rpc_request_data_corrupted_ratio, 0, "data corrupted ratio for rpc request message") CONFIG_FLD(double, double, rpc_response_data_corrupted_ratio, 0, "data corrupted ratio for rpc response message") - CONFIG_FLD_STRING(rpc_message_data_corrupted_type, "random", "data corrupted type: random/header/body") + CONFIG_FLD_STRING(rpc_message_data_corrupted_type, "random", "data corrupted type: random/header/header_rpc/header_id/body") CONFIG_FLD(double, double, rpc_request_drop_ratio, 0, "drop ratio for rpc request messages") CONFIG_FLD(double, double, rpc_response_drop_ratio, 0, "drop ratio for rpc response messages") @@ -200,6 +200,69 @@ namespace dsn { } } + static uint32_t compute_body_crc32(message_ex* request) + { + auto& buffers = request->buffers; + int i_max = (int)buffers.size() - 1; + uint32_t crc32 = 0; + size_t len = 0; + for (int i = 0; i <= i_max; i++) + { + uint32_t lcrc; + const void* ptr; + size_t sz; + + if (i == 0) + { + ptr = (const void*)(buffers[i].data() + sizeof(message_header)); + sz = (size_t)buffers[i].length() - sizeof(message_header); + } + else + { + ptr = (const void*)buffers[i].data(); + sz = (size_t)buffers[i].length(); + } + + lcrc = dsn_crc32_compute(ptr, sz, crc32); + crc32 = dsn_crc32_concatenate(0, 0, crc32, len, crc32, lcrc, sz); + len += sz; + } + + dassert(len == (size_t)request->header->body_length, "data length is wrong"); + return crc32; + } + + static void invalidate_body_crc32(message_ex* request) + { + task_spec* spec = task_spec::get(request->local_rpc_code); + if (spec != nullptr && spec->rpc_message_crc_required) + { + uint32_t crc32 = compute_body_crc32(request); + request->header->body_crc32 = (crc32 == 1 ? 2 : 1); + } + } + + static void corrupt_body(message_ex* request) + { + if (request->body_size() == 0) + { + dwarn("skip body data corruption for empty message body"); + return; + } + + replace_value(request->buffers, + dsn_random32(0, request->body_size() - 1) + sizeof(message_header)); + invalidate_body_crc32(request); + } + + static void corrupt_rpc_lookup_header(message_ex* request) + { + replace_value(request->buffers, static_cast(offsetof(message_header, rpc_name))); + replace_value(request->buffers, + static_cast(offsetof(message_header, rpc_code) + + offsetof(fast_code, local_hash))); + } + static void corrupt_data(message_ex* request, const std::string& corrupt_type) { if (corrupt_type == "header") @@ -213,23 +276,21 @@ namespace dsn { header_mutable_offset + dsn_random32(0, header_mutable_size - 1)); } + else if (corrupt_type == "header_id") + { + replace_value(request->buffers, static_cast(offsetof(message_header, id))); + } + else if (corrupt_type == "header_rpc") + { + corrupt_rpc_lookup_header(request); + } else if (corrupt_type == "body") { - if (request->body_size() == 0) - { - dwarn("skip body data corruption for empty message body"); - return; - } - replace_value(request->buffers, dsn_random32(0, request->body_size()-1) + sizeof(message_header)); + corrupt_body(request); } else if (corrupt_type == "random") { - if (request->body_size() == 0) - { - dwarn("skip random data corruption for empty message body"); - return; - } - replace_value(request->buffers, dsn_random32(0, request->body_size()-1) + sizeof(message_header)); + corrupt_body(request); } else { diff --git a/src/plugins/tools.common/http_message_parser.cpp b/src/plugins/tools.common/http_message_parser.cpp index 1a9965088..da69fd33c 100644 --- a/src/plugins/tools.common/http_message_parser.cpp +++ b/src/plugins/tools.common/http_message_parser.cpp @@ -62,6 +62,12 @@ http_message_parser::http_message_parser() auto owner = static_cast(parser->data); owner->_current_message.reset(message_ex::create_receive_message_with_standalone_header(blob())); + if (!owner->_current_message) + { + derror("http message creation failed"); + return 1; + } + owner->_response_parse_state = parsing_nothing; message_header* header = owner->_current_message->header; @@ -379,10 +385,18 @@ message_ex* http_message_parser::get_message_on_receive(message_reader* reader, if (reader->_buffer_occupied > 0) { _current_buffer = reader->_buffer; - auto nparsed = http_parser_execute(&_parser, &_parser_setting, reader->_buffer.data(), reader->_buffer_occupied); + auto occupied = reader->_buffer_occupied; + auto nparsed = http_parser_execute(&_parser, &_parser_setting, reader->_buffer.data(), occupied); _current_buffer = blob(); reader->_buffer = reader->_buffer.range(nparsed); reader->_buffer_occupied -= nparsed; + if (nparsed != occupied) + { + derror("http message parse failed"); + read_next = -1; + return nullptr; + } + if (_parser.upgrade) { derror("unsupported http protocol"); diff --git a/src/plugins/tools.common/raw_message_parser.cpp b/src/plugins/tools.common/raw_message_parser.cpp index 9d87f6922..189f1040d 100644 --- a/src/plugins/tools.common/raw_message_parser.cpp +++ b/src/plugins/tools.common/raw_message_parser.cpp @@ -52,6 +52,12 @@ void raw_message_parser::notify_rpc_session_disconnected(rpc_session *sp) if (!sp->is_client()) { message_ex* special_msg = message_ex::create_receive_message_with_standalone_header(blob()); + if (special_msg == nullptr) + { + derror("raw disconnect message creation failed"); + return; + } + dsn::message_header* header = special_msg->header; header->context.u.is_request = 1; header->context.u.is_forwarded = 0; @@ -102,6 +108,13 @@ message_ex* raw_message_parser::get_message_on_receive(message_reader* reader, / auto msg_length = reader->_buffer_occupied; dsn::blob msg_blob = reader->_buffer.range(0, msg_length); message_ex* new_message = message_ex::create_receive_message_with_standalone_header(msg_blob); + if (new_message == nullptr) + { + derror("raw message creation failed"); + read_next = -1; + return nullptr; + } + message_header* header = new_message->header; header->hdr_length = sizeof(*header); diff --git a/src/plugins/tools.common/thrift_message_parser.cpp b/src/plugins/tools.common/thrift_message_parser.cpp index aad887972..97d62c960 100644 --- a/src/plugins/tools.common/thrift_message_parser.cpp +++ b/src/plugins/tools.common/thrift_message_parser.cpp @@ -287,6 +287,12 @@ namespace dsn { dsn::blob body_data = message_data.range(thrift_header.hdr_length); dsn::message_ex* msg = message_ex::create_receive_message_with_standalone_header(body_data); + if (msg == nullptr) + { + derror("thrift message creation failed"); + return nullptr; + } + dsn::message_header* dsn_hdr = msg->header; std::string fname; diff --git a/src/plugins_ext/rDSN.dist.service b/src/plugins_ext/rDSN.dist.service index b15e9c04d..82ec2cd5b 160000 --- a/src/plugins_ext/rDSN.dist.service +++ b/src/plugins_ext/rDSN.dist.service @@ -1 +1 @@ -Subproject commit b15e9c04d29d84e3ecd9917189056f3327da4290 +Subproject commit 82ec2cd5b43120837f438edee286371ae0bb94c0 diff --git a/src/plugins_ext/rDSN.tools.hpc b/src/plugins_ext/rDSN.tools.hpc index 6a701008a..758c7b5ff 160000 --- a/src/plugins_ext/rDSN.tools.hpc +++ b/src/plugins_ext/rDSN.tools.hpc @@ -1 +1 @@ -Subproject commit 6a701008a1aeec525aeb3e6a6f2e87a40a3755c1 +Subproject commit 758c7b5ff6ce3786fd1a2cf0dc787fb47db7237b