diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 1abd392d0..aa08d597e 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -204,10 +204,19 @@ private static void handleToolCalls( ctx.sendEvent(toolRequestEvent); } + static String cleanLlmResponse(String rawResponse) { + String trimmed = rawResponse.trim(); + if (trimmed.startsWith("```")) { + return trimmed.replaceAll("(?s)^```(?:json)?\\s*(.*?)\\s*```$", "$1"); + } + return trimmed; + } + @SuppressWarnings("unchecked") private static ChatMessage generateStructuredOutput(ChatMessage response, Object outputSchema) throws JsonProcessingException { String output = response.getContent(); + output = cleanLlmResponse(output); Object structuredOutput; if (outputSchema instanceof Class) { structuredOutput = mapper.readValue(String.valueOf(output), (Class) outputSchema); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java new file mode 100644 index 000000000..d7f117850 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** Tests for {@link ChatModelAction}. */ +class ChatModelActionTest { + + @Test + void testCleanLlmResponseWithJsonBlock() { + String input = "```json\n{\"key\": \"value\"}\n```"; + String expected = "{\"key\": \"value\"}"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } + + @Test + void testCleanLlmResponseWithGenericCodeBlock() { + String input = "```\n{\"key\": \"value\"}\n```"; + String expected = "{\"key\": \"value\"}"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } + + @Test + void testCleanLlmResponseWithWhitespace() { + String input = " ```json\n{\"key\": \"value\"}\n``` "; + String expected = "{\"key\": \"value\"}"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } + + @Test + void testCleanLlmResponseWithoutBlock() { + String input = "{\"key\": \"value\"}"; + String expected = "{\"key\": \"value\"}"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } + + @Test + void testCleanLlmResponseWithTextAround() { + // Current implementation uses replaceAll with ^ and $ anchors, + // so it only matches if the whole (trimmed) string is a code block. + String input = "Here is the json: ```json {\"key\": \"value\"} ```"; + String expected = "Here is the json: ```json {\"key\": \"value\"} ```"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } + + @Test + void testCleanLlmResponseWithMultipleLinesInBlock() { + String input = "```json\n{\n \"key\": \"value\"\n}\n```"; + String expected = "{\n \"key\": \"value\"\n}"; + assertEquals(expected, ChatModelAction.cleanLlmResponse(input)); + } +} diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index a697f7237..55aa3965a 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -18,6 +18,7 @@ import copy import json import logging +import re import time from typing import TYPE_CHECKING, Dict, List, cast from uuid import UUID @@ -198,7 +199,7 @@ def _generate_structured_output( ) -> ChatMessage: """Deserialize output to expected output schema.""" output_schema = output_schema.output_schema - output = json.loads(response.content.strip()) + output = json.loads(_clean_llm_response(response.content)) if isinstance(output_schema, type) and issubclass(output_schema, BaseModel): output = output_schema.model_validate(output) @@ -213,6 +214,13 @@ def _generate_structured_output( return response +def _clean_llm_response(raw_response: str) -> str: + trimmed = raw_response.strip() + if trimmed.startswith("```"): + return re.sub(r"(?s)^```(?:json)?\s*(.*?)\s*```$", r"\1", trimmed) + return trimmed + + async def chat( initial_request_id: UUID, model: str, diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action.py b/python/flink_agents/plan/tests/actions/test_chat_model_action.py new file mode 100644 index 000000000..45f5944e9 --- /dev/null +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action.py @@ -0,0 +1,54 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.plan.actions.chat_model_action import _clean_llm_response + + +def test_clean_llm_response_with_json_block(): + input_str = "```json\n{\"key\": \"value\"}\n```" + expected = "{\"key\": \"value\"}" + assert _clean_llm_response(input_str) == expected + + +def test_clean_llm_response_with_generic_code_block(): + input_str = "```\n{\"key\": \"value\"}\n```" + expected = "{\"key\": \"value\"}" + assert _clean_llm_response(input_str) == expected + + +def test_clean_llm_response_with_whitespace(): + input_str = " ```json\n{\"key\": \"value\"}\n``` " + expected = "{\"key\": \"value\"}" + assert _clean_llm_response(input_str) == expected + + +def test_clean_llm_response_without_block(): + input_str = "{\"key\": \"value\"}" + expected = "{\"key\": \"value\"}" + assert _clean_llm_response(input_str) == expected + + +def test_clean_llm_response_with_text_around(): + input_str = "Here is the json: ```json {\"key\": \"value\"} ```" + expected = "Here is the json: ```json {\"key\": \"value\"} ```" + assert _clean_llm_response(input_str) == expected + + +def test_clean_llm_response_with_multiple_lines_in_block(): + input_str = "```json\n{\n \"key\": \"value\"\n}\n```" + expected = "{\n \"key\": \"value\"\n}" + assert _clean_llm_response(input_str) == expected