diff --git a/src/BUILD b/src/BUILD index 95a7c5df29..41f7ac8116 100644 --- a/src/BUILD +++ b/src/BUILD @@ -669,6 +669,7 @@ ovms_cc_library( "@tensorflow_serving//tensorflow_serving/util:json_tensor", "@org_tensorflow//tensorflow/core:framework", "//src/kfserving_api:kfserving_api_cpp", + "//src/utils:rapidjson_utils", "libovms_kfs_grpc_inference_service_h", "libovms_kfs_utils", "libovms_tensorinfo", diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index 54d10c94e3..0dd225accb 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -63,6 +63,7 @@ #include "status.hpp" #include "stringutils.hpp" #include "timer.hpp" +#include "utils/rapidjson_utils.hpp" #if (MEDIAPIPE_DISABLE == 0) #include "copyable_object_wrapper.hpp" @@ -520,13 +521,17 @@ static Status createV3HttpPayload( } else if (isApplicationJson) { { OVMS_PROFILE_SCOPE("rapidjson parse"); - parsedJson->Parse(request_body.c_str()); + auto outcome = parseJsonWithDepthLimit(*parsedJson, request_body.c_str()); + if (outcome == JsonParseOutcome::DepthExceeded) { + ensureJsonParserInErrorState(parsedJson); + return Status(StatusCode::JSON_INVALID, "JSON body exceeds maximum nesting depth"); + } + if (outcome == JsonParseOutcome::ParseError) { + ensureJsonParserInErrorState(parsedJson); + return Status(StatusCode::JSON_INVALID, "Cannot parse JSON body"); + } } OVMS_PROFILE_SCOPE("rapidjson validate"); - if (parsedJson->HasParseError()) { - return Status(StatusCode::JSON_INVALID, "Cannot parse JSON body"); - } - if (!parsedJson->IsObject()) { return Status(StatusCode::JSON_INVALID, "JSON body must be an object"); } diff --git a/src/rest_parser.cpp b/src/rest_parser.cpp index 6547009b58..542235b2f0 100644 --- a/src/rest_parser.cpp +++ b/src/rest_parser.cpp @@ -26,9 +26,12 @@ #include "rest_utils.hpp" #include "status.hpp" #include "tfs_frontend/tfs_utils.hpp" +#include "utils/rapidjson_utils.hpp" namespace ovms { +static constexpr int MAX_NESTING_DEPTH = 100; + TFSRestParser::TFSRestParser(const tensor_map_t& tensors) { for (const auto& kv : tensors) { const auto& name = kv.first; @@ -444,10 +447,17 @@ Status TFSRestParser::parseColumnFormat(rapidjson::Value& node) { Status TFSRestParser::parse(const char* json) { rapidjson::Document doc; - if (doc.Parse(json).HasParseError()) { + int errorCode = 0; + std::size_t errorOffset = 0; + auto outcome = parseJsonWithDepthLimit(doc, json, MAX_NESTING_DEPTH, &errorCode, &errorOffset); + if (outcome == JsonParseOutcome::DepthExceeded) { + SPDLOG_DEBUG("Request JSON exceeds maximum nesting depth"); + return Status(StatusCode::JSON_INVALID, "JSON body exceeds maximum nesting depth"); + } + if (outcome == JsonParseOutcome::ParseError) { std::stringstream ss; - ss << "Error: " << rapidjson::GetParseError_En(doc.GetParseError()) - << " Offset: " << doc.GetErrorOffset(); + ss << "Error: " << rapidjson::GetParseError_En(static_cast(errorCode)) + << " Offset: " << errorOffset; const std::string details = ss.str(); SPDLOG_DEBUG("Request is not a valid JSON. {}", details); return Status(StatusCode::JSON_INVALID, details); @@ -765,10 +775,17 @@ Status KFSRestParser::parseInputs(rapidjson::Value& node) { Status KFSRestParser::parse(const char* json) { rapidjson::Document doc; - if (doc.Parse(json).HasParseError()) { + int errorCode = 0; + std::size_t errorOffset = 0; + auto outcome = parseJsonWithDepthLimit(doc, json, MAX_NESTING_DEPTH, &errorCode, &errorOffset); + if (outcome == JsonParseOutcome::DepthExceeded) { + SPDLOG_DEBUG("Request JSON exceeds maximum nesting depth"); + return Status(StatusCode::JSON_INVALID, "JSON body exceeds maximum nesting depth"); + } + if (outcome == JsonParseOutcome::ParseError) { std::stringstream ss; - ss << "Error: " << rapidjson::GetParseError_En(doc.GetParseError()) - << " Offset: " << doc.GetErrorOffset(); + ss << "Error: " << rapidjson::GetParseError_En(static_cast(errorCode)) + << " Offset: " << errorOffset; const std::string details = ss.str(); SPDLOG_DEBUG("Request is not a valid JSON. {}", details); return Status(StatusCode::JSON_INVALID, details); diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index a4e6585af0..112c1d69ad 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -318,6 +318,40 @@ TEST_F(HttpOpenAIHandlerTest, JsonBodyValidButNotAnObject) { ASSERT_EQ(status.string(), "The file is not valid json - JSON body must be an object"); } +TEST_F(HttpOpenAIHandlerTest, JsonBodyExceedsNestingDepth_NestedObjects) { + // Deeply nested objects: {"a":{"a":{"a":...}}} - 200 levels + // Make it valid JSON by using key-value pairs + std::string requestBody; + for (int i = 0; i < 200; i++) { + requestBody += R"({"a":)"; + } + requestBody += "{}"; + for (int i = 0; i < 200; i++) { + requestBody += "}"; + } + + EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); + EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); + EXPECT_CALL(*writer, IsDisconnected()).Times(0); + + auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); + ASSERT_EQ(status.string(), "The file is not valid json - JSON body exceeds maximum nesting depth"); +} + +TEST_F(HttpOpenAIHandlerTest, JsonBodyExceedsNestingDepth_NestedArrays) { + // Deeply nested arrays inside a valid object: {"model":"m","data":[[[...]]]} + std::string requestBody = R"({"model":"m","data":)" + std::string(200, '[') + "0" + std::string(200, ']') + "}"; + + EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); + EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); + EXPECT_CALL(*writer, IsDisconnected()).Times(0); + + auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); + ASSERT_EQ(status.string(), "The file is not valid json - JSON body exceeds maximum nesting depth"); +} + TEST_F(HttpOpenAIHandlerTest, GraphWithANameDoesNotExist) { std::string requestBody = R"( { diff --git a/src/test/kfs_rest_parser_test.cpp b/src/test/kfs_rest_parser_test.cpp index 86e3f62fe3..1037f6719f 100644 --- a/src/test/kfs_rest_parser_test.cpp +++ b/src/test/kfs_rest_parser_test.cpp @@ -1021,3 +1021,13 @@ TEST_F(KFSRestParserTest, parseNegativeBatch) { ASSERT_NE(status, StatusCode::OK) << "for value: " << replace; } } + +static std::string makeNestedArrayJson(int depth) { + return std::string(depth, '[') + "0" + std::string(depth, ']'); +} + +TEST_F(KFSRestParserTest, NestingDepthExceeded_FP32) { + std::string request = R"({"inputs":[{"name":"input0","shape":[1],"datatype":"FP32","data":)" + makeNestedArrayJson(200) + "}]}"; + auto status = parser.parse(request.c_str()); + EXPECT_EQ(status, StatusCode::REST_COULD_NOT_PARSE_INPUT); +} diff --git a/src/test/tfs_rest_parser_column_test.cpp b/src/test/tfs_rest_parser_column_test.cpp index 3366b72dae..fdb8adaec7 100644 --- a/src/test/tfs_rest_parser_column_test.cpp +++ b/src/test/tfs_rest_parser_column_test.cpp @@ -773,3 +773,19 @@ TEST(TFSRestParserColumn, RemoveUnnecessaryInputs_UnexpectedScalarInRequest) { ASSERT_EQ(parser.getProto().inputs().count("m"), 1); // missing in endpoint metadata but exists in request, expect exists after conversion ASSERT_EQ(parser.getProto().inputs().size(), 3); } + +static std::string makeNestedArrayJson(int depth) { + return std::string(depth, '[') + "0" + std::string(depth, ']'); +} + +TEST(TFSRestParserColumn, NestingDepthExceeded_ColumnNamedInputs) { + TFSRestParser parser(prepareTensors({{"i", {1}}})); + std::string request = R"({"signature_name":"","inputs":{"i":)" + makeNestedArrayJson(200) + "}}"; + EXPECT_EQ(parser.parse(request.c_str()), StatusCode::REST_COULD_NOT_PARSE_INPUT); +} + +TEST(TFSRestParserColumn, NestingDepthExceeded_ColumnNoNamedInputs) { + TFSRestParser parser(prepareTensors({{"i", {1}}})); + std::string request = R"({"signature_name":"","inputs":)" + makeNestedArrayJson(200) + "}"; + EXPECT_EQ(parser.parse(request.c_str()), StatusCode::REST_COULD_NOT_PARSE_INPUT); +} diff --git a/src/test/tfs_rest_parser_row_test.cpp b/src/test/tfs_rest_parser_row_test.cpp index dc332f6196..1cd059e054 100644 --- a/src/test/tfs_rest_parser_row_test.cpp +++ b/src/test/tfs_rest_parser_row_test.cpp @@ -739,3 +739,19 @@ TEST(TFSRestParserRow, RemoveUnnecessaryInputs) { ASSERT_EQ(parser.getProto().inputs().count("k"), 1); ASSERT_EQ(parser.getProto().inputs().count("l"), 1); } + +static std::string makeNestedArrayJson(int depth) { + return std::string(depth, '[') + "0" + std::string(depth, ']'); +} + +TEST(TFSRestParserRow, NestingDepthExceeded_RowNamedInstances) { + TFSRestParser parser(prepareTensors({{"i", {1}}})); + std::string request = R"({"signature_name":"","instances":[{"i":)" + makeNestedArrayJson(200) + "}]}"; + EXPECT_EQ(parser.parse(request.c_str()), StatusCode::REST_COULD_NOT_PARSE_INSTANCE); +} + +TEST(TFSRestParserRow, NestingDepthExceeded_RowNoNamedInstances) { + TFSRestParser parser(prepareTensors({{"i", {1}}})); + std::string request = R"({"signature_name":"","instances":)" + makeNestedArrayJson(200) + "}"; + EXPECT_EQ(parser.parse(request.c_str()), StatusCode::REST_COULD_NOT_PARSE_INSTANCE); +} diff --git a/src/utils/rapidjson_utils.cpp b/src/utils/rapidjson_utils.cpp index e41dd08b9b..dffecc95c5 100644 --- a/src/utils/rapidjson_utils.cpp +++ b/src/utils/rapidjson_utils.cpp @@ -15,11 +15,15 @@ //***************************************************************************** #include "rapidjson_utils.hpp" +#include #include #pragma warning(push) #pragma warning(disable : 6313) #include +#include +#include +#include #include "src/port/rapidjson_stringbuffer.hpp" #include "src/port/rapidjson_writer.hpp" #pragma warning(pop) @@ -31,4 +35,28 @@ std::string documentToString(const rapidjson::Document& doc) { doc.Accept(writer); return buffer.GetString(); } + +JsonParseOutcome parseJsonWithDepthLimit( + rapidjson::Document& doc, + const char* json, + std::size_t maxDepth, + int* errorCode, + std::size_t* errorOffset) { + rapidjson::Reader reader; + rapidjson::StringStream ss(json); + DepthLimitFilter filter(doc, maxDepth); + if (!reader.Parse(ss, filter)) { + if (errorCode != nullptr) { + *errorCode = static_cast(reader.GetParseErrorCode()); + } + if (errorOffset != nullptr) { + *errorOffset = reader.GetErrorOffset(); + } + if (reader.GetParseErrorCode() == rapidjson::kParseErrorTermination) { + return JsonParseOutcome::DepthExceeded; + } + return JsonParseOutcome::ParseError; + } + return JsonParseOutcome::Ok; +} } // namespace ovms diff --git a/src/utils/rapidjson_utils.hpp b/src/utils/rapidjson_utils.hpp index 08c4d0ed82..95f3034ab9 100644 --- a/src/utils/rapidjson_utils.hpp +++ b/src/utils/rapidjson_utils.hpp @@ -14,10 +14,77 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** +#include #include #include "src/port/rapidjson_document.hpp" namespace ovms { std::string documentToString(const rapidjson::Document& doc); + +// Default maximum nesting depth allowed for incoming JSON request bodies. +inline constexpr std::size_t DEFAULT_MAX_JSON_NESTING_DEPTH = 100; + +// SAX filter that forwards events to an inner handler while enforcing a +// maximum nesting depth. Returning false from StartObject/StartArray aborts +// parsing with rapidjson::kParseErrorTermination, which avoids materializing +// the entire DOM for pathologically nested payloads. +template +struct DepthLimitFilter { + Inner& inner; + std::size_t depth{0}; + const std::size_t maxDepth; + + DepthLimitFilter(Inner& i, std::size_t m) : + inner(i), + maxDepth(m) {} + DepthLimitFilter(const DepthLimitFilter&) = delete; + DepthLimitFilter& operator=(const DepthLimitFilter&) = delete; + + bool Null() { return inner.Null(); } + bool Bool(bool b) { return inner.Bool(b); } + bool Int(int v) { return inner.Int(v); } + bool Uint(unsigned v) { return inner.Uint(v); } + bool Int64(int64_t v) { return inner.Int64(v); } + bool Uint64(uint64_t v) { return inner.Uint64(v); } + bool Double(double v) { return inner.Double(v); } + bool RawNumber(const char* s, rapidjson::SizeType l, bool c) { return inner.RawNumber(s, l, c); } + bool String(const char* s, rapidjson::SizeType l, bool c) { return inner.String(s, l, c); } + bool Key(const char* s, rapidjson::SizeType l, bool c) { return inner.Key(s, l, c); } + bool StartObject() { + if (++depth > maxDepth) + return false; + return inner.StartObject(); + } + bool StartArray() { + if (++depth > maxDepth) + return false; + return inner.StartArray(); + } + bool EndObject(rapidjson::SizeType n) { + --depth; + return inner.EndObject(n); + } + bool EndArray(rapidjson::SizeType n) { + --depth; + return inner.EndArray(n); + } +}; + +enum class JsonParseOutcome { + Ok, + DepthExceeded, + ParseError, +}; + +// Populates `doc` from `json` using rapidjson's iterative parser wrapped with a +// DepthLimitFilter. On DepthExceeded / ParseError the rapidjson error details +// (code, offset) are accessible through `errorCode` / `errorOffset` if non-null. +// Iterative parsing avoids native stack recursion regardless of input depth. +JsonParseOutcome parseJsonWithDepthLimit( + rapidjson::Document& doc, + const char* json, + std::size_t maxDepth = DEFAULT_MAX_JSON_NESTING_DEPTH, + int* errorCode = nullptr, + std::size_t* errorOffset = nullptr); } // namespace ovms