diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py index 4a6c331b5..5c1d90a81 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py @@ -18,6 +18,7 @@ try: import dotenv + dotenv.load_dotenv() except ImportError: # dotenv is optional for this test server @@ -34,6 +35,7 @@ def ask_sum(a: int, b: int) -> str: """Prompt of add tool.""" return f"Can you please calculate the sum of {a} and {b}?" + @mcp.tool() async def add(a: int, b: int) -> int: """Get the detailed information of a specified IP address. @@ -47,4 +49,5 @@ async def add(a: int, b: int) -> int: """ return a + b + mcp.run("streamable-http") diff --git a/e2e-test/test-scripts/check_resource_consistency.py b/e2e-test/test-scripts/check_resource_consistency.py index 6734c9bd1..82a23ab6f 100644 --- a/e2e-test/test-scripts/check_resource_consistency.py +++ b/e2e-test/test-scripts/check_resource_consistency.py @@ -34,9 +34,7 @@ def parse_java_resource_name(java_path: Path) -> dict: r'public\s+static\s+final\s+String\s+([A-Za-z0-9_]+)\s*=\s*"([^"]+)";', re.DOTALL, ) - class_re = re.compile( - r"public\s+(?:static\s+)?final\s+class\s+(\w+)\s*\{" - ) + class_re = re.compile(r"public\s+(?:static\s+)?final\s+class\s+(\w+)\s*\{") class_stack = [] brace_depth = 0 @@ -200,35 +198,42 @@ def _parse_python_resource_name(python_path: Path) -> dict: elif len(parts) == 3 and parts[2] == "Java": python_map[(rt, "Java")] = consts if "ResourceName" in result and "MCP_SERVER" in result["ResourceName"]: - python_map[("MCP", "Python")] = {"MCP_SERVER": result["ResourceName"]["MCP_SERVER"]} + python_map[("MCP", "Python")] = { + "MCP_SERVER": result["ResourceName"]["MCP_SERVER"] + } return python_map -_JAVA_ONLY_NAMES = frozenset({ - "PYTHON_WRAPPER_CONNECTION", "PYTHON_WRAPPER_SETUP", - "PYTHON_WRAPPER_VECTOR_STORE", "PYTHON_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE", -}) -_PYTHON_ONLY_NAMES = frozenset({ - "JAVA_WRAPPER_CONNECTION", "JAVA_WRAPPER_SETUP", - "JAVA_WRAPPER_VECTOR_STORE", "JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE", -}) +_JAVA_ONLY_NAMES = frozenset( + { + "PYTHON_WRAPPER_CONNECTION", + "PYTHON_WRAPPER_SETUP", + "PYTHON_WRAPPER_VECTOR_STORE", + "PYTHON_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE", + } +) +_PYTHON_ONLY_NAMES = frozenset( + { + "JAVA_WRAPPER_CONNECTION", + "JAVA_WRAPPER_SETUP", + "JAVA_WRAPPER_VECTOR_STORE", + "JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE", + } +) def _find_python_name_for_value(impls: dict, value: str, java_name: str) -> str | None: return java_name if impls.get(java_name) == value else None -def check_consistency( - java_map: dict, python_map: dict -) -> tuple[list[str], list[str]]: - +def check_consistency(java_map: dict, python_map: dict) -> tuple[list[str], list[str]]: errors = [] warnings = [] all_resource_types = set() - for (rt, _) in java_map: + for rt, _ in java_map: all_resource_types.add(rt) - for (rt, _) in python_map: + for rt, _ in python_map: all_resource_types.add(rt) for resource_type in sorted(all_resource_types): @@ -287,7 +292,10 @@ def check_consistency( def main() -> int: root = Path(__file__).resolve().parent.parent.parent - java_path = root / "api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java" + java_path = ( + root + / "api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java" + ) python_path = root / "python/flink_agents/api/resource.py" if not java_path.exists(): @@ -303,8 +311,19 @@ def main() -> int: debug = __import__("os").environ.get("RESOURCE_DEBUG") if debug: import json - print("Java map:", json.dumps({str(k): v for k, v in java_map.items()}, indent=2, ensure_ascii=False)) - print("Python map:", json.dumps({str(k): v for k, v in python_map.items()}, indent=2, ensure_ascii=False)) + + print( + "Java map:", + json.dumps( + {str(k): v for k, v in java_map.items()}, indent=2, ensure_ascii=False + ), + ) + print( + "Python map:", + json.dumps( + {str(k): v for k, v in python_map.items()}, indent=2, ensure_ascii=False + ), + ) errors, warnings = check_consistency(java_map, python_map) diff --git a/python/_build_backend/tests/test_backend.py b/python/_build_backend/tests/test_backend.py index 3883dc70d..a8fc9b6d3 100644 --- a/python/_build_backend/tests/test_backend.py +++ b/python/_build_backend/tests/test_backend.py @@ -54,7 +54,7 @@ def _write_manifest(path: Path, manifest: dict) -> None: # --------------------------------------------------------------------------- -class TestJarFilename: # noqa: D101 +class TestJarFilename: def test_without_classifier(self) -> None: entry = { "artifact_id": "flink-agents-dist-common", @@ -76,7 +76,7 @@ def test_with_classifier(self) -> None: ) -class TestLoadManifest: # noqa: D101 +class TestLoadManifest: def test_load(self, tmp_path) -> None: manifest = { "maven_base_url": "https://repo1.maven.org/maven2", @@ -89,7 +89,7 @@ def test_load(self, tmp_path) -> None: assert loaded == manifest -class TestVerifyChecksum: # noqa: D101 +class TestVerifyChecksum: def test_valid_checksum(self, tmp_path) -> None: content = b"fake jar content" jar = tmp_path / "test.jar" @@ -106,7 +106,7 @@ def test_invalid_checksum(self, tmp_path) -> None: assert not jar.exists() -class TestEnsureJars: # noqa: D101 +class TestEnsureJars: def test_skip_when_no_manifest(self, tmp_path, monkeypatch) -> None: monkeypatch.chdir(tmp_path) _ensure_jars() diff --git a/python/flink_agents/api/agents/agent.py b/python/flink_agents/api/agents/agent.py index 7aab9e771..3ea75bbfe 100644 --- a/python/flink_agents/api/agents/agent.py +++ b/python/flink_agents/api/agents/agent.py @@ -133,7 +133,10 @@ def add_action( return self def add_resource( - self, name: str, resource_type: ResourceType, instance: SerializableResource | ResourceDescriptor + self, + name: str, + resource_type: ResourceType, + instance: SerializableResource | ResourceDescriptor, ) -> "Agent": """Add resource to agent instance. diff --git a/python/flink_agents/api/agents/tests/test_row_schema.py b/python/flink_agents/api/agents/tests/test_row_schema.py index b3a663eaa..3f50ebb9a 100644 --- a/python/flink_agents/api/agents/tests/test_row_schema.py +++ b/python/flink_agents/api/agents/tests/test_row_schema.py @@ -21,7 +21,7 @@ from flink_agents.api.agents.react_agent import OutputSchema -def test_output_schema_serializable() -> None: # noqa: D103 +def test_output_schema_serializable() -> None: schema = OutputSchema( output_schema=RowTypeInfo( [BasicTypeInfo.INT_TYPE_INFO()], diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index b3d8c5c8a..5919fef5b 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -214,7 +214,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: # Call chat model connection to execute chat merged_kwargs = self.model_kwargs.copy() merged_kwargs.update(kwargs) - return self._get_connection().chat(messages, tools=self._get_tools(), **merged_kwargs) + return self._get_connection().chat( + messages, tools=self._get_tools(), **merged_kwargs + ) def _record_token_metrics( self, model_name: str, prompt_tokens: int, completion_tokens: int @@ -256,4 +258,3 @@ def _get_tools(self) -> List[Tool]: err_msg = f"Expect Tool, but is {tool.__class__.__name__}" raise TypeError(err_msg) return self.tools - diff --git a/python/flink_agents/api/chat_models/java_chat_model.py b/python/flink_agents/api/chat_models/java_chat_model.py index 4c7415b34..be0768fa3 100644 --- a/python/flink_agents/api/chat_models/java_chat_model.py +++ b/python/flink_agents/api/chat_models/java_chat_model.py @@ -31,7 +31,8 @@ class JavaChatModelConnection(BaseChatModelConnection): unlike JavaChatModelSetup, it does not provide direct chat functionality in Python. """ - java_class_name: str="" + java_class_name: str = "" + @java_resource class JavaChatModelSetup(BaseChatModelSetup): @@ -43,4 +44,4 @@ class JavaChatModelSetup(BaseChatModelSetup): implementation. """ - java_class_name: str="" + java_class_name: str = "" diff --git a/python/flink_agents/api/configuration.py b/python/flink_agents/api/configuration.py index 940655d26..ea1dd2376 100644 --- a/python/flink_agents/api/configuration.py +++ b/python/flink_agents/api/configuration.py @@ -30,11 +30,14 @@ class ConfigOption: default: The default value for this configuration option """ - def __init__(self, key: str, config_type: Type[Any], default: Any | None=None) -> None: + def __init__( + self, key: str, config_type: Type[Any], default: Any | None = None + ) -> None: """Initialize a configuration option.""" self._key = key self._type = config_type self._default_value = default + def get_key(self) -> str: """Gets the configuration key.""" return self._key @@ -47,6 +50,7 @@ def get_default_value(self) -> Any: """Returns the default value.""" return self._default_value + class WritableConfiguration(ABC): """Abstract base class providing write access to a configuration object. @@ -98,6 +102,7 @@ def set(self, option: ConfigOption, value: Any) -> None: value: The value to set for the key """ + class ReadableConfiguration(ABC): """Abstract base class providing read access to a configuration object. @@ -105,7 +110,7 @@ class ReadableConfiguration(ABC): """ @abstractmethod - def get_int(self, key: str, default: int | None=None) -> int: + def get_int(self, key: str, default: int | None = None) -> int: """Get the int configuration value by key. Args: @@ -117,7 +122,7 @@ def get_int(self, key: str, default: int | None=None) -> int: """ @abstractmethod - def get_float(self, key: str, default: float | None=None) -> float: + def get_float(self, key: str, default: float | None = None) -> float: """Get the float configuration value by key. Args: @@ -129,7 +134,7 @@ def get_float(self, key: str, default: float | None=None) -> float: """ @abstractmethod - def get_bool(self, key: str, default: bool | None=None) -> bool: + def get_bool(self, key: str, default: bool | None = None) -> bool: """Get the boolean configuration value by key. Args: @@ -141,7 +146,7 @@ def get_bool(self, key: str, default: bool | None=None) -> bool: """ @abstractmethod - def get_str(self, key: str, default: str | None=None) -> str: + def get_str(self, key: str, default: str | None = None) -> str: """Get the string configuration value by key. Args: @@ -163,6 +168,7 @@ def get(self, option: ConfigOption) -> Any: The value of the given option """ + class Configuration(WritableConfiguration, ReadableConfiguration, ABC): """A configuration object that provides both read and write access to a configuration object. diff --git a/python/flink_agents/api/decorators.py b/python/flink_agents/api/decorators.py index 3289d8a74..231d52729 100644 --- a/python/flink_agents/api/decorators.py +++ b/python/flink_agents/api/decorators.py @@ -189,6 +189,7 @@ def vector_store(func: Callable) -> Callable: func._is_vector_store = True return func + def java_resource(cls: Type) -> Type: """Decorator to mark a class as Java resource.""" cls._is_java_resource = True diff --git a/python/flink_agents/api/embedding_models/java_embedding_model.py b/python/flink_agents/api/embedding_models/java_embedding_model.py index 18d2f9a0b..53b1d3170 100644 --- a/python/flink_agents/api/embedding_models/java_embedding_model.py +++ b/python/flink_agents/api/embedding_models/java_embedding_model.py @@ -32,7 +32,8 @@ class JavaEmbeddingModelConnection(BaseEmbeddingModelConnection): functionality in Python. """ - java_class_name: str="" + java_class_name: str = "" + @java_resource class JavaEmbeddingModelSetup(BaseEmbeddingModelSetup): @@ -44,4 +45,4 @@ class JavaEmbeddingModelSetup(BaseEmbeddingModelSetup): implementation. """ - java_class_name: str="" + java_class_name: str = "" diff --git a/python/flink_agents/api/events/context_retrieval_event.py b/python/flink_agents/api/events/context_retrieval_event.py index 24b0e1cca..a41a3b7c1 100644 --- a/python/flink_agents/api/events/context_retrieval_event.py +++ b/python/flink_agents/api/events/context_retrieval_event.py @@ -34,6 +34,7 @@ class ContextRetrievalRequestEvent(Event): max_results : int Maximum number of results to return (default: 3) """ + query: str vector_store: str max_results: int = 3 @@ -51,6 +52,7 @@ class ContextRetrievalResponseEvent(Event): documents : List[Document] List of retrieved documents from the vector store """ + request_id: UUID query: str documents: List[Document] diff --git a/python/flink_agents/api/execution_environment.py b/python/flink_agents/api/execution_environment.py index d9cdeff1f..965daa4d4 100644 --- a/python/flink_agents/api/execution_environment.py +++ b/python/flink_agents/api/execution_environment.py @@ -237,7 +237,10 @@ def execute(self, job_name: str | None = None) -> None: """Execute agent individually.""" def add_resource( - self, name: str, resource_type: ResourceType, instance: SerializableResource | ResourceDescriptor + self, + name: str, + resource_type: ResourceType, + instance: SerializableResource | ResourceDescriptor, ) -> "AgentsExecutionEnvironment": """Register resource to agent execution environment. diff --git a/python/flink_agents/api/memory/long_term_memory.py b/python/flink_agents/api/memory/long_term_memory.py index 6f7934899..c74ee1513 100644 --- a/python/flink_agents/api/memory/long_term_memory.py +++ b/python/flink_agents/api/memory/long_term_memory.py @@ -34,6 +34,7 @@ ItemType = str | ChatMessage + class CompactionConfig(BaseModel): """Compaction configuration. @@ -48,6 +49,7 @@ class CompactionConfig(BaseModel): prompt: str | Prompt | None = None limit: int = 1 + class LongTermMemoryBackend(Enum): """Backend for Long-Term Memory.""" diff --git a/python/flink_agents/api/memory/tests/test_long_term_memory.py b/python/flink_agents/api/memory/tests/test_long_term_memory.py index 916ac2c28..64fa96913 100644 --- a/python/flink_agents/api/memory/tests/test_long_term_memory.py +++ b/python/flink_agents/api/memory/tests/test_long_term_memory.py @@ -23,7 +23,7 @@ ) -def test_memory_set_serialization() -> None: # noqa:D103 +def test_memory_set_serialization() -> None: memory_set = MemorySet( name="chat_history", item_type=ChatMessage, diff --git a/python/flink_agents/api/memory_object.py b/python/flink_agents/api/memory_object.py index 5e700b096..2361abeb5 100644 --- a/python/flink_agents/api/memory_object.py +++ b/python/flink_agents/api/memory_object.py @@ -24,11 +24,14 @@ if TYPE_CHECKING: from flink_agents.api.memory_reference import MemoryRef + class MemoryType(Enum): """Memory types based on MemoryObject.""" - SENSORY = "sensory", + + SENSORY = ("sensory",) SHORT_TERM = "short_term" + class MemoryObject(BaseModel, ABC): """Representation of an object in the short-term memory. @@ -38,7 +41,7 @@ class MemoryObject(BaseModel, ABC): """ @abstractmethod - def get(self, path_or_ref: Union[str,"MemoryRef"] ) -> Any: + def get(self, path_or_ref: Union[str, "MemoryRef"]) -> Any: """Get the value of a (direct or indirect) field or a MemoryRef in the object. Parameters diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index bf8d4ff9e..c97018ec5 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -37,7 +37,9 @@ class AsyncExecutionResult: and `asyncio.sleep` are NOT supported because there is no asyncio event loop. """ - def __init__(self, executor: Any, func: Callable, args: tuple, kwargs: dict) -> None: + def __init__( + self, executor: Any, func: Callable, args: tuple, kwargs: dict + ) -> None: """Initialize an AsyncExecutionResult. Parameters @@ -90,7 +92,9 @@ def send_event(self, event: Event) -> None: """ @abstractmethod - def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: + def get_resource( + self, name: str, type: ResourceType, metric_group: MetricGroup = None + ) -> Resource: """Get resource from context. Parameters diff --git a/python/flink_agents/api/tests/test_decorators.py b/python/flink_agents/api/tests/test_decorators.py index 697f6c981..f4242c1ba 100644 --- a/python/flink_agents/api/tests/test_decorators.py +++ b/python/flink_agents/api/tests/test_decorators.py @@ -24,7 +24,7 @@ from flink_agents.api.runner_context import RunnerContext -def test_action_decorator() -> None: # noqa D103 +def test_action_decorator() -> None: @action(InputEvent) def forward_action(event: Event, ctx: RunnerContext) -> None: input = event.input @@ -35,7 +35,7 @@ def forward_action(event: Event, ctx: RunnerContext) -> None: assert listen_events == (InputEvent,) -def test_action_decorator_listen_multi_events() -> None: # noqa D103 +def test_action_decorator_listen_multi_events() -> None: @action(InputEvent, OutputEvent) def forward_action(event: Event, ctx: RunnerContext) -> None: input = event.input @@ -46,7 +46,7 @@ def forward_action(event: Event, ctx: RunnerContext) -> None: assert listen_events == (InputEvent, OutputEvent) -def test_action_decorator_listen_no_event() -> None: # noqa D103 +def test_action_decorator_listen_no_event() -> None: with pytest.raises(AssertionError): @action() @@ -55,7 +55,7 @@ def forward_action(event: Event, ctx: RunnerContext) -> None: ctx.send_event(OutputEvent(output=input)) -def test_action_decorator_listen_non_event_type() -> None: # noqa D103 +def test_action_decorator_listen_non_event_type() -> None: with pytest.raises(AssertionError): @action(List) diff --git a/python/flink_agents/api/tests/test_event.py b/python/flink_agents/api/tests/test_event.py index 0b909091b..5ba0825fa 100644 --- a/python/flink_agents/api/tests/test_event.py +++ b/python/flink_agents/api/tests/test_event.py @@ -25,45 +25,45 @@ from flink_agents.api.events.event import Event, InputEvent, OutputEvent -def test_event_init_serializable() -> None: # noqa D103 +def test_event_init_serializable() -> None: Event(a=1, b=InputEvent(input=1), c=OutputEvent(output="111")) -def test_event_init_non_serializable() -> None: # noqa D103 +def test_event_init_non_serializable() -> None: with pytest.raises(ValidationError): Event(a=1, b=Type[InputEvent]) -def test_event_setattr_serializable() -> None: # noqa D103 +def test_event_setattr_serializable() -> None: event = Event(a=1) event.c = Event() -def test_event_setattr_non_serializable() -> None: # noqa D103 +def test_event_setattr_non_serializable() -> None: event = Event(a=1) with pytest.raises(PydanticSerializationError): event.c = Type[InputEvent] -def test_input_event_ignore_row_unserializable() -> None: # noqa D103 +def test_input_event_ignore_row_unserializable() -> None: InputEvent(input=Row({"a": 1})) -def test_event_row_with_non_serializable_fails() -> None: # noqa D103 +def test_event_row_with_non_serializable_fails() -> None: with pytest.raises(ValidationError): Event(row_field=Row({"a": 1}), non_serializable_field=Type[InputEvent]) -def test_event_multiple_rows_serializable() -> None: # noqa D103 +def test_event_multiple_rows_serializable() -> None: Event(row1=Row({"a": 1}), row2=Row({"b": 2}), normal_field="test") -def test_event_setattr_row_serializable() -> None: # noqa D103 +def test_event_setattr_row_serializable() -> None: event = Event(a=1) event.row_field = Row({"key": "value"}) -def test_event_json_serialization_with_row() -> None: # noqa D103 +def test_event_json_serialization_with_row() -> None: event = InputEvent(input=Row({"test": "data"})) json_str = event.model_dump_json() assert "test" in json_str diff --git a/python/flink_agents/api/tests/test_prompt.py b/python/flink_agents/api/tests/test_prompt.py index 71ba0752e..599e600dc 100644 --- a/python/flink_agents/api/tests/test_prompt.py +++ b/python/flink_agents/api/tests/test_prompt.py @@ -22,7 +22,7 @@ @pytest.fixture(scope="module") -def text_prompt() -> Prompt: # noqa: D103 +def text_prompt() -> Prompt: template = ( "You ara a product review analyzer, please generate a score and the dislike reasons" "(if any) for the review. " @@ -32,7 +32,7 @@ def text_prompt() -> Prompt: # noqa: D103 return Prompt.from_text(text=template) -def test_prompt_from_text_to_string(text_prompt: LocalPrompt) -> None: # noqa: D103 +def test_prompt_from_text_to_string(text_prompt: LocalPrompt) -> None: assert text_prompt.format_string( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -45,7 +45,7 @@ def test_prompt_from_text_to_string(text_prompt: LocalPrompt) -> None: # noqa: ) -def test_prompt_from_text_to_messages(text_prompt: LocalPrompt) -> None: # noqa: D103 +def test_prompt_from_text_to_messages(text_prompt: LocalPrompt) -> None: assert text_prompt.format_messages( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -62,7 +62,7 @@ def test_prompt_from_text_to_messages(text_prompt: LocalPrompt) -> None: # noqa @pytest.fixture(scope="module") -def messages_prompt() -> Prompt: # noqa: D103 +def messages_prompt() -> Prompt: template = [ ChatMessage( role=MessageRole.SYSTEM, @@ -78,7 +78,7 @@ def messages_prompt() -> Prompt: # noqa: D103 return Prompt.from_messages(messages=template) -def test_prompt_from_messages_to_string(messages_prompt: LocalPrompt) -> None: # noqa: D103 +def test_prompt_from_messages_to_string(messages_prompt: LocalPrompt) -> None: assert messages_prompt.format_string( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -92,7 +92,7 @@ def test_prompt_from_messages_to_string(messages_prompt: LocalPrompt) -> None: ) -def test_prompt_from_messages_to_messages(messages_prompt: LocalPrompt) -> None: # noqa: D103 +def test_prompt_from_messages_to_messages(messages_prompt: LocalPrompt) -> None: assert messages_prompt.format_messages( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -112,7 +112,7 @@ def test_prompt_from_messages_to_messages(messages_prompt: LocalPrompt) -> None: ] -def test_prompt_lack_one_argument(text_prompt: LocalPrompt) -> None: # noqa: D103 +def test_prompt_lack_one_argument(text_prompt: LocalPrompt) -> None: assert text_prompt.format_string( product_id="12345", review="The headphones broke after one week of use. Very poor quality", @@ -123,7 +123,7 @@ def test_prompt_lack_one_argument(text_prompt: LocalPrompt) -> None: # noqa: D1 ) -def test_prompt_contain_json_schema() -> None: # noqa: D103 +def test_prompt_contain_json_schema() -> None: prompt = Prompt.from_text( text=f"The json schema is {LocalPrompt.model_json_schema(mode='serialization')}", ) diff --git a/python/flink_agents/api/tests/test_resource_name.py b/python/flink_agents/api/tests/test_resource_name.py index 5575f3509..fe4a2e984 100644 --- a/python/flink_agents/api/tests/test_resource_name.py +++ b/python/flink_agents/api/tests/test_resource_name.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################ """Verify ResourceName Python paths resolve to existing, importable classes.""" + from __future__ import annotations import importlib @@ -40,9 +41,7 @@ def _collect_python_class_paths() -> list[tuple[str, str]]: val = getattr(resource_cls, attr) if isinstance(val, str) and val.startswith(PYTHON_PREFIX): paths.append((f"{resource_name}.{attr}", val)) - if hasattr(ResourceName, "MCP_SERVER") and isinstance( - ResourceName.MCP_SERVER, str - ): + if hasattr(ResourceName, "MCP_SERVER") and isinstance(ResourceName.MCP_SERVER, str): if ResourceName.MCP_SERVER.startswith(PYTHON_PREFIX): paths.append(("MCP_SERVER", ResourceName.MCP_SERVER)) return paths @@ -58,7 +57,10 @@ def _class_exists(full_class_path: str) -> tuple[bool, str]: module = importlib.import_module(module_path) cls = getattr(module, class_name, None) if cls is None: - return False, f"module {module_path!r} Attribute does not exist {class_name!r}" + return ( + False, + f"module {module_path!r} Attribute does not exist {class_name!r}", + ) if not inspect.isclass(cls): return False, f"{full_class_path!r} is not a class" except Exception as e: diff --git a/python/flink_agents/api/tests/test_tool.py b/python/flink_agents/api/tests/test_tool.py index 203d65452..23be82db8 100644 --- a/python/flink_agents/api/tests/test_tool.py +++ b/python/flink_agents/api/tests/test_tool.py @@ -47,7 +47,7 @@ def foo(bar: int, baz: str) -> str: @pytest.fixture(scope="module") -def tool_metadata() -> ToolMetadata: # noqa: D103 +def tool_metadata() -> ToolMetadata: return ToolMetadata( name="foo", description="Function for testing ToolMetadata", @@ -55,7 +55,7 @@ def tool_metadata() -> ToolMetadata: # noqa: D103 ) -def test_serialize_tool_metadata(tool_metadata: ToolMetadata) -> None: # noqa: D103 +def test_serialize_tool_metadata(tool_metadata: ToolMetadata) -> None: json_value = tool_metadata.model_dump_json(serialize_as_any=True) with Path(f"{current_dir}/resources/tool_metadata.json").open() as f: expected_json = f.read() @@ -64,7 +64,7 @@ def test_serialize_tool_metadata(tool_metadata: ToolMetadata) -> None: # noqa: assert actual == expected -def test_deserialize_tool_metadata(tool_metadata: ToolMetadata) -> None: # noqa: D103 +def test_deserialize_tool_metadata(tool_metadata: ToolMetadata) -> None: with Path(f"{current_dir}/resources/tool_metadata.json").open() as f: expected_json = f.read() actual_tool_metadata = tool_metadata.model_validate_json(expected_json) diff --git a/python/flink_agents/api/tests/test_version_compatibility.py b/python/flink_agents/api/tests/test_version_compatibility.py index 95839704a..a4c969cd3 100644 --- a/python/flink_agents/api/tests/test_version_compatibility.py +++ b/python/flink_agents/api/tests/test_version_compatibility.py @@ -24,35 +24,35 @@ # Tests for _normalize_version function -def test_normalize_three_part_version() -> None: # noqa: D103 +def test_normalize_three_part_version() -> None: assert _normalize_version("1.20.3") == "1.20.3" assert _normalize_version("2.2.0") == "2.2.0" -def test_normalize_two_part_version() -> None: # noqa: D103 +def test_normalize_two_part_version() -> None: assert _normalize_version("2.2") == "2.2.0" assert _normalize_version("1.20") == "1.20.0" -def test_normalize_version_with_suffix() -> None: # noqa: D103 +def test_normalize_version_with_suffix() -> None: assert _normalize_version("2.2-SNAPSHOT") == "2.2.0" assert _normalize_version("1.20.dev0") == "1.20.0" assert _normalize_version("2.0.rc1") == "2.0.0" -def test_normalize_long_version() -> None: # noqa: D103 +def test_normalize_long_version() -> None: assert _normalize_version("1.20.3.4") == "1.20.3" assert _normalize_version("2.2.0.1.5") == "2.2.0" # Tests for FlinkVersionManager class -def test_version_property_with_flink_installed() -> None: # noqa: D103 +def test_version_property_with_flink_installed() -> None: with patch("importlib.metadata.version", return_value="1.20.3"): manager = FlinkVersionManager() assert manager.version == "1.20.3" -def test_version_property_without_flink_installed() -> None: # noqa: D103 +def test_version_property_without_flink_installed() -> None: with patch( "importlib.metadata.version", side_effect=Exception("Package not found") ): @@ -60,7 +60,7 @@ def test_version_property_without_flink_installed() -> None: # noqa: D103 assert manager.version is None -def test_major_version_property() -> None: # noqa: D103 +def test_major_version_property() -> None: with patch("importlib.metadata.version", return_value="1.20.3"): manager = FlinkVersionManager() assert manager.major_version == "1.20" @@ -70,13 +70,13 @@ def test_major_version_property() -> None: # noqa: D103 assert manager.major_version == "2.2" -def test_major_version_with_snapshot() -> None: # noqa: D103 +def test_major_version_with_snapshot() -> None: with patch("importlib.metadata.version", return_value="2.2.0-SNAPSHOT"): manager = FlinkVersionManager() assert manager.major_version == "2.2" -def test_major_version_without_flink() -> None: # noqa: D103 +def test_major_version_without_flink() -> None: with patch( "importlib.metadata.version", side_effect=Exception("Package not found") ): @@ -84,7 +84,7 @@ def test_major_version_without_flink() -> None: # noqa: D103 assert manager.major_version is None -def test_ge_method() -> None: # noqa: D103 +def test_ge_method() -> None: with patch("importlib.metadata.version", return_value="1.20.3"): manager = FlinkVersionManager() assert manager.ge("1.20.0") is True @@ -92,7 +92,7 @@ def test_ge_method() -> None: # noqa: D103 assert manager.ge("1.21.0") is False -def test_ge_with_two_part_version() -> None: # noqa: D103 +def test_ge_with_two_part_version() -> None: with patch("importlib.metadata.version", return_value="2.2"): manager = FlinkVersionManager() assert manager.ge("2.0.0") is True @@ -100,7 +100,7 @@ def test_ge_with_two_part_version() -> None: # noqa: D103 assert manager.ge("2.3") is False -def test_ge_without_flink_installed() -> None: # noqa: D103 +def test_ge_without_flink_installed() -> None: with patch( "importlib.metadata.version", side_effect=Exception("Package not found") ): @@ -108,7 +108,7 @@ def test_ge_without_flink_installed() -> None: # noqa: D103 assert manager.ge("1.20.0") is False -def test_lt_method() -> None: # noqa: D103 +def test_lt_method() -> None: with patch("importlib.metadata.version", return_value="1.20.3"): manager = FlinkVersionManager() assert manager.lt("1.21.0") is True @@ -116,7 +116,7 @@ def test_lt_method() -> None: # noqa: D103 assert manager.lt("1.20.0") is False -def test_lt_with_two_part_version() -> None: # noqa: D103 +def test_lt_with_two_part_version() -> None: with patch("importlib.metadata.version", return_value="2.2"): manager = FlinkVersionManager() assert manager.lt("2.3") is True @@ -124,7 +124,7 @@ def test_lt_with_two_part_version() -> None: # noqa: D103 assert manager.lt("2.0") is False -def test_lt_without_flink_installed() -> None: # noqa: D103 +def test_lt_without_flink_installed() -> None: with patch( "importlib.metadata.version", side_effect=Exception("Package not found") ): @@ -132,7 +132,7 @@ def test_lt_without_flink_installed() -> None: # noqa: D103 assert manager.lt("2.0.0") is False -def test_lazy_initialization() -> None: # noqa: D103 +def test_lazy_initialization() -> None: with patch("importlib.metadata.version", return_value="1.20.3") as mock_version: manager = FlinkVersionManager() # Version should not be fetched yet @@ -149,7 +149,7 @@ def test_lazy_initialization() -> None: # noqa: D103 mock_version.assert_called_once() # Still called only once -def test_version_comparison_with_snapshot_versions() -> None: # noqa: D103 +def test_version_comparison_with_snapshot_versions() -> None: with patch("importlib.metadata.version", return_value="2.2-SNAPSHOT"): manager = FlinkVersionManager() assert manager.ge("2.2.0") is True @@ -157,7 +157,7 @@ def test_version_comparison_with_snapshot_versions() -> None: # noqa: D103 assert manager.lt("2.3.0") is True -def test_version_comparison_edge_cases() -> None: # noqa: D103 +def test_version_comparison_edge_cases() -> None: # Test boundary versions with patch("importlib.metadata.version", return_value="1.20.3"): manager = FlinkVersionManager() diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index 585e64f5e..a2ee041a5 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -77,9 +77,7 @@ def __custom_deserialize(self) -> "ToolMetadata": args_schema = self["args_schema"] if isinstance(args_schema, dict): title = args_schema.get("title", "default") - self["args_schema"] = create_model_from_schema( - title, args_schema - ) + self["args_schema"] = create_model_from_schema(title, args_schema) return self def __eq__(self, other: "ToolMetadata") -> bool: diff --git a/python/flink_agents/api/tools/utils.py b/python/flink_agents/api/tools/utils.py index 51ea1bd94..f94fc818a 100644 --- a/python/flink_agents/api/tools/utils.py +++ b/python/flink_agents/api/tools/utils.py @@ -127,7 +127,7 @@ def resolve_field_type(field_schema: dict) -> type[typing.Any]: if type(None) in types: types.remove(type(None)) if len(types) == 1: - return typing.Optional[types[0]] # noqa: UP007 + return typing.Optional[types[0]] # noqa: UP007 return Optional[tuple(types)] # # noqa: UP007 else: return Union[tuple(types)] # noqa: UP007 @@ -178,7 +178,10 @@ def resolve_field_type(field_schema: dict) -> type[typing.Any]: return create_model(name, **main_fields, __doc__=schema.get("description", "")) -def create_model_from_java_tool_schema_str(name: str, schema_str: str) -> type[BaseModel]: + +def create_model_from_java_tool_schema_str( + name: str, schema_str: str +) -> type[BaseModel]: """Create Pydantic model from a java tool input schema.""" json_schema = json.loads(schema_str) properties = json_schema["properties"] @@ -192,6 +195,7 @@ def create_model_from_java_tool_schema_str(name: str, schema_str: str) -> type[B fields[param_name] = (type, FieldInfo(description=description)) return create_model(name, **fields) + def create_java_tool_schema_str_from_model(model: type[BaseModel]) -> str: """Create a java tool input schema string from a Pydantic model. diff --git a/python/flink_agents/api/vector_stores/java_vector_store.py b/python/flink_agents/api/vector_stores/java_vector_store.py index f3f7e8d0c..cf87285c5 100644 --- a/python/flink_agents/api/vector_stores/java_vector_store.py +++ b/python/flink_agents/api/vector_stores/java_vector_store.py @@ -26,10 +26,13 @@ class JavaVectorStore(BaseVectorStore): """Java-based implementation of VectorStore that wraps a Java vector store.""" - java_class_name: str="" + java_class_name: str = "" + @java_resource -class JavaCollectionManageableVectorStore(JavaVectorStore, CollectionManageableVectorStore): +class JavaCollectionManageableVectorStore( + JavaVectorStore, CollectionManageableVectorStore +): """Java-based implementation of VectorStore with collection management capabilities that bridges Python and Java vector store functionality. """ diff --git a/python/flink_agents/api/version_compatibility.py b/python/flink_agents/api/version_compatibility.py index 4060c4091..a3ab066fa 100644 --- a/python/flink_agents/api/version_compatibility.py +++ b/python/flink_agents/api/version_compatibility.py @@ -34,11 +34,11 @@ def _normalize_version(version_str: str) -> str: str: Normalized version string in format "major.minor.patch" """ # Remove any version suffix with hyphen (e.g., -SNAPSHOT, -dev) - base_version = version_str.split('-')[0] + base_version = version_str.split("-")[0] # Split by dot and keep only numeric parts parts = [] - for part in base_version.split('.'): + for part in base_version.split("."): # Only keep parts that are purely numeric if part.isdigit(): parts.append(part) @@ -48,9 +48,9 @@ def _normalize_version(version_str: str) -> str: # Ensure we have at least three parts (major.minor.patch) while len(parts) < 3: - parts.append('0') + parts.append("0") - return '.'.join(parts[:3]) + return ".".join(parts[:3]) class FlinkVersionManager: @@ -98,6 +98,7 @@ def _get_pyflink_version(self) -> str | None: """ try: from importlib.metadata import version as get_version + return get_version("apache-flink") except Exception: return None @@ -129,7 +130,7 @@ def major_version(self) -> str | None: # Extract major.minor from full version string # Examples: "2.2.0" -> "2.2", "1.20.3" -> "1.20", "2.2.0-SNAPSHOT" -> "2.2" - version_parts = self.version.split('-')[0].split('.') + version_parts = self.version.split("-")[0].split(".") if len(version_parts) >= 2: return f"{version_parts[0]}.{version_parts[1]}" return self.version diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py index 25f4555f3..3fd911afa 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py @@ -41,7 +41,7 @@ def open(self) -> None: """Do nothing.""" @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {} @override @@ -68,7 +68,7 @@ class AsyncTestAgent(Agent): @chat_model_setup @staticmethod - def slow_chat_model() -> ResourceDescriptor: # noqa: D102 + def slow_chat_model() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}", connection="placement", @@ -84,7 +84,7 @@ def add(a: int, b: int) -> int: @action(InputEvent) @staticmethod - def process_input(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 + def process_input(event: InputEvent, ctx: RunnerContext) -> None: input = event.input ctx.send_event( ChatRequestEvent( @@ -99,7 +99,7 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 @action(ChatResponseEvent) @staticmethod - def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: # noqa: D102 + def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: input = event.response ctx.send_event(OutputEvent(output=input.content)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py index f98795c5a..eee6dd60b 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py @@ -42,7 +42,8 @@ class ChatModelTestAgent(Agent): def openai_connection() -> ResourceDescriptor: """ChatModelConnection responsible for openai model service connection.""" return ResourceDescriptor( - clazz=ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION, api_key=os.environ.get("OPENAI_API_KEY") + clazz=ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION, + api_key=os.environ.get("OPENAI_API_KEY"), ) @chat_model_connection diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py index f53b593e9..bda497dba 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py @@ -36,7 +36,9 @@ os.environ["OPENAI_CHAT_MODEL"] = OPENAI_MODEL AZURE_OPENAI_MODEL = os.environ.get("AZURE_OPENAI_CHAT_MODEL", "gpt-5") os.environ["AZURE_OPENAI_CHAT_MODEL"] = AZURE_OPENAI_MODEL -AZURE_OPENAI_API_VERSION= os.environ.get("AZURE_OPENAI_API_VERSION", "2025-04-01-preview") +AZURE_OPENAI_API_VERSION = os.environ.get( + "AZURE_OPENAI_API_VERSION", "2025-04-01-preview" +) os.environ["AZURE_OPENAI_API_VERSION"] = AZURE_OPENAI_API_VERSION DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") @@ -76,7 +78,7 @@ ), ], ) -def test_chat_model_integration(model_provider: str) -> None: # noqa: D103 +def test_chat_model_integration(model_provider: str) -> None: os.environ["MODEL_PROVIDER"] = model_provider env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_server_without_prompts.py b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_server_without_prompts.py index 456a57153..acdeed507 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_server_without_prompts.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_server_without_prompts.py @@ -23,7 +23,8 @@ dotenv.load_dotenv() # Create MCP server -mcp = FastMCP("MathServer", port = 8001) +mcp = FastMCP("MathServer", port=8001) + @mcp.tool() async def add(a: int, b: int) -> int: diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py index db2163feb..f18099026 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py @@ -101,7 +101,9 @@ def math_chat_model() -> ResourceDescriptor: } # Only add prompt if using server with prompts if mcp_mode == "with_prompts": - descriptor_kwargs["prompt"] = "ask_sum" # MCP prompt registered from my_mcp_server + descriptor_kwargs["prompt"] = ( + "ask_sum" # MCP prompt registered from my_mcp_server + ) return ResourceDescriptor(**descriptor_kwargs) @action(InputEvent) @@ -154,7 +156,11 @@ def run_mcp_server(server_file: str) -> None: ("mcp_server_mode", "server_file", "server_endpoint"), [ ("with_prompts", "mcp_server.py", MCP_SERVER_ENDPOINT), - ("without_prompts", "mcp_server_without_prompts.py", MCP_SERVER_ENDPOINT_WITHOUT_PROMPTS), + ( + "without_prompts", + "mcp_server_without_prompts.py", + MCP_SERVER_ENDPOINT_WITHOUT_PROMPTS, + ), ], ) @pytest.mark.skipif( diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py index fec68dade..d9c55b350 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py @@ -255,4 +255,3 @@ def test_durable_execute_async_exception_flink(tmp_path: Path) -> None: f"{current_dir}/../resources/ground_truth/test_execute_async_exception.txt" ), ) - diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py index c9a675d40..84570b765 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py @@ -104,7 +104,9 @@ def process(event: Event, ctx: RunnerContext) -> None: input_data: ExecuteTestData = event.input # Use synchronous durable execute result = ctx.durable_execute(compute_value, input_data.value, 10) - ctx.send_event(OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result))) + ctx.send_event( + OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result)) + ) class ExecuteMultipleTestAgent(Agent): @@ -117,7 +119,9 @@ def process(event: Event, ctx: RunnerContext) -> None: input_data: ExecuteTestData = event.input result1 = ctx.durable_execute(compute_value, input_data.value, 5) result2 = ctx.durable_execute(multiply_value, result1, 2) - ctx.send_event(OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result2))) + ctx.send_event( + OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result2)) + ) class ExecuteWithAsyncTestAgent(Agent): @@ -155,4 +159,3 @@ async def process(event: Event, ctx: RunnerContext) -> None: output=ExecuteTestErrorOutput(id=input_data.id, error=str(exc)) ) ) - diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py index 989cc0ec9..3af227d23 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py @@ -50,7 +50,7 @@ class ItemData(BaseModel): memory_info: dict | None = None -class MyEvent(Event): # noqa D101 +class MyEvent(Event): value: Any diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_intergration_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_intergration_test.py index 2d5a0c55e..1be3b3800 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_intergration_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_intergration_test.py @@ -47,7 +47,7 @@ os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] -def test_from_datastream_to_datastream(tmp_path: Path) -> None: # noqa: D103 +def test_from_datastream_to_datastream(tmp_path: Path) -> None: config = Configuration() config.set_string("state.backend.type", "rocksdb") config.set_string("checkpointing.interval", "1s") @@ -99,7 +99,7 @@ def test_from_datastream_to_datastream(tmp_path: Path) -> None: # noqa: D103 ) -def test_from_table_to_table(tmp_path: Path) -> None: # noqa: D103 +def test_from_table_to_table(tmp_path: Path) -> None: env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) @@ -174,7 +174,7 @@ def test_from_table_to_table(tmp_path: Path) -> None: # noqa: D103 ) -def test_from_datastream_to_table(tmp_path: Path) -> None: # noqa: D103 +def test_from_datastream_to_table(tmp_path: Path) -> None: env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py index 3dd49cf59..87f257afb 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py @@ -96,7 +96,7 @@ class ItemData(BaseModel): memory_info: dict | None = None -class Record(BaseModel): # noqa: D101 +class Record(BaseModel): id: int count: int timestamp_before_add: str @@ -105,7 +105,7 @@ class Record(BaseModel): # noqa: D101 items: List[MemorySetItem] | None = None -class MyEvent(Event): # noqa D101 +class MyEvent(Event): value: Any @@ -142,14 +142,14 @@ def ollama_qwen3() -> ResourceDescriptor: @embedding_model_connection @staticmethod - def ollama_embedding_connection() -> ResourceDescriptor: # noqa D102 + def ollama_embedding_connection() -> ResourceDescriptor: return ResourceDescriptor( clazz=ResourceName.EmbeddingModel.OLLAMA_CONNECTION, request_timeout=240.0 ) @embedding_model_setup @staticmethod - def ollama_nomic_embed_text() -> ResourceDescriptor: # noqa D102 + def ollama_nomic_embed_text() -> ResourceDescriptor: return ResourceDescriptor( clazz=ResourceName.EmbeddingModel.OLLAMA_SETUP, connection="ollama_embedding_connection", @@ -218,7 +218,7 @@ async def retrieve_items(event: Event, ctx: RunnerContext): # noqa D102 "flink-agent doesn't allow get resource in async thread. We will deprecate VectorStoreLongTermMemory in 0.3.0," "so we will not fix this issue for now." ) -def test_long_term_memory_async_execution_in_action(tmp_path: Path) -> None: # noqa: D103 +def test_long_term_memory_async_execution_in_action(tmp_path: Path) -> None: env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) env.set_parallelism(1) @@ -271,7 +271,7 @@ def test_long_term_memory_async_execution_in_action(tmp_path: Path) -> None: # check_result(result_dir=result_dir) -def check_result(*, result_dir: Path) -> None: # noqa: D103 +def check_result(*, result_dir: Path) -> None: actual_result = [] for file in result_dir.iterdir(): if file.is_dir(): diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/python_event_logging_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/python_event_logging_test.py index 2db308042..f5dea5e47 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/python_event_logging_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/python_event_logging_test.py @@ -105,9 +105,7 @@ def test_python_event_logging(tmp_path: Path) -> None: log_files = list(event_log_dir.glob("events-*.log")) # At least one log file should exist - assert len(log_files) > 0, ( - f"Event log files should be created in {event_log_dir}" - ) + assert len(log_files) > 0, f"Event log files should be created in {event_log_dir}" # Check that log files contain structured event content record = None diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py index 5e4880dfe..85cd42bfd 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py @@ -54,13 +54,13 @@ os.environ["OLLAMA_CHAT_MODEL"] = OLLAMA_MODEL -class InputData(BaseModel): # noqa: D101 +class InputData(BaseModel): a: int b: int c: int -class OutputData(BaseModel): # noqa: D101 +class OutputData(BaseModel): result: int @@ -78,7 +78,7 @@ def get_key(self, value: Row) -> int: @pytest.mark.skipif( client is None, reason="Ollama client is not available or test model is missing" ) -def test_react_agent_on_local_runner() -> None: # noqa: D103 +def test_react_agent_on_local_runner() -> None: env = AgentsExecutionEnvironment.get_execution_environment() env.get_config().set( AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY @@ -90,9 +90,11 @@ def test_react_agent_on_local_runner() -> None: # noqa: D103 env.add_resource( "ollama", ResourceType.CHAT_MODEL_CONNECTION, - ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0), + ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0 + ), ) - .add_resource("add", ResourceType.TOOL, Tool.from_callable(add)) + .add_resource("add", ResourceType.TOOL, Tool.from_callable(add)) .add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply)) ) @@ -136,7 +138,7 @@ def test_react_agent_on_local_runner() -> None: # noqa: D103 @pytest.mark.skipif( client is None, reason="Ollama client is not available or test model is missing" ) -def test_react_agent_on_remote_runner(tmp_path: Path) -> None: # noqa: D103 +def test_react_agent_on_remote_runner(tmp_path: Path) -> None: stream_env = StreamExecutionEnvironment.get_execution_environment() stream_env.set_parallelism(1) @@ -169,7 +171,9 @@ def test_react_agent_on_remote_runner(tmp_path: Path) -> None: # noqa: D103 env.add_resource( "ollama", ResourceType.CHAT_MODEL_CONNECTION, - ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0), + ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0 + ), ) .add_resource("add", ResourceType.TOOL, Tool.from_callable(add)) .add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py index 173155166..205f0891b 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py @@ -33,12 +33,12 @@ current_dir = Path(__file__).parent -class ProcessedData(BaseModel): # noqa D101 +class ProcessedData(BaseModel): content: str visit_count: int -class MyEvent(Event): # noqa D101 +class MyEvent(Event): value: Any @@ -93,7 +93,7 @@ def second_action(event: Event, ctx: RunnerContext): # noqa D102 ctx.send_event(OutputEvent(output={key_with_count: final_content})) -def test_workflow() -> None: # noqa: D103 +def test_workflow() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py index 9cf16904d..42a5dd719 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py @@ -47,8 +47,11 @@ os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] -@pytest.mark.skipif(client is None, reason="Ollama client is not available or test model is missing.") -def test_java_chat_model_integration(tmp_path: Path) -> None: # noqa: D103 + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing." +) +def test_java_chat_model_integration(tmp_path: Path) -> None: env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) env.set_parallelism(1) @@ -57,20 +60,19 @@ def test_java_chat_model_integration(tmp_path: Path) -> None: # noqa: D103 # we use continuous file source here. input_datastream = env.from_source( source=FileSource.for_record_stream_format( - StreamFormat.text_line_format(), f"file:///{current_dir}/../resources/java_chat_module_input" + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/java_chat_module_input", ).build(), watermark_strategy=WatermarkStrategy.no_watermarks(), source_name="streaming_agent_example", ) - deserialize_datastream = input_datastream.map( - lambda x: str(x) - ) + deserialize_datastream = input_datastream.map(lambda x: str(x)) agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) output_datastream = ( agents_env.from_datastream( - input=deserialize_datastream, key_selector= lambda x: "orderKey" + input=deserialize_datastream, key_selector=lambda x: "orderKey" ) .apply(ChatModelCrossLanguageAgent()) .to_datastream() @@ -79,13 +81,16 @@ def test_java_chat_model_integration(tmp_path: Path) -> None: # noqa: D103 result_dir = tmp_path / "results" result_dir.mkdir(parents=True, exist_ok=True) - (output_datastream.map(lambda x: str(x).replace('\n', '') - .replace('\r', ''), Types.STRING()).add_sink( - StreamingFileSink.for_row_format( - base_path=str(result_dir.absolute()), - encoder=Encoder.simple_string_encoder(), - ).build() - )) + ( + output_datastream.map( + lambda x: str(x).replace("\n", "").replace("\r", ""), Types.STRING() + ).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + ) agents_env.execute() diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py index 56a85ffcb..c8be4fb43 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py @@ -71,40 +71,52 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: try: # Get embedding model - embeddingModel = ctx.get_resource("embedding_model", ResourceType.EMBEDDING_MODEL) + embeddingModel = ctx.get_resource( + "embedding_model", ResourceType.EMBEDDING_MODEL + ) # Test single text embedding embedding = embeddingModel.embed(input_text) print(f"[TEST] Generated embedding with dimension: {len(embedding)}") # Validate single embedding result - if embedding is None or not isinstance(embedding, list) or len(embedding) == 0: + if ( + embedding is None + or not isinstance(embedding, list) + or len(embedding) == 0 + ): err_msg = "Embedding cannot be null or empty" - raise AssertionError(err_msg) # noqa: TRY301 + raise AssertionError(err_msg) # noqa: TRY301 if not all(isinstance(x, float) for x in embedding): err_msg = "All embedding values must be floats" - raise AssertionError(err_msg) # noqa: TRY301 + raise AssertionError(err_msg) # noqa: TRY301 - print(f"[TEST] Validated single embedding: Text={short_doc}, Dimension={len(embedding)}, Text='{input_text[:30]}...'") + print( + f"[TEST] Validated single embedding: Text={short_doc}, Dimension={len(embedding)}, Text='{input_text[:30]}...'" + ) # Test batch embedding embeddings = embeddingModel.embed([input_text]) print(f"[TEST] Generated batch embeddings: count={len(embeddings)}") # Validate batch embedding results - if embeddings is None or not isinstance(embeddings, list) or len(embeddings) == 0: + if ( + embeddings is None + or not isinstance(embeddings, list) + or len(embeddings) == 0 + ): err_msg = "Batch embeddings cannot be null or empty" - raise AssertionError(err_msg) # noqa: TRY301 + raise AssertionError(err_msg) # noqa: TRY301 if len(embeddings) != 1: err_msg = f"Expected 1 embedding but got {len(embeddings)}" - raise AssertionError(err_msg) # noqa: TRY301 + raise AssertionError(err_msg) # noqa: TRY301 for i, emb in enumerate(embeddings): if not isinstance(emb, list) or len(emb) == 0: err_msg = f"Embedding at index {i} is invalid" - raise AssertionError(err_msg) # noqa: TRY301 + raise AssertionError(err_msg) # noqa: TRY301 print(f"[TEST] Validated batch embedding {i}: Dimension={len(emb)}") # Create test result as a single string @@ -112,7 +124,9 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: ctx.send_event(OutputEvent(output=test_result)) - print(f"[TEST] Embedding generation test PASSED for: '{input_text[:50]}...'") + print( + f"[TEST] Embedding generation test PASSED for: '{input_text[:50]}...'" + ) except Exception as e: # Create error result as a single string diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py index 6ba762ab1..a0afc56f2 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py @@ -47,8 +47,11 @@ os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] -@pytest.mark.skipif(client is None, reason="Ollama client is not available or test model is missing.") -def test_java_embedding_model_integration(tmp_path: Path) -> None: # noqa: D103 + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing." +) +def test_java_embedding_model_integration(tmp_path: Path) -> None: env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) env.set_parallelism(1) @@ -57,20 +60,19 @@ def test_java_embedding_model_integration(tmp_path: Path) -> None: # noqa: D103 # we use continuous file source here. input_datastream = env.from_source( source=FileSource.for_record_stream_format( - StreamFormat.text_line_format(), f"file:///{current_dir}/../resources/java_chat_module_input" + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/java_chat_module_input", ).build(), watermark_strategy=WatermarkStrategy.no_watermarks(), source_name="streaming_agent_example", ) - deserialize_datastream = input_datastream.map( - lambda x: str(x) - ) + deserialize_datastream = input_datastream.map(lambda x: str(x)) agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) output_datastream = ( agents_env.from_datastream( - input=deserialize_datastream, key_selector= lambda x: "orderKey" + input=deserialize_datastream, key_selector=lambda x: "orderKey" ) .apply(EmbeddingModelCrossLanguageAgent()) .to_datastream() @@ -79,13 +81,16 @@ def test_java_embedding_model_integration(tmp_path: Path) -> None: # noqa: D103 result_dir = tmp_path / "results" result_dir.mkdir(parents=True, exist_ok=True) - (output_datastream.map(lambda x: str(x).replace('\n', '') - .replace('\r', ''), Types.STRING()).add_sink( - StreamingFileSink.for_row_format( - base_path=str(result_dir.absolute()), - encoder=Encoder.simple_string_encoder(), - ).build() - )) + ( + output_datastream.map( + lambda x: str(x).replace("\n", "").replace("\r", ""), Types.STRING() + ).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + ) agents_env.execute() diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py index cc41e5a1c..cb00dbddb 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py @@ -46,6 +46,7 @@ TEST_COLLECTION = "test_collection" MAX_RETRIES_TIMES = 10 + class VectorStoreCrossLanguageAgent(Agent): """Example agent demonstrating cross-language embedding model testing.""" @@ -115,7 +116,9 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: vector_store = ctx.get_resource("vector_store", ResourceType.VECTOR_STORE) if isinstance(vector_store, CollectionManageableVectorStore): - vector_store.get_or_create_collection(TEST_COLLECTION , metadata={"key1": "value1", "key2": "value2"}) + vector_store.get_or_create_collection( + TEST_COLLECTION, metadata={"key1": "value1", "key2": "value2"} + ) collection = vector_store.get_collection(name=TEST_COLLECTION) @@ -168,18 +171,24 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: doc = vector_store.get(ids="doc2") assert doc is not None assert doc[0].id == "doc2" - assert doc[0].content == "Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse." + assert ( + doc[0].content + == "Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse." + ) print("[TEST] Vector store Document Management PASSED") stm.set("is_initialized", True) - - ctx.send_event(ContextRetrievalRequestEvent(query=input_text, vector_store="vector_store")) + ctx.send_event( + ContextRetrievalRequestEvent(query=input_text, vector_store="vector_store") + ) @action(ContextRetrievalResponseEvent) @staticmethod - def contextRetrievalResponseEvent(event: ContextRetrievalResponseEvent, ctx: RunnerContext) -> None: + def contextRetrievalResponseEvent( + event: ContextRetrievalResponseEvent, ctx: RunnerContext + ) -> None: """User defined action for processing context retrieval response. In this action, we will test Vector store Context Retrieval. @@ -195,6 +204,8 @@ def contextRetrievalResponseEvent(event: ContextRetrievalResponseEvent, ctx: Run assert document.content is not None test_result = f"[PASS] retrieved_count={len(documents)}, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}" - print(f"[TEST] Vector store Context Retrieval PASSED, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}") + print( + f"[TEST] Vector store Context Retrieval PASSED, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}" + ) ctx.send_event(OutputEvent(output=test_result)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py index ea0930ff8..36825a779 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py @@ -49,9 +49,13 @@ os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] -@pytest.mark.skipif(client is None or ES_HOST is None, reason="Ollama client or Elasticsearch host is missing.") + +@pytest.mark.skipif( + client is None or ES_HOST is None, + reason="Ollama client or Elasticsearch host is missing.", +) @pytest.mark.parametrize("embedding_type", ["JAVA", "PYTHON"]) -def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> None: # noqa: D103 +def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> None: os.environ["EMBEDDING_TYPE"] = embedding_type env = StreamExecutionEnvironment.get_execution_environment() @@ -62,20 +66,19 @@ def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> N # we use continuous file source here. input_datastream = env.from_source( source=FileSource.for_record_stream_format( - StreamFormat.text_line_format(), f"file:///{current_dir}/../resources/java_chat_module_input" + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/java_chat_module_input", ).build(), watermark_strategy=WatermarkStrategy.no_watermarks(), source_name="streaming_agent_example", ) - deserialize_datastream = input_datastream.map( - lambda x: str(x) - ) + deserialize_datastream = input_datastream.map(lambda x: str(x)) agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) output_datastream = ( agents_env.from_datastream( - input=deserialize_datastream, key_selector= lambda x: "orderKey" + input=deserialize_datastream, key_selector=lambda x: "orderKey" ) .apply(VectorStoreCrossLanguageAgent()) .to_datastream() @@ -84,13 +87,16 @@ def test_java_vector_store_integration(tmp_path: Path, embedding_type: str) -> N result_dir = tmp_path / "results" result_dir.mkdir(parents=True, exist_ok=True) - (output_datastream.map(lambda x: str(x).replace('\n', '') - .replace('\r', ''), Types.STRING()).add_sink( - StreamingFileSink.for_row_format( - base_path=str(result_dir.absolute()), - encoder=Encoder.simple_string_encoder(), - ).build() - )) + ( + output_datastream.map( + lambda x: str(x).replace("\n", "").replace("\r", ""), Types.STRING() + ).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + ) agents_env.execute() diff --git a/python/flink_agents/examples/quickstart/react_agent_example.py b/python/flink_agents/examples/quickstart/react_agent_example.py index 34d159712..71d3ea760 100644 --- a/python/flink_agents/examples/quickstart/react_agent_example.py +++ b/python/flink_agents/examples/quickstart/react_agent_example.py @@ -62,9 +62,13 @@ def main() -> None: agents_env.add_resource( "ollama_server", ResourceType.CHAT_MODEL_CONNECTION, - ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=120), + ResourceDescriptor( + clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=120 + ), ).add_resource( - "notify_shipping_manager", ResourceType.TOOL, Tool.from_callable(notify_shipping_manager) + "notify_shipping_manager", + ResourceType.TOOL, + Tool.from_callable(notify_shipping_manager), ) # Read product reviews from a text file as a streaming source. diff --git a/python/flink_agents/examples/rag/knowledge_base_setup.py b/python/flink_agents/examples/rag/knowledge_base_setup.py index e574378a1..de3f6f056 100644 --- a/python/flink_agents/examples/rag/knowledge_base_setup.py +++ b/python/flink_agents/examples/rag/knowledge_base_setup.py @@ -72,10 +72,12 @@ def populate_knowledge_base() -> None: "documents": documents, "embeddings": embeddings, "metadatas": metadatas, - "ids": [f"doc{i + 1}" for i in range(len(documents))] + "ids": [f"doc{i + 1}" for i in range(len(documents))], } # Add documents to ChromaDB collection.add(**test_data) - print(f"Knowledge base setup complete! Added {len(documents)} documents to ChromaDB.") + print( + f"Knowledge base setup complete! Added {len(documents)} documents to ChromaDB." + ) diff --git a/python/flink_agents/examples/rag/rag_agent_example.py b/python/flink_agents/examples/rag/rag_agent_example.py index 8d9352a20..0e228732f 100644 --- a/python/flink_agents/examples/rag/rag_agent_example.py +++ b/python/flink_agents/examples/rag/rag_agent_example.py @@ -100,7 +100,7 @@ def chat_model() -> ResourceDescriptor: return ResourceDescriptor( clazz=ResourceName.ChatModel.OLLAMA_SETUP, connection="ollama_chat_connection", - model=OLLAMA_CHAT_MODEL + model=OLLAMA_CHAT_MODEL, ) @action(InputEvent) @@ -119,7 +119,7 @@ def process_input(event: InputEvent, ctx: RunnerContext) -> None: @action(ContextRetrievalResponseEvent) @staticmethod def process_retrieved_context( - event: ContextRetrievalResponseEvent, ctx: RunnerContext + event: ContextRetrievalResponseEvent, ctx: RunnerContext ) -> None: """Process retrieved context and create enhanced chat request.""" user_query = event.query @@ -131,10 +131,11 @@ def process_retrieved_context( ) # Get prompt resource and format it - prompt_resource = ctx.get_resource("context_enhanced_prompt", ResourceType.PROMPT) + prompt_resource = ctx.get_resource( + "context_enhanced_prompt", ResourceType.PROMPT + ) enhanced_prompt = prompt_resource.format_string( - context=context_text, - user_query=user_query + context=context_text, user_query=user_query ) # Send chat request with enhanced prompt @@ -174,8 +175,16 @@ def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: agents_env = AgentsExecutionEnvironment.get_execution_environment() # Setup Ollama embedding and chat model connections - agents_env.add_resource("ollama_embedding_connection", ResourceType.EMBEDDING_MODEL_CONNECTION, ResourceDescriptor(clazz=ResourceName.EmbeddingModel.OLLAMA_CONNECTION)) - agents_env.add_resource("ollama_chat_connection", ResourceType.EMBEDDING_MODEL, ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION)) + agents_env.add_resource( + "ollama_embedding_connection", + ResourceType.EMBEDDING_MODEL_CONNECTION, + ResourceDescriptor(clazz=ResourceName.EmbeddingModel.OLLAMA_CONNECTION), + ) + agents_env.add_resource( + "ollama_chat_connection", + ResourceType.EMBEDDING_MODEL, + ResourceDescriptor(clazz=ResourceName.ChatModel.OLLAMA_CONNECTION), + ) output_list = agents_env.from_list(input_list).apply(agent).to_list() diff --git a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py index 95a77e5cf..e6b0acb2a 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py @@ -32,10 +32,8 @@ @pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") -def test_anthropic_chat_model() -> None: # noqa: D103 - connection = AnthropicChatModelConnection( - name="anthropic_server", api_key=api_key - ) +def test_anthropic_chat_model() -> None: + connection = AnthropicChatModelConnection(name="anthropic_server", api_key=api_key) def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL_CONNECTION: @@ -44,7 +42,10 @@ def get_resource(name: str, type: ResourceType) -> Resource: return get_resource(name, ResourceType.TOOL) chat_model = AnthropicChatModelSetup( - name="anthropic", model=test_model, connection="anthropic_server", get_resource=get_resource + name="anthropic", + model=test_model, + connection="anthropic_server", + get_resource=get_resource, ) response = chat_model.chat([ChatMessage(role=MessageRole.USER, content="Hello!")]) assert response is not None @@ -70,10 +71,8 @@ def add(a: int, b: int) -> int: @pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") -def test_anthropic_chat_with_tools() -> None: # noqa : D103 - connection = AnthropicChatModelConnection( - name="anthropic_server", api_key=api_key - ) +def test_anthropic_chat_with_tools() -> None: + connection = AnthropicChatModelConnection(name="anthropic_server", api_key=api_key) def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL_CONNECTION: diff --git a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py index 74a8fe2ea..64da68900 100644 --- a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py @@ -58,7 +58,7 @@ class AzureOpenAIChatModelConnection(BaseChatModelConnection): ) azure_endpoint: str = Field( default=None, - description="Supported Azure OpenAI endpoints. Example: https://{your-resource-name}.openai.azure.com" + description="Supported Azure OpenAI endpoints. Example: https://{your-resource-name}.openai.azure.com", ) timeout: float = Field( default=60.0, @@ -72,14 +72,14 @@ class AzureOpenAIChatModelConnection(BaseChatModelConnection): ) def __init__( - self, - *, - api_key: str | None = None, - api_version: str | None = None, - azure_endpoint: str | None = None, - timeout: float = 60.0, - max_retries: int = 3, - **kwargs: Any, + self, + *, + api_key: str | None = None, + api_version: str | None = None, + azure_endpoint: str | None = None, + timeout: float = 60.0, + max_retries: int = 3, + **kwargs: Any, ) -> None: """Init method.""" super().__init__( @@ -106,7 +106,12 @@ def client(self) -> AzureOpenAI: ) return self._client - def chat(self, messages: Sequence[ChatMessage], tools: List[Tool] | None = None, **kwargs: Any,) -> ChatMessage: + def chat( + self, + messages: Sequence[ChatMessage], + tools: List[Tool] | None = None, + **kwargs: Any, + ) -> ChatMessage: """Direct communication with model service for chat conversation. Parameters @@ -194,26 +199,26 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup): model_of_azure_deployment: str | None = Field( default=None, description="The underlying model name of the Azure deployment (e.g., 'gpt-4', " - "'gpt-35-turbo'). Used for token counting and cost calculation. " - "Required for token metrics tracking.", + "'gpt-35-turbo'). Used for token counting and cost calculation. " + "Required for token metrics tracking.", ) temperature: float | None = Field( default=None, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output " - "more random, while lower values like 0.2 will make it more focused and deterministic. " - "Not supported by reasoning models (e.g. gpt-5, o-series).", + "more random, while lower values like 0.2 will make it more focused and deterministic. " + "Not supported by reasoning models (e.g. gpt-5, o-series).", ge=0.0, le=2.0, ) max_tokens: int | None = Field( default=None, description="The maximum number of tokens that can be generated in the chat completion. The total length of " - "input tokens and generated tokens is limited by the model's context length.", + "input tokens and generated tokens is limited by the model's context length.", gt=0, ) logprobs: bool | None = Field( description="Whether to return log probabilities of the output tokens or not. If true, returns the log " - "probabilities of each output token returned in the content of message.", + "probabilities of each output token returned in the content of message.", default=False, ) additional_kwargs: Dict[str, Any] = Field( @@ -221,15 +226,15 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup): ) def __init__( - self, - *, - model: str, - model_of_azure_deployment: str | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - logprobs: bool | None = False, - additional_kwargs: Dict[str, Any] | None = None, - **kwargs: Any, + self, + *, + model: str, + model_of_azure_deployment: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + logprobs: bool | None = False, + additional_kwargs: Dict[str, Any] | None = None, + **kwargs: Any, ) -> None: """Init method.""" additional_kwargs = additional_kwargs or {} diff --git a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py index d8a1e9f00..58d4a5c0d 100644 --- a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py @@ -34,7 +34,7 @@ @pytest.mark.skipif(api_key is None, reason="AZURE_OPENAI_API_KEY is not set") -def test_azure_openai_chat_model() -> None: # noqa: D103 +def test_azure_openai_chat_model() -> None: connection = AzureOpenAIChatModelConnection( name="azure_openai", api_key=api_key, @@ -78,7 +78,7 @@ def add(a: int, b: int) -> int: @pytest.mark.skipif(api_key is None, reason="AZURE_OPENAI_API_KEY is not set") -def test_azure_openai_chat_with_tools() -> None: # noqa : D103 +def test_azure_openai_chat_with_tools() -> None: connection = AzureOpenAIChatModelConnection( name="azure_openai", api_key=api_key, @@ -100,7 +100,12 @@ def get_resource(name: str, type: ResourceType) -> Resource: get_resource=get_resource, ) response = chat_model.chat( - [ChatMessage(role=MessageRole.USER, content="You MUST use the add tool to calculate: What is 377 + 688?")] + [ + ChatMessage( + role=MessageRole.USER, + content="You MUST use the add tool to calculate: What is 377 + 688?", + ) + ] ) tool_calls = response.tool_calls assert len(tool_calls) == 1 diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py index dfd67eaed..15249a9cd 100644 --- a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py @@ -33,7 +33,7 @@ @pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") -def test_openai_chat_model() -> None: # noqa: D103 +def test_openai_chat_model() -> None: connection = OpenAIChatModelConnection( name="openai", api_key=api_key, api_base_url=api_base_url ) @@ -71,7 +71,7 @@ def add(a: int, b: int) -> int: @pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") -def test_openai_chat_with_tools() -> None: # noqa : D103 +def test_openai_chat_with_tools() -> None: connection = OpenAIChatModelConnection( name="openai", api_key=api_key, api_base_url=api_base_url ) diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py index 5fb080a23..34702d3a4 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -61,7 +61,7 @@ @pytest.mark.skipif( client is None, reason="Ollama client is not available or test model is missing" ) -def test_ollama_chat() -> None: # noqa :D103 +def test_ollama_chat() -> None: server = OllamaChatModelConnection(name="ollama", request_timeout=120.0) response = server.chat( [ChatMessage(role=MessageRole.USER, content="Hello!")], model=test_model @@ -88,14 +88,14 @@ def add(a: int, b: int) -> int: return a + b -def get_tool(name: str, type: ResourceType) -> FunctionTool: # noqa :D103 +def get_tool(name: str, type: ResourceType) -> FunctionTool: return from_callable(func=add) @pytest.mark.skipif( client is None, reason="Ollama client is not available or test model is missing" ) -def test_ollama_chat_with_tools() -> None: # noqa :D103 +def test_ollama_chat_with_tools() -> None: connection = OllamaChatModelConnection(name="ollama", request_timeout=120.0) def get_resource(name: str, type: ResourceType) -> Resource: diff --git a/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py b/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py index 174fd7fc4..29b62a00d 100644 --- a/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py +++ b/python/flink_agents/integrations/embedding_models/local/ollama_embedding_model.py @@ -58,10 +58,10 @@ class OllamaEmbeddingModelConnection(BaseEmbeddingModelConnection): __client: Client = None def __init__( - self, - base_url: str = "http://localhost:11434", - request_timeout: float | None = DEFAULT_REQUEST_TIMEOUT, - **kwargs: Any, + self, + base_url: str = "http://localhost:11434", + request_timeout: float | None = DEFAULT_REQUEST_TIMEOUT, + **kwargs: Any, ) -> None: """Init method.""" super().__init__( @@ -77,7 +77,9 @@ def client(self) -> Client: self.__client = Client(host=self.base_url, timeout=self.request_timeout) return self.__client - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query.""" # Extract specific parameters model = kwargs.pop("model") @@ -117,6 +119,7 @@ class OllamaEmbeddingModelSetup(BaseEmbeddingModelSetup): Additional model parameters for the Ollama embeddings API, e.g. num_ctx, temperature, etc. """ + truncate: bool = Field( default=True, description="Controls what happens if input text exceeds model's maximum length (default: True).", @@ -124,7 +127,7 @@ class OllamaEmbeddingModelSetup(BaseEmbeddingModelSetup): keep_alive: float | str | None = Field( default="5m", description="Controls how long the model will stay loaded into memory following the " - "request(default: 5m)", + "request(default: 5m)", ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, @@ -132,14 +135,14 @@ class OllamaEmbeddingModelSetup(BaseEmbeddingModelSetup): ) def __init__( - self, - *, - connection: str, - model: str, - truncate: bool = True, - additional_kwargs: Dict[str, Any] | None = None, - keep_alive: float | str | None = None, - **kwargs: Any, + self, + *, + connection: str, + model: str, + truncate: bool = True, + additional_kwargs: Dict[str, Any] | None = None, + keep_alive: float | str | None = None, + **kwargs: Any, ) -> None: """Init method.""" if additional_kwargs is None: diff --git a/python/flink_agents/integrations/embedding_models/local/tests/test_ollama_embedding_model.py b/python/flink_agents/integrations/embedding_models/local/tests/test_ollama_embedding_model.py index 3dfc6e01d..e275a97fe 100644 --- a/python/flink_agents/integrations/embedding_models/local/tests/test_ollama_embedding_model.py +++ b/python/flink_agents/integrations/embedding_models/local/tests/test_ollama_embedding_model.py @@ -54,13 +54,13 @@ @pytest.mark.skipif( - client is None, reason="Ollama client is not available or test embedding model is missing" + client is None, + reason="Ollama client is not available or test embedding model is missing", ) def test_ollama_embedding_setup() -> None: """Test embedding functionality with OllamaEmbeddingModelSetup.""" connection = OllamaEmbeddingModelConnection( - name="ollama_embed", - base_url="http://localhost:11434" + name="ollama_embed", base_url="http://localhost:11434" ) def get_resource(name: str, type: ResourceType) -> Resource: @@ -71,7 +71,7 @@ def get_resource(name: str, type: ResourceType) -> Resource: connection="ollama_embed", model=test_model, truncate=True, - get_resource=get_resource + get_resource=get_resource, ) # Test embedding through setup diff --git a/python/flink_agents/integrations/embedding_models/openai_embedding_model.py b/python/flink_agents/integrations/embedding_models/openai_embedding_model.py index 234ad8d1c..eac4c8f10 100644 --- a/python/flink_agents/integrations/embedding_models/openai_embedding_model.py +++ b/python/flink_agents/integrations/embedding_models/openai_embedding_model.py @@ -111,7 +111,9 @@ def client(self) -> OpenAI: ) return self.__client - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query.""" # Extract OpenAI specific parameters model = kwargs.pop("model") @@ -123,7 +125,9 @@ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[ response = self.client.embeddings.create( model=model, input=text, - encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + encoding_format=encoding_format + if encoding_format is not None + else NOT_GIVEN, dimensions=dimensions if dimensions is not None else NOT_GIVEN, user=user if user is not None else NOT_GIVEN, ) diff --git a/python/flink_agents/integrations/embedding_models/tests/test_openai_embedding_model.py b/python/flink_agents/integrations/embedding_models/tests/test_openai_embedding_model.py index 3f50736c6..c1a736f2e 100644 --- a/python/flink_agents/integrations/embedding_models/tests/test_openai_embedding_model.py +++ b/python/flink_agents/integrations/embedding_models/tests/test_openai_embedding_model.py @@ -30,10 +30,8 @@ @pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") -def test_openai_embedding_model() -> None: # noqa: D103 - connection = OpenAIEmbeddingModelConnection( - name="openai", api_key=api_key - ) +def test_openai_embedding_model() -> None: + connection = OpenAIEmbeddingModelConnection(name="openai", api_key=api_key) def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.EMBEDDING_MODEL_CONNECTION: diff --git a/python/flink_agents/integrations/embedding_models/tests/test_tongyi_embedding_model.py b/python/flink_agents/integrations/embedding_models/tests/test_tongyi_embedding_model.py index 59712efa9..54c595780 100644 --- a/python/flink_agents/integrations/embedding_models/tests/test_tongyi_embedding_model.py +++ b/python/flink_agents/integrations/embedding_models/tests/test_tongyi_embedding_model.py @@ -49,7 +49,9 @@ def get_resource(name: str, type: ResourceType) -> Resource: ) embedding_model.open() - response = embedding_model.embed("The quality of the clothes is excellent, very beautiful, worth the wait, I like it and will buy here again") + response = embedding_model.embed( + "The quality of the clothes is excellent, very beautiful, worth the wait, I like it and will buy here again" + ) assert response is not None assert isinstance(response, list) assert len(response) > 0 diff --git a/python/flink_agents/integrations/mcp/mcp.py b/python/flink_agents/integrations/mcp/mcp.py index 856be4a39..d45aab739 100644 --- a/python/flink_agents/integrations/mcp/mcp.py +++ b/python/flink_agents/integrations/mcp/mcp.py @@ -133,6 +133,7 @@ def close(self) -> None: finally: self.mcp_server = None + class MCPServer(Resource, ABC): """Resource representing an MCP server and exposing its tools/prompts. @@ -145,7 +146,7 @@ class MCPServer(Resource, ABC): endpoint: str headers: Dict[str, Any] | None = None timeout: int = 30 - sse_read_timeout: int = 60*5 + sse_read_timeout: int = 60 * 5 auth: httpx.Auth | None = None session: ClientSession = Field(default=None, exclude=True) @@ -187,20 +188,19 @@ async def _get_session(self) -> AsyncIterator[ClientSession]: msg = f"Invalid HTTP endpoint: {self.endpoint}" raise ValueError(msg) - async with streamablehttp_client( + async with ( + streamablehttp_client( url=self.endpoint, headers=self.headers, timeout=timedelta(seconds=self.timeout), sse_read_timeout=timedelta(seconds=self.sse_read_timeout), auth=self.auth, - ) as (read, write, _), ClientSession( - read, - write - ) as session: + ) as (read, write, _), + ClientSession(read, write) as session, + ): await session.initialize() yield session - async def _cleanup_connection(self) -> None: """Clean up connection resources.""" try: @@ -214,11 +214,12 @@ async def _cleanup_connection(self) -> None: async def call_tool_async(self, tool_name: str, *args: Any, **kwargs: Any) -> Any: """Call a tool on the MCP server asynchronously.""" async with self._get_session() as session: - arguments = kwargs if kwargs else (args[0] if args else {}) result = await session.call_tool( - tool_name, arguments, read_timeout_seconds=timedelta(seconds=self.timeout), + tool_name, + arguments, + read_timeout_seconds=timedelta(seconds=self.timeout), ) content = [extract_mcp_content_item(item) for item in result.content] diff --git a/python/flink_agents/integrations/mcp/tests/mcp_server.py b/python/flink_agents/integrations/mcp/tests/mcp_server.py index 2e295d0b7..5c1d90a81 100644 --- a/python/flink_agents/integrations/mcp/tests/mcp_server.py +++ b/python/flink_agents/integrations/mcp/tests/mcp_server.py @@ -18,6 +18,7 @@ try: import dotenv + dotenv.load_dotenv() except ImportError: # dotenv is optional for this test server @@ -34,6 +35,7 @@ def ask_sum(a: int, b: int) -> str: """Prompt of add tool.""" return f"Can you please calculate the sum of {a} and {b}?" + @mcp.tool() async def add(a: int, b: int) -> int: """Get the detailed information of a specified IP address. @@ -47,5 +49,5 @@ async def add(a: int, b: int) -> int: """ return a + b -mcp.run("streamable-http") +mcp.run("streamable-http") diff --git a/python/flink_agents/integrations/mcp/tests/test_mcp.py b/python/flink_agents/integrations/mcp/tests/test_mcp.py index 46ae03156..7a6e14877 100644 --- a/python/flink_agents/integrations/mcp/tests/test_mcp.py +++ b/python/flink_agents/integrations/mcp/tests/test_mcp.py @@ -29,11 +29,14 @@ from flink_agents.integrations.mcp.mcp import MCPServer -def run_server() -> None: # noqa : D103 +def run_server() -> None: runpy.run_path(f"{current_dir}/mcp_server.py") + current_dir = Path(__file__).parent -def test_mcp() -> None: # noqa : D103 + + +def test_mcp() -> None: process = multiprocessing.Process(target=run_server) process.start() time.sleep(5) @@ -44,10 +47,12 @@ def test_mcp() -> None: # noqa : D103 prompt = prompts[0] assert prompt.name == "ask_sum" message = prompt.format_messages(role=MessageRole.SYSTEM, a="1", b="2") - assert [ChatMessage( + assert [ + ChatMessage( role=MessageRole.USER, content="Can you please calculate the sum of 1 and 2?", - )] == message + ) + ] == message tools = mcp_server.list_tools() assert len(tools) == 1 tool = tools[0] @@ -56,11 +61,10 @@ def test_mcp() -> None: # noqa : D103 process.kill() - class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" - def __init__(self) -> None: # noqa:D107 + def __init__(self) -> None: self.tokens: OAuthToken | None = None self.client_info: OAuthClientInformationFull | None = None @@ -81,16 +85,17 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None self.client_info = client_info -async def handle_redirect(auth_url: str) -> None: # noqa:D103 +async def handle_redirect(auth_url: str) -> None: print(f"Visit: {auth_url}") -async def handle_callback() -> tuple[str, str | None]: # noqa:D103 +async def handle_callback() -> tuple[str, str | None]: callback_url = input("Paste callback URL: ") params = parse_qs(urlparse(callback_url).query) return params["code"][0], params.get("state", [None])[0] -def test_serialize_mcp_server() -> None: # noqa:D103 + +def test_serialize_mcp_server() -> None: oauth_auth = OAuthClientProvider( server_url="http://localhost:8001", client_metadata=OAuthClientMetadata( @@ -119,8 +124,3 @@ def test_serialize_mcp_server() -> None: # noqa:D103 deserialized.auth.context.client_metadata == mcp_server.auth.context.client_metadata ) - - - - - diff --git a/python/flink_agents/integrations/mcp/utils.py b/python/flink_agents/integrations/mcp/utils.py index ba96aa734..fb6cf8d4a 100644 --- a/python/flink_agents/integrations/mcp/utils.py +++ b/python/flink_agents/integrations/mcp/utils.py @@ -42,24 +42,28 @@ def extract_mcp_content_item(content_item: Any) -> Dict[str, Any] | str: return { "type": "image", "data": content_item.data, - "mimeType": content_item.mimeType + "mimeType": content_item.mimeType, } elif isinstance(content_item, types.EmbeddedResource): if isinstance(content_item.resource, types.TextResourceContents): return { "type": "resource", "uri": content_item.resource.uri, - "text": content_item.resource.text + "text": content_item.resource.text, } elif isinstance(content_item.resource, types.BlobResourceContents): return { "type": "resource", "uri": content_item.resource.uri, - "blob": content_item.resource.blob + "blob": content_item.resource.blob, } else: err_msg = f"Unsupported content type: {type(content_item)}" raise TypeError(err_msg) else: # Handle unknown content types as generic dict - return content_item.model_dump() if hasattr(content_item, 'model_dump') else str(content_item) + return ( + content_item.model_dump() + if hasattr(content_item, "model_dump") + else str(content_item) + ) diff --git a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py index 438d02ea3..96fe4b649 100644 --- a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py +++ b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py @@ -43,19 +43,19 @@ database = os.environ.get("TEST_DATABASE") -class MockEmbeddingModel(BaseEmbeddingModelSetup): # noqa: D101 +class MockEmbeddingModel(BaseEmbeddingModelSetup): name: str connection: str = "mock" model: str = "mock" - def open(self) -> None: # noqa: D102 + def open(self) -> None: pass @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {} - def embed(self, text: str, **kwargs: Any) -> list[float]: # noqa: D102 + def embed(self, text: str, **kwargs: Any) -> list[float]: if "ChromaDB" in text: return [0.2, 0.3, 0.4, 0.5, 0.6] else: diff --git a/python/flink_agents/plan/actions/action.py b/python/flink_agents/plan/actions/action.py index da40a3551..a059df9a2 100644 --- a/python/flink_agents/plan/actions/action.py +++ b/python/flink_agents/plan/actions/action.py @@ -27,6 +27,7 @@ _CONFIG_TYPE = "__config_type__" + class Action(BaseModel): """Representation of an agent action with event listening and function execution. @@ -42,6 +43,7 @@ class Action(BaseModel): listen_event_types : List[str] List of event types that will trigger this Action's execution. """ + model_config = ConfigDict(arbitrary_types_allowed=True) name: str diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 0a0c780e6..a697f7237 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -122,10 +122,13 @@ def _accumulate_retry_stats( ) -> None: """Accumulate retry stats for a given initial request across tool call rounds.""" retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {} - stats = retry_stats_context.get(initial_request_id, { - "total_retry_count": 0, - "total_retry_wait_sec": 0, - }) + stats = retry_stats_context.get( + initial_request_id, + { + "total_retry_count": 0, + "total_retry_wait_sec": 0, + }, + ) stats["total_retry_count"] += retry_count stats["total_retry_wait_sec"] += retry_wait_sec retry_stats_context[initial_request_id] = stats @@ -138,10 +141,13 @@ def _get_retry_stats( ) -> dict: """Get accumulated retry stats for a given initial request.""" retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {} - return retry_stats_context.get(initial_request_id, { - "total_retry_count": 0, - "total_retry_wait_sec": 0, - }) + return retry_stats_context.get( + initial_request_id, + { + "total_retry_count": 0, + "total_retry_wait_sec": 0, + }, + ) def _record_retry_metrics( @@ -296,7 +302,10 @@ async def chat( if actual_retry_count > 0: _accumulate_retry_stats( - ctx.sensory_memory, initial_request_id, actual_retry_count, total_wait_time_sec + ctx.sensory_memory, + initial_request_id, + actual_retry_count, + total_wait_time_sec, ) if ( @@ -310,7 +319,9 @@ async def chat( total_retry_count = retry_stats["total_retry_count"] total_retry_wait_sec = retry_stats["total_retry_wait_sec"] - _record_retry_metrics(ctx, chat_model.connection, total_retry_count, total_retry_wait_sec) + _record_retry_metrics( + ctx, chat_model.connection, total_retry_count, total_retry_wait_sec + ) ctx.send_event( ChatResponseEvent( diff --git a/python/flink_agents/plan/actions/tool_call_action.py b/python/flink_agents/plan/actions/tool_call_action.py index 6fbefaee7..0ff1a4e15 100644 --- a/python/flink_agents/plan/actions/tool_call_action.py +++ b/python/flink_agents/plan/actions/tool_call_action.py @@ -26,6 +26,7 @@ _logger = logging.getLogger(__name__) + async def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: """Built-in action for processing tool call requests.""" tool_call_async = ctx.config.get(AgentExecutionOptions.TOOL_CALL_ASYNC) diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index 2f31ff3c2..b737cb0e3 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -250,7 +250,9 @@ def _get_actions(agent: Agent) -> List[Action]: return actions -def _get_resource_providers(agent: Agent, config: AgentConfiguration) -> List[ResourceProvider]: +def _get_resource_providers( + agent: Agent, config: AgentConfiguration +) -> List[ResourceProvider]: resource_providers = [] # retrieve resource declared by decorator for name, value in agent.__class__.__dict__.items(): @@ -340,7 +342,10 @@ def _get_resource_providers(agent: Agent, config: AgentConfiguration) -> List[Re def _add_mcp_server( - name: str, resource_providers: List[ResourceProvider], descriptor: ResourceDescriptor, config: AgentConfiguration + name: str, + resource_providers: List[ResourceProvider], + descriptor: ResourceDescriptor, + config: AgentConfiguration, ) -> None: provider = PythonResourceProvider.get(name=name, descriptor=descriptor) @@ -349,7 +354,9 @@ def _add_mcp_server( def get_resource(name: str, type: ResourceType) -> Any: """Placeholder - MCP server construction doesn't need resource resolution.""" - mcp_server = cast("MCPServer", provider.provide(get_resource=get_resource, config=config)) + mcp_server = cast( + "MCPServer", provider.provide(get_resource=get_resource, config=config) + ) resource_providers.extend( [ diff --git a/python/flink_agents/plan/configuration.py b/python/flink_agents/plan/configuration.py index 06ed9955a..8db17ebcd 100644 --- a/python/flink_agents/plan/configuration.py +++ b/python/flink_agents/plan/configuration.py @@ -28,7 +28,7 @@ ) -def flatten_dict(d: Dict, parent_key: str = '', sep: str = '.') -> Dict[str, Any]: +def flatten_dict(d: Dict, parent_key: str = "", sep: str = ".") -> Dict[str, Any]: """Flatten a nested dictionary into a single-level dictionary. This function recursively traverses the dictionary, converting multi-level @@ -55,6 +55,7 @@ def flatten_dict(d: Dict, parent_key: str = '', sep: str = '.') -> Dict[str, Any items[new_key] = v return items + class AgentConfiguration(BaseModel, Configuration): """Base class for config objects in the system. Provides a flat dict interface to access nested config values. @@ -65,11 +66,13 @@ class AgentConfiguration(BaseModel, Configuration): def __init__(self, conf_data: Dict[str, Any] | None = None) -> None: """Initialize with optional configuration data.""" if conf_data is None: - super().__init__(conf_data = {}) + super().__init__(conf_data={}) else: - super().__init__(conf_data = conf_data) + super().__init__(conf_data=conf_data) - def get_value_with_type(self, key: str, config_type: Type[Any], default: Any) -> Any: + def get_value_with_type( + self, key: str, config_type: Type[Any], default: Any + ) -> Any: """Helper method for all the get_xxx functions to avoid duplicate code. Args: @@ -92,24 +95,26 @@ def get_value_with_type(self, key: str, config_type: Type[Any], default: Any) -> raise ValueError(msg) from e @override - def get_int(self, key: str, default: int | None=None) -> int: + def get_int(self, key: str, default: int | None = None) -> int: return self.get_value_with_type(key, int, default) @override - def get_float(self, key: str, default: float | None=None) -> float: + def get_float(self, key: str, default: float | None = None) -> float: return self.get_value_with_type(key, float, default) @override - def get_bool(self, key: str, default: bool | None=None) -> bool: + def get_bool(self, key: str, default: bool | None = None) -> bool: return self.get_value_with_type(key, bool, default) @override - def get_str(self, key: str, default: str | None=None) -> str: + def get_str(self, key: str, default: str | None = None) -> str: return self.get_value_with_type(key, str, default) @override def get(self, option: ConfigOption) -> Any: - return self.get_value_with_type(option.get_key(), option.get_type(), option.get_default_value()) + return self.get_value_with_type( + option.get_key(), option.get_type(), option.get_default_value() + ) @override def set_str(self, key: str, value: str) -> None: @@ -141,7 +146,7 @@ def load_from_file(self, config_path: str | None = None) -> None: path = Path(config_path) with path.open() as f: raw_config = yaml.safe_load(f) - self.conf_data.update(flatten_dict(raw_config.get('agent', {}))) + self.conf_data.update(flatten_dict(raw_config.get("agent", {}))) def get_conf_data(self) -> dict: """Get the configuration data dictionary. diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index bb5fa5a12..d8c10a4ba 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -360,7 +360,10 @@ def call_python_awaitable(awaitable: Any) -> Tuple[bool, Any]: return True, e.value if hasattr(e, "value") else None except RuntimeError as e: err_msg = str(e) - if "no running event loop" in err_msg or "await wasn't used with future" in err_msg: + if ( + "no running event loop" in err_msg + or "await wasn't used with future" in err_msg + ): raise RuntimeError(_ASYNCIO_ERROR_MESSAGE) from e raise except Exception: diff --git a/python/flink_agents/plan/resource_provider.py b/python/flink_agents/plan/resource_provider.py index 714a01c26..d8e00e684 100644 --- a/python/flink_agents/plan/resource_provider.py +++ b/python/flink_agents/plan/resource_provider.py @@ -97,11 +97,10 @@ def get(name: str, descriptor: ResourceDescriptor) -> "PythonResourceProvider": """Create PythonResourceProvider instance.""" clazz = descriptor.clazz return PythonResourceProvider( - name=name, - type=clazz.resource_type(), - descriptor=descriptor, - ) - + name=name, + type=clazz.resource_type(), + descriptor=descriptor, + ) def provide(self, get_resource: Callable, config: AgentConfiguration) -> Resource: """Create resource in runtime.""" @@ -154,6 +153,7 @@ def provide(self, get_resource: Callable, config: AgentConfiguration) -> Resourc self.resource = clazz.model_validate(self.serialized) return self.resource + JAVA_RESOURCE_MAPPING: dict[ResourceType, str] = { ResourceType.CHAT_MODEL: "flink_agents.runtime.java.java_chat_model.JavaChatModelSetupImpl", ResourceType.CHAT_MODEL_CONNECTION: "flink_agents.runtime.java.java_chat_model.JavaChatModelConnectionImpl", @@ -162,6 +162,7 @@ def provide(self, get_resource: Callable, config: AgentConfiguration) -> Resourc ResourceType.VECTOR_STORE: "flink_agents.runtime.java.java_vector_store.JavaVectorStoreImpl", } + class JavaResourceProvider(ResourceProvider): """Represent Resource Provider declared by Java. @@ -179,7 +180,7 @@ def get(name: str, descriptor: ResourceDescriptor) -> "JavaResourceProvider": kwargs.update(descriptor.arguments) clazz = descriptor.arguments.get("java_clazz", "") - if len(clazz) <1: + if len(clazz) < 1: err_msg = f"java_clazz are not set for {wrapper_clazz.__name__}" raise KeyError(err_msg) @@ -204,8 +205,12 @@ def provide(self, get_resource: Callable, config: AgentConfiguration) -> Resourc cls = get_resource_class(module_path, class_name) kwargs = self.descriptor.arguments - return cls(**kwargs, get_resource=get_resource, j_resource=j_resource, j_resource_adapter= self._j_resource_adapter) - + return cls( + **kwargs, + get_resource=get_resource, + j_resource=j_resource, + j_resource_adapter=self._j_resource_adapter, + ) def set_java_resource_adapter(self, j_resource_adapter: Any) -> None: """Set java resource adapter for java resource initialization.""" diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py index a60143233..922a98b3a 100644 --- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py @@ -116,7 +116,9 @@ def _create_mock_runner_context( id(AgentExecutionOptions.CHAT_ASYNC): False, } config.get = MagicMock( - side_effect=lambda option: option_values.get(id(option), option.get_default_value()) + side_effect=lambda option: option_values.get( + id(option), option.get_default_value() + ) ) ctx = MagicMock() @@ -125,7 +127,9 @@ def _create_mock_runner_context( ctx.action_metric_group = metric_group ctx.send_event = MagicMock(side_effect=lambda e: sent_events.append(e)) ctx.get_resource = MagicMock(return_value=chat_model) - ctx.durable_execute = MagicMock(side_effect=lambda fn, *args, **kwargs: fn(*args, **kwargs)) + ctx.durable_execute = MagicMock( + side_effect=lambda fn, *args, **kwargs: fn(*args, **kwargs) + ) return ctx, sent_events, metric_group, sensory_memory @@ -149,7 +153,13 @@ def test_chat_succeeds_without_retry(self) -> None: request_id = uuid4() asyncio.run( - chat(request_id, chat_model.connection, [ChatMessage(role=MessageRole.USER, content="hi")], None, ctx) + chat( + request_id, + chat_model.connection, + [ChatMessage(role=MessageRole.USER, content="hi")], + None, + ctx, + ) ) assert len(sent_events) == 1 @@ -183,7 +193,13 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: start = time.monotonic() asyncio.run( - chat(request_id, "test-model", [ChatMessage(role=MessageRole.USER, content="hi")], None, ctx) + chat( + request_id, + "test-model", + [ChatMessage(role=MessageRole.USER, content="hi")], + None, + ctx, + ) ) elapsed = time.monotonic() - start @@ -212,7 +228,13 @@ def test_chat_exhausts_retries_and_raises(self) -> None: with pytest.raises(RuntimeError, match="persistent error"): asyncio.run( - chat(request_id, "test-model", [ChatMessage(role=MessageRole.USER, content="hi")], None, ctx) + chat( + request_id, + "test-model", + [ChatMessage(role=MessageRole.USER, content="hi")], + None, + ctx, + ) ) assert len(sent_events) == 0 diff --git a/python/flink_agents/plan/tests/compatibility/generate_agent_plan_json.py b/python/flink_agents/plan/tests/compatibility/generate_agent_plan_json.py index 24ad9cf8a..f1dc4e9c8 100644 --- a/python/flink_agents/plan/tests/compatibility/generate_agent_plan_json.py +++ b/python/flink_agents/plan/tests/compatibility/generate_agent_plan_json.py @@ -30,7 +30,9 @@ # correspond modification should be applied to it when modify this file. if __name__ == "__main__": json_path = sys.argv[1] - agent_plan = AgentPlan.from_agent(PythonAgentPlanCompatibilityTestAgent(), AgentConfiguration()) + agent_plan = AgentPlan.from_agent( + PythonAgentPlanCompatibilityTestAgent(), AgentConfiguration() + ) json_value = agent_plan.model_dump_json(serialize_as_any=True, indent=4) with Path(json_path).open("w") as f: f.write(json_value) diff --git a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py index ee0075681..80437128d 100644 --- a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py +++ b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py @@ -60,7 +60,10 @@ def second_action(event: InputEvent, ctx: RunnerContext) -> None: def chat_model() -> ResourceDescriptor: """ChatModel can be used in action.""" return ResourceDescriptor( - clazz=f"{MockChatModel.__module__}.{MockChatModel.__name__}", name="chat_model", prompt="prompt", tools=["add"] + clazz=f"{MockChatModel.__module__}.{MockChatModel.__name__}", + name="chat_model", + prompt="prompt", + tools=["add"], ) @tool diff --git a/python/flink_agents/plan/tests/test_action.py b/python/flink_agents/plan/tests/test_action.py index 4410dd43d..9e0b9c61f 100644 --- a/python/flink_agents/plan/tests/test_action.py +++ b/python/flink_agents/plan/tests/test_action.py @@ -28,15 +28,15 @@ from flink_agents.plan.function import PythonFunction -def legal_signature(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D103 +def legal_signature(event: InputEvent, ctx: RunnerContext) -> None: pass -def illegal_signature(value: int, ctx: RunnerContext) -> None: # noqa: D103 +def illegal_signature(value: int, ctx: RunnerContext) -> None: pass -def test_action_signature_legal() -> None: # noqa: D103 +def test_action_signature_legal() -> None: Action( name="legal", exec=PythonFunction.from_callable(legal_signature), @@ -44,7 +44,7 @@ def test_action_signature_legal() -> None: # noqa: D103 ) -def test_action_signature_illegal() -> None: # noqa: D103 +def test_action_signature_illegal() -> None: with pytest.raises(TypeError): Action( name="illegal", @@ -54,7 +54,7 @@ def test_action_signature_illegal() -> None: # noqa: D103 @pytest.fixture(scope="module") -def action() -> Action: # noqa: D103 +def action() -> Action: func = PythonFunction.from_callable(legal_signature) return Action( name="legal", @@ -74,7 +74,7 @@ def action() -> Action: # noqa: D103 current_dir = Path(__file__).parent -def test_action_serialize(action: Action) -> None: # noqa: D103 +def test_action_serialize(action: Action) -> None: json_value = action.model_dump_json(serialize_as_any=True, indent=4) with Path.open(Path(f"{current_dir}/resources/action.json")) as f: expected_json = f.read() @@ -83,7 +83,7 @@ def test_action_serialize(action: Action) -> None: # noqa: D103 assert actual == expected -def test_action_deserialize(action: Action) -> None: # noqa: D103 +def test_action_deserialize(action: Action) -> None: with Path.open(Path(f"{current_dir}/resources/action.json")) as f: expected_json = f.read() action = Action.model_validate_json(expected_json) diff --git a/python/flink_agents/plan/tests/test_agent_plan.py b/python/flink_agents/plan/tests/test_agent_plan.py index 583940cb9..aed52caa2 100644 --- a/python/flink_agents/plan/tests/test_agent_plan.py +++ b/python/flink_agents/plan/tests/test_agent_plan.py @@ -48,16 +48,16 @@ from flink_agents.runtime.resource_cache import ResourceCache -class AgentForTest(Agent): # noqa D101 +class AgentForTest(Agent): @action(InputEvent) @staticmethod - def increment(event: Event, ctx: RunnerContext) -> None: # noqa D102 + def increment(event: Event, ctx: RunnerContext) -> None: value = event.input value += 1 ctx.send_event(OutputEvent(output=value)) -def test_from_agent(): # noqa D102 +def test_from_agent(): agent = AgentForTest() agent_plan = AgentPlan.from_agent(agent, AgentConfiguration()) event_type = f"{InputEvent.__module__}.{InputEvent.__name__}" @@ -72,14 +72,14 @@ def test_from_agent(): # noqa D102 assert action.listen_event_types == [event_type] -class InvalidAgent(Agent): # noqa D101 +class InvalidAgent(Agent): @action(InputEvent) @staticmethod - def invalid_signature_action(event: Event) -> None: # noqa D102 + def invalid_signature_action(event: Event) -> None: pass -def test_to_agent_invalid_signature() -> None: # noqa D103 +def test_to_agent_invalid_signature() -> None: agent = InvalidAgent() with pytest.raises(TypeError): AgentPlan.from_agent(agent, AgentConfiguration()) @@ -89,7 +89,7 @@ class MyEvent(Event): """Event for testing purposes.""" -class MockChatModelImpl(BaseChatModelSetup): # noqa: D101 +class MockChatModelImpl(BaseChatModelSetup): host: str desc: str @@ -97,11 +97,11 @@ def open(self) -> None: """Do nothing.""" @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {} @classmethod - def resource_type(cls) -> ResourceType: # noqa: D102 + def resource_type(cls) -> ResourceType: return ResourceType.CHAT_MODEL def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: @@ -111,7 +111,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: ) -class MockEmbeddingModelConnection(BaseEmbeddingModelConnection): # noqa: D101 +class MockEmbeddingModelConnection(BaseEmbeddingModelConnection): api_key: str def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float]: @@ -121,19 +121,19 @@ def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float]: return [[0.1234, -0.5678, 0.9012, -0.3456, 0.7890]] -class MockEmbeddingModelSetup(BaseEmbeddingModelSetup): # noqa: D101 +class MockEmbeddingModelSetup(BaseEmbeddingModelSetup): @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {"model": self.model} -class MockVectorStore(BaseVectorStore): # noqa: D101 +class MockVectorStore(BaseVectorStore): host: str port: int collection_name: str @property - def store_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def store_kwargs(self) -> Dict[str, Any]: return {"collection_name": self.collection_name} def size(self, collection_name: str | None = None) -> int: @@ -182,10 +182,10 @@ def _query_embedding( ][:limit] -class MyAgent(Agent): # noqa: D101 +class MyAgent(Agent): @chat_model_setup @staticmethod - def mock() -> ResourceDescriptor: # noqa: D102 + def mock() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{MockChatModelImpl.__module__}.{MockChatModelImpl.__name__}", host="8.8.8.8", @@ -195,14 +195,15 @@ def mock() -> ResourceDescriptor: # noqa: D102 @embedding_model_connection @staticmethod - def mock_embedding_conn() -> ResourceDescriptor: # noqa: D102 + def mock_embedding_conn() -> ResourceDescriptor: return ResourceDescriptor( - clazz=f"{MockEmbeddingModelConnection.__module__}.{MockEmbeddingModelConnection.__name__}", api_key="mock-api-key" + clazz=f"{MockEmbeddingModelConnection.__module__}.{MockEmbeddingModelConnection.__name__}", + api_key="mock-api-key", ) @embedding_model_setup @staticmethod - def mock_embedding() -> ResourceDescriptor: # noqa: D102 + def mock_embedding() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{MockEmbeddingModelSetup.__module__}.{MockEmbeddingModelSetup.__name__}", model="test-model", @@ -211,7 +212,7 @@ def mock_embedding() -> ResourceDescriptor: # noqa: D102 @vector_store @staticmethod - def mock_vector_store() -> ResourceDescriptor: # noqa: D102 + def mock_vector_store() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{MockVectorStore.__module__}.{MockVectorStore.__name__}", embedding_model="mock_embedding", @@ -222,17 +223,17 @@ def mock_vector_store() -> ResourceDescriptor: # noqa: D102 @action(InputEvent) @staticmethod - def first_action(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 + def first_action(event: InputEvent, ctx: RunnerContext) -> None: pass @action(InputEvent, MyEvent) @staticmethod - def second_action(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 + def second_action(event: InputEvent, ctx: RunnerContext) -> None: pass @pytest.fixture(scope="module") -def agent_plan() -> AgentPlan: # noqa: D103 +def agent_plan() -> AgentPlan: return AgentPlan.from_agent( MyAgent(), AgentConfiguration({"mock.key": "mock.value"}) ) @@ -241,7 +242,7 @@ def agent_plan() -> AgentPlan: # noqa: D103 current_dir = Path(__file__).parent -def test_agent_plan_serialize(agent_plan: AgentPlan) -> None: # noqa: D103 +def test_agent_plan_serialize(agent_plan: AgentPlan) -> None: json_value = agent_plan.model_dump_json(serialize_as_any=True, indent=4) with Path.open(Path(f"{current_dir}/resources/agent_plan.json")) as f: expected_json = f.read() @@ -250,14 +251,14 @@ def test_agent_plan_serialize(agent_plan: AgentPlan) -> None: # noqa: D103 assert actual == expected -def test_agent_plan_deserialize(agent_plan: AgentPlan) -> None: # noqa: D103 +def test_agent_plan_deserialize(agent_plan: AgentPlan) -> None: with Path.open(Path(f"{current_dir}/resources/agent_plan.json")) as f: expected_json = f.read() deserialized_agent_plan = AgentPlan.model_validate_json(expected_json) assert deserialized_agent_plan == agent_plan -def test_get_resource() -> None: # noqa: D103 +def test_get_resource() -> None: agent_plan = AgentPlan.from_agent(MyAgent(), AgentConfiguration()) cache = ResourceCache(agent_plan.resource_providers, agent_plan.config) mock = cache.get_resource("mock", ResourceType.CHAT_MODEL) @@ -267,7 +268,7 @@ def test_get_resource() -> None: # noqa: D103 ) -def test_add_action_and_resource_to_agent() -> None: # noqa: D103 +def test_add_action_and_resource_to_agent() -> None: my_agent = Agent() my_agent.add_action( name="first_action", events=[InputEvent], func=MyAgent.first_action @@ -290,7 +291,8 @@ def test_add_action_and_resource_to_agent() -> None: # noqa: D103 name="mock_embedding_conn", resource_type=ResourceType.EMBEDDING_MODEL_CONNECTION, instance=ResourceDescriptor( - clazz=f"{MockEmbeddingModelConnection.__module__}.{MockEmbeddingModelConnection.__name__}", api_key="mock-api-key" + clazz=f"{MockEmbeddingModelConnection.__module__}.{MockEmbeddingModelConnection.__name__}", + api_key="mock-api-key", ), ) my_agent.add_resource( diff --git a/python/flink_agents/plan/tests/test_configuration.py b/python/flink_agents/plan/tests/test_configuration.py index 7a874edfd..91996783c 100644 --- a/python/flink_agents/plan/tests/test_configuration.py +++ b/python/flink_agents/plan/tests/test_configuration.py @@ -40,7 +40,7 @@ def test_load_configuration_from_file() -> None: } } - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml.dump(test_data, f) config_file = f.name @@ -50,13 +50,13 @@ def test_load_configuration_from_file() -> None: config.load_from_file(config_file) # Test that nested configuration is properly flattened - assert config.get_str('database.host') == 'localhost' - assert config.get_int('database.port') == 5432 - assert config.get_str('database.credentials.username') == 'admin' - assert config.get_str('database.credentials.password') == 'secret' - assert config.get_str('api.endpoint') == '/api/v1' - assert config.get_float('api.timeout') == 30.0 - assert config.get_bool('debug') is True + assert config.get_str("database.host") == "localhost" + assert config.get_int("database.port") == 5432 + assert config.get_str("database.credentials.username") == "admin" + assert config.get_str("database.credentials.password") == "secret" + assert config.get_str("api.endpoint") == "/api/v1" + assert config.get_float("api.timeout") == 30.0 + assert config.get_bool("debug") is True finally: config_file = Path(config_file) config_file.unlink() @@ -76,13 +76,13 @@ def test_load_configuration_with_invalid_file() -> None: """Test loading configuration with a non-existent file.""" config = AgentConfiguration() with pytest.raises(FileNotFoundError): - config.load_from_file('/path/to/nonexistent/file.yaml') + config.load_from_file("/path/to/nonexistent/file.yaml") def test_load_configuration_with_invalid_yaml() -> None: """Test loading configuration with invalid YAML content.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: - f.write('invalid: yaml: content: [') + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("invalid: yaml: content: [") config_file = f.name try: @@ -96,91 +96,103 @@ def test_load_configuration_with_invalid_yaml() -> None: def test_get_int() -> None: """Test get_int method with various inputs.""" - config = AgentConfiguration({'int_key': 42, 'str_key': '123', 'invalid_key': 'not_an_int'}) + config = AgentConfiguration( + {"int_key": 42, "str_key": "123", "invalid_key": "not_an_int"} + ) # Test normal integer value - assert config.get_int('int_key') == 42 + assert config.get_int("int_key") == 42 # Test string that can be converted to int - assert config.get_int('str_key') == 123 + assert config.get_int("str_key") == 123 # Test default value when key is not found - assert config.get_int('missing_key', 999) == 999 + assert config.get_int("missing_key", 999) == 999 # Test default value when no default specified - assert config.get_int('missing_key') is None + assert config.get_int("missing_key") is None # Test invalid value that cannot be converted to int with pytest.raises(ValueError, match="Invalid value for invalid_key: not_an_int"): - config.get_int('invalid_key') + config.get_int("invalid_key") def test_get_float() -> None: """Test get_float method with various inputs.""" - config = AgentConfiguration({'float_key': 3.14, 'int_key': 42, 'str_key': '2.5', 'invalid_key': 'not_a_float'}) + config = AgentConfiguration( + { + "float_key": 3.14, + "int_key": 42, + "str_key": "2.5", + "invalid_key": "not_a_float", + } + ) # Test normal float value - assert config.get_float('float_key') == 3.14 + assert config.get_float("float_key") == 3.14 # Test int value converted to float - assert config.get_float('int_key') == 42.0 + assert config.get_float("int_key") == 42.0 # Test string that can be converted to float - assert config.get_float('str_key') == 2.5 + assert config.get_float("str_key") == 2.5 # Test default value when key is not found - assert config.get_float('missing_key', 1.23) == 1.23 + assert config.get_float("missing_key", 1.23) == 1.23 # Test default value when no default specified - assert config.get_float('missing_key') is None + assert config.get_float("missing_key") is None # Test invalid value that cannot be converted to float with pytest.raises(ValueError, match="Invalid value for invalid_key: not_a_float"): - config.get_float('invalid_key') + config.get_float("invalid_key") def test_get_bool() -> None: """Test get_bool method with various inputs.""" - config = AgentConfiguration({'bool_key': True, 'false_key': False, 'str_key': 'true'}) + config = AgentConfiguration( + {"bool_key": True, "false_key": False, "str_key": "true"} + ) # Test normal boolean values - assert config.get_bool('bool_key') is True - assert config.get_bool('false_key') is False + assert config.get_bool("bool_key") is True + assert config.get_bool("false_key") is False # Test default value when key is not found - assert config.get_bool('missing_key', True) is True + assert config.get_bool("missing_key", True) is True # Test default value when no default specified - assert config.get_bool('missing_key') is None + assert config.get_bool("missing_key") is None # Note: bool() in Python behaves differently than might be expected # bool('true') is True, but that's Python behavior, not a bug in our code - assert config.get_bool('str_key') is True + assert config.get_bool("str_key") is True def test_get_str() -> None: """Test get_str method with various inputs.""" - config = AgentConfiguration({'str_key': 'hello', 'int_key': 42, 'float_key': 3.14}) + config = AgentConfiguration({"str_key": "hello", "int_key": 42, "float_key": 3.14}) # Test normal string value - assert config.get_str('str_key') == 'hello' + assert config.get_str("str_key") == "hello" # Test int value converted to string - assert config.get_str('int_key') == '42' + assert config.get_str("int_key") == "42" # Test float value converted to string - assert config.get_str('float_key') == '3.14' + assert config.get_str("float_key") == "3.14" # Test default value when key is not found - assert config.get_str('missing_key', 'default') == 'default' + assert config.get_str("missing_key", "default") == "default" # Test default value when no default specified - assert config.get_str('missing_key') is None + assert config.get_str("missing_key") is None # Test None value - assert config.get_str('none_key') is None + assert config.get_str("none_key") is None + -def test_get_with_config_option() -> None: # noqa: D103 +def test_get_with_config_option() -> None: data = { "config.str": "config.value", "config.int": 6789, @@ -207,7 +219,7 @@ def test_get_with_config_option() -> None: # noqa: D103 assert config.get(missing_key) is None -def test_get_with_default_value() -> None: # noqa: D103 +def test_get_with_default_value() -> None: default_str = ConfigOption("default.str", str, "default_value") default_int = ConfigOption("default.int", int, 100) default_double = ConfigOption("default.double", float, 2.5) @@ -219,7 +231,7 @@ def test_get_with_default_value() -> None: # noqa: D103 assert config.get(default_double) == 2.5 -def test_get_with_null_and_default() -> None: # noqa: D103 +def test_get_with_null_and_default() -> None: nullable_str = ConfigOption("nullable.str", str, "default") config = AgentConfiguration() diff --git a/python/flink_agents/plan/tests/test_function.py b/python/flink_agents/plan/tests/test_function.py index 832b33642..a42ac01e5 100644 --- a/python/flink_agents/plan/tests/test_function.py +++ b/python/flink_agents/plan/tests/test_function.py @@ -37,77 +37,77 @@ from flink_agents.plan.function import Function -def check_class(input_event: InputEvent, output_event: OutputEvent) -> None: # noqa: D103 +def check_class(input_event: InputEvent, output_event: OutputEvent) -> None: pass -def test_function_signature_same_class() -> None: # noqa: D103 +def test_function_signature_same_class() -> None: func = PythonFunction.from_callable(check_class) func.check_signature(InputEvent, OutputEvent) -def test_function_signature_subclass() -> None: # noqa: D103 +def test_function_signature_subclass() -> None: func = PythonFunction.from_callable(check_class) func.check_signature(Event, Event) -def test_function_signature_mismatch_class() -> None: # noqa: D103 +def test_function_signature_mismatch_class() -> None: func = PythonFunction.from_callable(check_class) with pytest.raises(TypeError): func.check_signature(OutputEvent, InputEvent) -def test_function_signature_mismatch_args_num() -> None: # noqa: D103 +def test_function_signature_mismatch_args_num() -> None: func = PythonFunction.from_callable(check_class) with pytest.raises(TypeError): func.check_signature(InputEvent) -def check_primitive(value: int) -> None: # noqa: D103 +def check_primitive(value: int) -> None: pass -def test_function_signature_same_primitive() -> None: # noqa: D103 +def test_function_signature_same_primitive() -> None: func = PythonFunction.from_callable(check_primitive) func.check_signature(int) -def test_function_signature_mismatch_primitive() -> None: # noqa: D103 +def test_function_signature_mismatch_primitive() -> None: func = PythonFunction.from_callable(check_primitive) with pytest.raises(TypeError): func.check_signature(float) -def check_mix(a: int, b: InputEvent) -> None: # noqa: D103 +def check_mix(a: int, b: InputEvent) -> None: pass -def test_function_signature_match_mix() -> None: # noqa: D103 +def test_function_signature_match_mix() -> None: func = PythonFunction.from_callable(check_mix) func.check_signature(int, Event) -def test_function_signature_mismatch_mix() -> None: # noqa: D103 +def test_function_signature_mismatch_mix() -> None: func = PythonFunction.from_callable(check_mix) with pytest.raises(TypeError): func.check_signature(Event, int) -def check_generic_type(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None: # noqa: D103 +def check_generic_type(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None: pass -def test_function_signature_generic_type_same() -> None: # noqa: D103 +def test_function_signature_generic_type_same() -> None: func = PythonFunction.from_callable(check_generic_type) func.check_signature(Tuple[Any, ...], Dict[str, Any]) -def test_function_signature_generic_type_match() -> None: # noqa: D103 +def test_function_signature_generic_type_match() -> None: func = PythonFunction.from_callable(check_generic_type) func.check_signature(tuple, dict) -def test_function_signature_generic_type_mismatch() -> None: # noqa: D103 +def test_function_signature_generic_type_mismatch() -> None: func = PythonFunction.from_callable(check_generic_type) with pytest.raises(TypeError): func.check_signature(Tuple[str, ...], Dict[str, Any]) @@ -117,11 +117,11 @@ def test_function_signature_generic_type_mismatch() -> None: # noqa: D103 @pytest.fixture(scope="module") -def func() -> Function: # noqa: D103 +def func() -> Function: return PythonFunction.from_callable(check_class) -def test_python_function_serialize(func: Function) -> None: # noqa: D103 +def test_python_function_serialize(func: Function) -> None: json_value = func.model_dump_json(serialize_as_any=True) with Path.open(Path(f"{current_dir}/resources/python_function.json")) as f: expected_json = f.read() @@ -130,7 +130,7 @@ def test_python_function_serialize(func: Function) -> None: # noqa: D103 assert actual == expected -def test_python_function_deserialize(func: Function) -> None: # noqa: D103 +def test_python_function_deserialize(func: Function) -> None: with Path.open(Path(f"{current_dir}/resources/python_function.json")) as f: expected_json = f.read() deserialized_func = PythonFunction.model_validate_json(expected_json) diff --git a/python/flink_agents/plan/tests/test_resource_provider.py b/python/flink_agents/plan/tests/test_resource_provider.py index 8a0597f69..355c76435 100644 --- a/python/flink_agents/plan/tests/test_resource_provider.py +++ b/python/flink_agents/plan/tests/test_resource_provider.py @@ -26,25 +26,29 @@ current_dir = Path(__file__).parent -class MockChatModelImpl(Resource): # noqa: D101 +class MockChatModelImpl(Resource): host: str desc: str @classmethod - def resource_type(cls) -> ResourceType: # noqa: D102 + def resource_type(cls) -> ResourceType: return ResourceType.CHAT_MODEL @pytest.fixture(scope="module") -def resource_provider() -> ResourceProvider: # noqa: D103 +def resource_provider() -> ResourceProvider: return PythonResourceProvider( name="mock", type=MockChatModelImpl.resource_type(), - descriptor=ResourceDescriptor(clazz=f"{MockChatModelImpl.__module__}.{MockChatModelImpl.__name__}", host="8.8.8.8", desc="mock chat model"), + descriptor=ResourceDescriptor( + clazz=f"{MockChatModelImpl.__module__}.{MockChatModelImpl.__name__}", + host="8.8.8.8", + desc="mock chat model", + ), ) -def test_python_resource_provider_serialize( # noqa: D103 +def test_python_resource_provider_serialize( resource_provider: ResourceProvider, ) -> None: json_value = resource_provider.model_dump_json(serialize_as_any=True) @@ -55,7 +59,7 @@ def test_python_resource_provider_serialize( # noqa: D103 assert actual == expected -def test_python_resource_provider_deserialize( # noqa: D103 +def test_python_resource_provider_deserialize( resource_provider: ResourceProvider, ) -> None: with Path.open(Path(f"{current_dir}/resources/resource_provider.json")) as f: diff --git a/python/flink_agents/plan/tests/tools/test_function_tool.py b/python/flink_agents/plan/tests/tools/test_function_tool.py index 60127fbef..d4a3f512c 100644 --- a/python/flink_agents/plan/tests/tools/test_function_tool.py +++ b/python/flink_agents/plan/tests/tools/test_function_tool.py @@ -44,11 +44,11 @@ def foo(bar: int, baz: str) -> str: @pytest.fixture(scope="module") -def func_tool() -> FunctionTool: # noqa: D103 +def func_tool() -> FunctionTool: return from_callable(foo) -def test_serialize_function_tool(func_tool: FunctionTool) -> None: # noqa: D103 +def test_serialize_function_tool(func_tool: FunctionTool) -> None: json_value = func_tool.model_dump_json(serialize_as_any=True, indent=4) with Path(f"{current_dir}/resources/function_tool.json").open() as f: expected_json = f.read() @@ -57,7 +57,7 @@ def test_serialize_function_tool(func_tool: FunctionTool) -> None: # noqa: D103 assert actual == expected -def test_deserialize_function_tool(func_tool: FunctionTool) -> None: # noqa: D103 +def test_deserialize_function_tool(func_tool: FunctionTool) -> None: with Path(f"{current_dir}/resources/function_tool.json").open() as f: json_value = f.read() actual_func_tool = FunctionTool.model_validate_json(json_value) diff --git a/python/flink_agents/runtime/flink_memory_object.py b/python/flink_agents/runtime/flink_memory_object.py index f0a8e9daf..2412f9081 100644 --- a/python/flink_agents/runtime/flink_memory_object.py +++ b/python/flink_agents/runtime/flink_memory_object.py @@ -74,7 +74,9 @@ def set(self, path: str, value: Any) -> MemoryRef: def new_object(self, path: str, *, overwrite: bool = False) -> "FlinkMemoryObject": """Create a new object at the given path.""" try: - return FlinkMemoryObject(self.__type, self._j_memory_object.newObject(path, overwrite)) + return FlinkMemoryObject( + self.__type, self._j_memory_object.newObject(path, overwrite) + ) except Exception as e: msg = f"Failed to create new object at path '{path}'" raise MemoryObjectError(msg) from e diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 4daeb6279..3e81e456d 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -205,7 +205,9 @@ def __await__(self) -> Any: ) if plan.mode == "replay": - result = self._ctx._replay_terminal_call(self._func, self._args, self._kwargs) + result = self._ctx._replay_terminal_call( + self._func, self._args, self._kwargs + ) if False: yield return result @@ -304,7 +306,9 @@ def send_event(self, event: Event) -> None: raise RuntimeError(err_msg) from e @override - def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: + def get_resource( + self, name: str, type: ResourceType, metric_group: MetricGroup = None + ) -> Resource: self._j_runner_context.checkMailboxThread() resource = self.__resource_cache.get_resource(name, type) # Bind metric group to the resource @@ -492,7 +496,9 @@ def _peek_current_call_result(self) -> _PersistedCallResult | None: function_id=function_id, args_digest=args_digest, status=status, - result_payload=bytes(result_payload) if result_payload is not None else None, + result_payload=bytes(result_payload) + if result_payload is not None + else None, exception_payload=( bytes(exception_payload) if exception_payload is not None else None ), @@ -793,6 +799,7 @@ def flink_runner_context_switch_action_context( if ctx.long_term_memory is not None: ctx.long_term_memory.switch_context(str(key)) + def close_flink_runner_context( ctx: FlinkRunnerContext, ) -> None: @@ -804,7 +811,9 @@ def create_async_thread_pool(max_workers: int | None) -> ThreadPoolExecutor: """Used to create a thread pool to execute asynchronous code block in action. """ - logging.info(f"Initialize fixed thread pool for async task with {max_workers} threads") + logging.info( + f"Initialize fixed thread pool for async task with {max_workers} threads" + ) return ThreadPoolExecutor(max_workers=max_workers or os.cpu_count() * 2) diff --git a/python/flink_agents/runtime/java/java_chat_model.py b/python/flink_agents/runtime/java/java_chat_model.py index f22d079cd..28ca408d1 100644 --- a/python/flink_agents/runtime/java/java_chat_model.py +++ b/python/flink_agents/runtime/java/java_chat_model.py @@ -36,7 +36,6 @@ class JavaChatModelConnectionImpl(JavaChatModelConnection): unlike JavaChatModelSetup, it does not provide direct chat functionality in Python. """ - _j_resource: Any _j_resource_adapter: Any @@ -49,15 +48,15 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N **kwargs: Additional keyword arguments """ super().__init__(**kwargs) - self._j_resource=j_resource - self._j_resource_adapter=j_resource_adapter + self._j_resource = j_resource + self._j_resource_adapter = j_resource_adapter @override def chat( - self, - messages: Sequence[ChatMessage], - tools: List[Tool] | None = None, - **kwargs: Any, + self, + messages: Sequence[ChatMessage], + tools: List[Tool] | None = None, + **kwargs: Any, ) -> ChatMessage: """Chat method that throws UnsupportedOperationException. @@ -70,7 +69,8 @@ def chat( for message in messages ] java_tools = [ - self._j_resource_adapter.getResource(tool.name, ResourceType.TOOL.value) for tool in tools + self._j_resource_adapter.getResource(tool.name, ResourceType.TOOL.value) + for tool in tools ] j_response_message = self._j_resource.chat(java_messages, java_tools, kwargs) @@ -108,10 +108,10 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N """ # connection is a required parameter for BaseChatModelSetup connection = kwargs.pop("connection", "") - super().__init__(connection = connection, **kwargs) + super().__init__(connection=connection, **kwargs) - self._j_resource=j_resource - self._j_resource_adapter=j_resource_adapter + self._j_resource = j_resource + self._j_resource_adapter = j_resource_adapter @property @override @@ -149,11 +149,15 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: Model response message """ # Convert Python messages to Java format - java_messages = [self._j_resource_adapter.fromPythonChatMessage(message) for message in messages] + java_messages = [ + self._j_resource_adapter.fromPythonChatMessage(message) + for message in messages + ] j_response_message = self._j_resource.chat(java_messages, kwargs) # Convert Java response back to Python format from flink_agents.runtime.python_java_utils import ( from_java_chat_message, ) + return from_java_chat_message(j_response_message) diff --git a/python/flink_agents/runtime/java/java_embedding_model.py b/python/flink_agents/runtime/java/java_embedding_model.py index a2dbc99a9..2cb15b819 100644 --- a/python/flink_agents/runtime/java/java_embedding_model.py +++ b/python/flink_agents/runtime/java/java_embedding_model.py @@ -45,10 +45,12 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N **kwargs: Additional keyword arguments """ super().__init__(**kwargs) - self._j_resource=j_resource - self._j_resource_adapter=j_resource_adapter + self._j_resource = j_resource + self._j_resource_adapter = j_resource_adapter - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text input. Converts the input text into a high-dimensional vector representation suitable for semantic similarity search and retrieval operations. @@ -70,6 +72,7 @@ class JavaEmbeddingModelSetupImpl(JavaEmbeddingModelSetup): but unlike JavaEmbeddingModelConnection, it does not provide direct embedding functionality in Python. """ + _j_resource: Any _j_resource_adapter: Any @@ -84,10 +87,10 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N # connection,model are required parameters for BaseEmbeddingModelSetup connection = kwargs.pop("connection", "") model = kwargs.pop("model", "") - super().__init__(connection = connection, model = model, **kwargs) + super().__init__(connection=connection, model=model, **kwargs) - self._j_resource=j_resource - self._j_resource_adapter=j_resource_adapter + self._j_resource = j_resource + self._j_resource_adapter = j_resource_adapter @property def model_kwargs(self) -> Dict[str, Any]: @@ -103,7 +106,9 @@ def open(self) -> None: """Open the java resource.""" self._j_resource.open() - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query. Converts the input text into a high-dimensional vector representation suitable for semantic similarity search and retrieval operations. diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py index 0d1e8a46b..8b2a9296e 100644 --- a/python/flink_agents/runtime/java/java_resource_wrapper.py +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -40,10 +40,11 @@ def call(self, *args: Any, **kwargs: Any) -> Any: err_msg = "Java tool is defined in Java and needs to be executed through the Java runtime." raise NotImplementedError(err_msg) + class JavaPrompt(Prompt): """Python wrapper for Java's Prompt.""" - j_prompt: Any= Field(exclude=True) + j_prompt: Any = Field(exclude=True) @override def format_string(self, **kwargs: str) -> str: @@ -54,18 +55,29 @@ def format_messages( self, role: MessageRole = MessageRole.SYSTEM, **kwargs: str ) -> List[ChatMessage]: from pemja import findClass - j_MessageRole = findClass("org.apache.flink.agents.api.chat.messages.MessageRole") - j_chat_messages = self.j_prompt.formatMessages(j_MessageRole.fromValue(role.value), kwargs) - chatMessages = [ChatMessage(role=MessageRole(j_chat_message.getRole().getValue()), - content=j_chat_message.getContent(), - tool_calls= j_chat_message.getToolCalls(), - extra_args=j_chat_message.getExtraArgs()) for j_chat_message in j_chat_messages] + + j_MessageRole = findClass( + "org.apache.flink.agents.api.chat.messages.MessageRole" + ) + j_chat_messages = self.j_prompt.formatMessages( + j_MessageRole.fromValue(role.value), kwargs + ) + chatMessages = [ + ChatMessage( + role=MessageRole(j_chat_message.getRole().getValue()), + content=j_chat_message.getContent(), + tool_calls=j_chat_message.getToolCalls(), + extra_args=j_chat_message.getExtraArgs(), + ) + for j_chat_message in j_chat_messages + ] return chatMessages @override def close(self) -> None: self.j_prompt.close() + class JavaGetResourceWrapper: """Python wrapper for Java ResourceAdapter.""" @@ -73,7 +85,6 @@ def __init__(self, j_resource_adapter: Any) -> None: """Initialize with a Java ResourceAdapter.""" self._j_resource_adapter = j_resource_adapter - def get_resource(self, name: str, type: ResourceType) -> Resource: """Get a resource by name and type.""" return self._j_resource_adapter.getResource(name, type.value) diff --git a/python/flink_agents/runtime/java/java_vector_store.py b/python/flink_agents/runtime/java/java_vector_store.py index 1cb4abe96..2f5253a67 100644 --- a/python/flink_agents/runtime/java/java_vector_store.py +++ b/python/flink_agents/runtime/java/java_vector_store.py @@ -44,6 +44,7 @@ class JavaVectorStoreImpl(JavaCollectionManageableVectorStore): but unlike JavaEmbeddingModelConnection, it does not provide direct embedding functionality in Python. """ + _j_resource: Any _j_resource_adapter: Any @@ -57,10 +58,10 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N """ # embedding_model are required parameters for BaseVectorStore embedding_model = kwargs.pop("embedding_model", "") - super().__init__(embedding_model = embedding_model, **kwargs) + super().__init__(embedding_model=embedding_model, **kwargs) - self._j_resource=j_resource - self._j_resource_adapter=j_resource_adapter + self._j_resource = j_resource + self._j_resource_adapter = j_resource_adapter @override @property @@ -73,12 +74,11 @@ def open(self) -> None: @override def add( - self, - documents: Document | List[Document], - collection_name: str | None = None, - **kwargs: Any, + self, + documents: Document | List[Document], + collection_name: str | None = None, + **kwargs: Any, ) -> List[str]: - documents = _maybe_cast_to_list(documents) j_documents = [ self._j_resource_adapter.fromPythonDocument(document) @@ -99,10 +99,10 @@ def size(self, collection_name: str | None = None) -> int: @override def get( - self, - ids: str | List[str] | None = None, - collection_name: str | None = None, - **kwargs: Any, + self, + ids: str | List[str] | None = None, + collection_name: str | None = None, + **kwargs: Any, ) -> List[Document]: ids = _maybe_cast_to_list(ids) j_documents = self._j_resource.get(ids, collection_name, kwargs) @@ -120,7 +120,7 @@ def delete( @override def get_or_create_collection( - self, name: str, metadata: Dict[str, Any] | None = None + self, name: str, metadata: Dict[str, Any] | None = None ) -> Collection: j_collection = self._j_resource.getOrCreateCollection(name, metadata) return from_java_collection(j_collection) diff --git a/python/flink_agents/runtime/local_memory_object.py b/python/flink_agents/runtime/local_memory_object.py index 20e03f69d..a9fb44b82 100644 --- a/python/flink_agents/runtime/local_memory_object.py +++ b/python/flink_agents/runtime/local_memory_object.py @@ -37,7 +37,9 @@ class LocalMemoryObject(MemoryObject): __store: dict[str, Any] __prefix: str - def __init__(self, type: MemoryType, store: Dict[str, Any], prefix: str = ROOT_KEY) -> None: + def __init__( + self, type: MemoryType, store: Dict[str, Any], prefix: str = ROOT_KEY + ) -> None: """Initialize a LocalMemoryObject. Parameters diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 53823b61d..84402e44b 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -125,7 +125,9 @@ def send_event(self, event: Event) -> None: self.events.append(event) @override - def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: + def get_resource( + self, name: str, type: ResourceType, metric_group: MetricGroup = None + ) -> Resource: return self.__resource_cache.get_resource(name, type) @property diff --git a/python/flink_agents/runtime/memory/compaction_functions.py b/python/flink_agents/runtime/memory/compaction_functions.py index 80d447b1f..635832b91 100644 --- a/python/flink_agents/runtime/memory/compaction_functions.py +++ b/python/flink_agents/runtime/memory/compaction_functions.py @@ -142,7 +142,7 @@ def _generate_summarization( item_type: Type, compaction_config: CompactionConfig, ctx: RunnerContext, - metric_group: MetricGroup + metric_group: MetricGroup, ) -> ChatMessage: """Generate summarization of the items by llm.""" # get arguments @@ -161,7 +161,9 @@ def _generate_summarization( # generate summary model: BaseChatModelSetup = cast( "BaseChatModelSetup", - ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL, metric_group=metric_group), + ctx.get_resource( + name=model_name, type=ResourceType.CHAT_MODEL, metric_group=metric_group + ), ) input_variable = {} for msg in msgs: @@ -171,14 +173,18 @@ def _generate_summarization( if isinstance(prompt, str): prompt: Prompt = cast( "Prompt", - ctx.get_resource(prompt, ResourceType.PROMPT, metric_group=metric_group), + ctx.get_resource( + prompt, ResourceType.PROMPT, metric_group=metric_group + ), ) prompt_messages = prompt.format_messages( role=MessageRole.USER, **input_variable ) msgs.extend(prompt_messages) else: - msgs.extend(DEFAULT_ANALYSIS_PROMPT.format_messages(limit=str(compaction_config.limit))) + msgs.extend( + DEFAULT_ANALYSIS_PROMPT.format_messages(limit=str(compaction_config.limit)) + ) response: ChatMessage = model.chat(messages=msgs) diff --git a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py index 390d54e0f..0311fe595 100644 --- a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py +++ b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py @@ -59,25 +59,25 @@ current_dir = Path(__file__).parent -class MockEmbeddingModel(BaseEmbeddingModelSetup): # noqa: D101 +class MockEmbeddingModel(BaseEmbeddingModelSetup): model_config = ConfigDict(arbitrary_types_allowed=True) ef: EmbeddingFunction = embedding_functions.DefaultEmbeddingFunction() connection: str = "mock" model: str = "mock" - def open(self) -> None: # noqa: D102 + def open(self) -> None: pass @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {} - def embed(self, text: str, **kwargs: Any) -> Any: # noqa: D102 + def embed(self, text: str, **kwargs: Any) -> Any: return self.ef([text])[0] @pytest.fixture(scope="module") -def long_term_memory() -> VectorStoreLongTermMemory: # noqa: D103 +def long_term_memory() -> VectorStoreLongTermMemory: embedding_model_connection = OllamaEmbeddingModelConnection() chat_model_connection = OllamaChatModelConnection(request_timeout=240) @@ -134,7 +134,7 @@ def get_resource(name: str, type: ResourceType) -> Resource: ) -def prepare_memory_set( # noqa: D103 +def prepare_memory_set( long_term_memory: VectorStoreLongTermMemory, compaction_config: CompactionConfig = CompactionConfig(model="llm"), # noqa:B008 ) -> (MemorySet, List[ChatMessage]): @@ -154,7 +154,7 @@ def prepare_memory_set( # noqa: D103 return memory_set, msgs -def test_get_memory_set( # noqa:D103 +def test_get_memory_set( long_term_memory: VectorStoreLongTermMemory, ) -> None: memory_set, _ = prepare_memory_set(long_term_memory) @@ -164,7 +164,7 @@ def test_get_memory_set( # noqa:D103 long_term_memory.delete_memory_set(name="chat_history") -def test_add_and_get( # noqa:D103 +def test_add_and_get( long_term_memory: VectorStoreLongTermMemory, ) -> None: memory_set, msgs = prepare_memory_set(long_term_memory) @@ -177,7 +177,7 @@ def test_add_and_get( # noqa:D103 long_term_memory.delete_memory_set(name="chat_history") -def test_search( # noqa:D103 +def test_search( long_term_memory: VectorStoreLongTermMemory, ) -> None: memory_set, msgs = prepare_memory_set(long_term_memory) @@ -193,7 +193,7 @@ def test_search( # noqa:D103 @pytest.mark.skip("Depend on ollama server") -def test_compact( # noqa:D103 +def test_compact( long_term_memory: VectorStoreLongTermMemory, ) -> None: memory_set: MemorySet = long_term_memory.get_or_create_memory_set( diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index 42b4a8a65..51b3c1462 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -75,7 +75,10 @@ def get_output_from_output_event(bytesObject: bytes) -> Any: """Get output data from OutputEvent and serialize.""" return cloudpickle.dumps(convert_to_python_object(bytesObject).output) -def create_resource(resource_module: str, resource_clazz: str, func_kwargs: Dict[str, Any]) -> Resource: + +def create_resource( + resource_module: str, resource_clazz: str, func_kwargs: Dict[str, Any] +) -> Resource: """Dynamically create a resource instance from module and class name. Args: @@ -90,6 +93,7 @@ def create_resource(resource_module: str, resource_clazz: str, func_kwargs: Dict cls = getattr(module, resource_clazz) return cls(**func_kwargs) + def get_resource_function(j_resource_adapter: Any) -> Callable: """Create a callable wrapper for Java resource adapter. @@ -101,6 +105,7 @@ def get_resource_function(j_resource_adapter: Any) -> Callable: """ return JavaGetResourceWrapper(j_resource_adapter).get_resource + def from_java_tool(j_tool: Any) -> JavaTool: """Convert a Java tool object to a Python JavaTool instance. @@ -114,10 +119,13 @@ def from_java_tool(j_tool: Any) -> JavaTool: metadata = ToolMetadata( name=name, description=j_tool.getDescription(), - args_schema=create_model_from_java_tool_schema_str(name, j_tool.getMetadata().getInputSchema()), + args_schema=create_model_from_java_tool_schema_str( + name, j_tool.getMetadata().getInputSchema() + ), ) return JavaTool(metadata=metadata) + def from_java_prompt(j_prompt: Any) -> JavaPrompt: """Convert a Java prompt object to a Python JavaPrompt instance. @@ -129,6 +137,7 @@ def from_java_prompt(j_prompt: Any) -> JavaPrompt: """ return JavaPrompt(j_prompt=j_prompt) + def from_java_resource(type_name: str, kwargs: Dict[str, Any]) -> Resource: """Convert a Java resource object to a Python Resource instance. This function is used to convert a Java resource object to a Python Resource @@ -150,6 +159,7 @@ def from_java_resource(type_name: str, kwargs: Dict[str, Any]) -> Resource: return cls(**kwargs) + def normalize_tool_call_id(tool_call: Dict[str, Any]) -> Dict[str, Any]: """Normalize tool call by converting the ID field to string format while preserving all other fields. @@ -171,17 +181,24 @@ def normalize_tool_call_id(tool_call: Dict[str, Any]) -> Dict[str, Any]: return normalized_call + def from_java_chat_message(j_chat_message: Any) -> ChatMessage: """Convert a chat message to a python chat message.""" - return ChatMessage(role=MessageRole(j_chat_message.getRole().getValue()), - content=j_chat_message.getContent(), - tool_calls=[normalize_tool_call_id(tool_call) for tool_call in j_chat_message.getToolCalls()], - extra_args=j_chat_message.getExtraArgs()) + return ChatMessage( + role=MessageRole(j_chat_message.getRole().getValue()), + content=j_chat_message.getContent(), + tool_calls=[ + normalize_tool_call_id(tool_call) + for tool_call in j_chat_message.getToolCalls() + ], + extra_args=j_chat_message.getExtraArgs(), + ) def to_java_chat_message(chat_message: ChatMessage) -> Any: """Convert a chat message to a java chat message.""" from pemja import findClass + j_ChatMessage = findClass("org.apache.flink.agents.api.chat.messages.ChatMessage") j_chat_message = j_ChatMessage() @@ -190,22 +207,28 @@ def to_java_chat_message(chat_message: ChatMessage) -> Any: j_chat_message.setContent(chat_message.content) j_chat_message.setExtraArgs(chat_message.extra_args) if chat_message.tool_calls: - tool_calls = [normalize_tool_call_id(tool_call) for tool_call in chat_message.tool_calls] + tool_calls = [ + normalize_tool_call_id(tool_call) for tool_call in chat_message.tool_calls + ] j_chat_message.setToolCalls(tool_calls) return j_chat_message + # TODO: Replace this with `to_java_chat_message()` when the `find_class` bug is fixed. def update_java_chat_message(chat_message: ChatMessage, j_chat_message: Any) -> str: """Update a Java chat message using Python chat message.""" j_chat_message.setContent(chat_message.content) j_chat_message.setExtraArgs(chat_message.extra_args) if chat_message.tool_calls: - tool_calls = [normalize_tool_call_id(tool_call) for tool_call in chat_message.tool_calls] + tool_calls = [ + normalize_tool_call_id(tool_call) for tool_call in chat_message.tool_calls + ] j_chat_message.setToolCalls(tool_calls) return chat_message.role.value + def from_java_document(j_document: Any) -> Document: """Convert a Java documents to a Python document.""" document = Document( @@ -217,6 +240,7 @@ def from_java_document(j_document: Any) -> Document: document.embedding = list(j_document.getEmbedding()) return document + def update_java_document(document: Document, j_document: Any) -> None: """Update a Java document using Python document.""" j_document.setContent(document.content) @@ -233,15 +257,19 @@ def from_java_vector_store_query(j_query: Any) -> VectorStoreQuery: query_text=j_query.getQueryText(), limit=j_query.getLimit(), collection_name=j_query.getCollection(), - extra_args=j_query.getExtraArgs() + extra_args=j_query.getExtraArgs(), ) + def from_java_vector_store_query_result(j_query: Any) -> VectorStoreQueryResult: """Convert a Java vector store query result to a Python query result.""" return VectorStoreQueryResult( - documents=[from_java_document(j_document) for j_document in j_query.getDocuments()], + documents=[ + from_java_document(j_document) for j_document in j_query.getDocuments() + ], ) + def from_java_collection(j_collection: Any) -> Collection: """Convert a Java collection to a Python collection.""" return Collection( @@ -249,18 +277,28 @@ def from_java_collection(j_collection: Any) -> Collection: metadata=j_collection.getMetadata(), ) + def from_java_message_role(j_role: Any) -> MessageRole: """Convert a Java message role to a Python message role.""" return MessageRole(j_role.getValue()) + def get_java_tool_metadata_from_tool(tool: Tool) -> typing.Dict[str, str]: """Retrieve Java format tool metadata from a tool input schema string.""" - return {"name": tool.name, "description": tool.metadata.description, "inputSchema": create_java_tool_schema_str_from_model(tool.metadata.args_schema)} + return { + "name": tool.name, + "description": tool.metadata.description, + "inputSchema": create_java_tool_schema_str_from_model( + tool.metadata.args_schema + ), + } + def get_mode_value(query: VectorStoreQuery) -> str: """Get the mode value of a VectorStoreQuery.""" return query.mode.value + def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any: """Calls a method on `obj` by name and passes in positional and keyword arguments. diff --git a/python/flink_agents/runtime/remote_execution_environment.py b/python/flink_agents/runtime/remote_execution_environment.py index 134365c68..3d755520d 100644 --- a/python/flink_agents/runtime/remote_execution_environment.py +++ b/python/flink_agents/runtime/remote_execution_environment.py @@ -46,6 +46,7 @@ _CONFIG_FILE_NAME = "config.yaml" _LEGACY_CONFIG_FILE_NAME = "flink-conf.yaml" + class RemoteAgentBuilder(AgentBuilder): """RemoteAgentBuilder for integrating datastream/table and agent.""" @@ -268,7 +269,6 @@ def execute(self, job_name: str | None = None) -> None: """Execute agent.""" self.__env.execute(job_name=job_name) - def __load_config_from_flink_conf_dir(self) -> None: """Load agent configuration from FLINK_CONF_DIR if available.""" flink_conf_dir = os.environ.get("FLINK_CONF_DIR") @@ -299,9 +299,7 @@ def __find_config_file(self, flink_conf_dir: str) -> Path | None: # Try legacy config file name first legacy_config_path = Path(flink_conf_dir).joinpath(_LEGACY_CONFIG_FILE_NAME) if legacy_config_path.exists(): - logging.warning( - f"Using legacy config file {_LEGACY_CONFIG_FILE_NAME}" - ) + logging.warning(f"Using legacy config file {_LEGACY_CONFIG_FILE_NAME}") return legacy_config_path # Try new config file name as fallback diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py index 8959b12b3..07258a35f 100644 --- a/python/flink_agents/runtime/tests/test_built_in_actions.py +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -129,7 +129,9 @@ def prompt() -> Prompt: @staticmethod def mock_connection() -> ResourceDescriptor: """Chat model server can be used by ChatModel.""" - return ResourceDescriptor(clazz=f"{MockChatModelConnection.__module__}.{MockChatModelConnection.__name__}") + return ResourceDescriptor( + clazz=f"{MockChatModelConnection.__module__}.{MockChatModelConnection.__name__}" + ) @chat_model_setup @staticmethod @@ -188,7 +190,7 @@ def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: ctx.send_event(OutputEvent(output=input.content)) -def test_built_in_actions() -> None: # noqa: D103 +def test_built_in_actions() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index 52d777a2c..f652b0ee8 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -152,6 +152,7 @@ def test_validate_reconciler_callable_accepts_none() -> None: def test_validate_reconciler_callable_accepts_zero_arg_function() -> None: """Accept a zero-argument reconciler function.""" + def reconciler() -> str: return "ok" diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py index 4d64eae41..a41a33c7b 100644 --- a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -290,9 +290,16 @@ def reconciler() -> str: assert result == "call:order-1" assert reconciler_called is False - assert j_runner_context.operations == ["peek", "clear", "append_pending", "finalize"] + assert j_runner_context.operations == [ + "peek", + "clear", + "append_pending", + "finalize", + ] assert len(j_runner_context.call_results) == 1 - assert j_runner_context.call_results[0].function_id == _compute_function_id(_call_value) + assert j_runner_context.call_results[0].function_id == _compute_function_id( + _call_value + ) assert j_runner_context.call_results[0].args_digest == _compute_args_digest( ("order-1",), {} ) diff --git a/python/flink_agents/runtime/tests/test_get_resource_in_action.py b/python/flink_agents/runtime/tests/test_get_resource_in_action.py index f6028003f..de508397a 100644 --- a/python/flink_agents/runtime/tests/test_get_resource_in_action.py +++ b/python/flink_agents/runtime/tests/test_get_resource_in_action.py @@ -27,7 +27,7 @@ from flink_agents.api.runner_context import RunnerContext -class MockChatModelImpl(BaseChatModelSetup): # noqa: D101 +class MockChatModelImpl(BaseChatModelSetup): host: str desc: str @@ -35,22 +35,26 @@ def open(self) -> None: """Do nothing.""" @property - def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + def model_kwargs(self) -> Dict[str, Any]: return {} - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: # noqa: D102 + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: return ChatMessage( role=MessageRole.ASSISTANT, content=f"{messages[0].content} {self.host} {self.desc}", ) -class MyAgent(Agent): # noqa: D101 +class MyAgent(Agent): @chat_model_setup @staticmethod - def mock_chat_model() -> ResourceDescriptor: # noqa: D102 - return ResourceDescriptor(clazz=f"{MockChatModelImpl.__module__}.{MockChatModelImpl.__name__}", host="8.8.8.8", - desc="mock chat model just for testing.", connection="mock") + def mock_chat_model() -> ResourceDescriptor: + return ResourceDescriptor( + clazz=f"{MockChatModelImpl.__module__}.{MockChatModelImpl.__name__}", + host="8.8.8.8", + desc="mock chat model just for testing.", + connection="mock", + ) @tool @staticmethod @@ -71,7 +75,7 @@ def mock_tool(input: str) -> str: @action(InputEvent) @staticmethod - def mock_action(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 + def mock_action(event: InputEvent, ctx: RunnerContext) -> None: input = event.input mock_chat_model = ctx.get_resource( type=ResourceType.CHAT_MODEL, name="mock_chat_model" @@ -88,7 +92,7 @@ def mock_action(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 ) -def test_get_resource_in_action() -> None: # noqa: D103 +def test_get_resource_in_action() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] diff --git a/python/flink_agents/runtime/tests/test_local_execution_environment.py b/python/flink_agents/runtime/tests/test_local_execution_environment.py index a34492c60..383e270ec 100644 --- a/python/flink_agents/runtime/tests/test_local_execution_environment.py +++ b/python/flink_agents/runtime/tests/test_local_execution_environment.py @@ -26,7 +26,7 @@ from flink_agents.api.runner_context import RunnerContext -class Agent1(Agent): # noqa: D101 +class Agent1(Agent): @action(InputEvent) @staticmethod def increment(event: Event, ctx: RunnerContext): # noqa D102 @@ -35,7 +35,7 @@ def increment(event: Event, ctx: RunnerContext): # noqa D102 ctx.send_event(OutputEvent(output=value)) -class Agent1WithAsync(Agent): # noqa: D101 +class Agent1WithAsync(Agent): @action(InputEvent) @staticmethod async def increment(event: Event, ctx: RunnerContext): # noqa D102 @@ -48,7 +48,7 @@ def my_func(value: int) -> int: ctx.send_event(OutputEvent(output=value)) -class Agent2(Agent): # noqa: D101 +class Agent2(Agent): @action(InputEvent) @staticmethod def decrease(event: Event, ctx: RunnerContext): # noqa D102 @@ -57,7 +57,7 @@ def decrease(event: Event, ctx: RunnerContext): # noqa D102 ctx.send_event(OutputEvent(output=value)) -def test_local_execution_environment() -> None: # noqa: D103 +def test_local_execution_environment() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] @@ -73,7 +73,7 @@ def test_local_execution_environment() -> None: # noqa: D103 assert output_list == [{"bob": 2}, {"john": 3}] -def test_local_execution_environment_with_async() -> None: # noqa: D103 +def test_local_execution_environment_with_async() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] @@ -89,7 +89,7 @@ def test_local_execution_environment_with_async() -> None: # noqa: D103 assert output_list == [{"bob": 2}, {"john": 3}] -def test_local_execution_environment_apply_multi_agents() -> None: # noqa: D103 +def test_local_execution_environment_apply_multi_agents() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] @@ -100,7 +100,7 @@ def test_local_execution_environment_apply_multi_agents() -> None: # noqa: D103 env.from_list(input_list).apply(agent1).apply(agent2).to_list() -def test_local_execution_environment_execute_multi_times() -> None: # noqa: D103 +def test_local_execution_environment_execute_multi_times() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] @@ -116,7 +116,7 @@ def test_local_execution_environment_execute_multi_times() -> None: # noqa: D10 env.execute() -def test_local_execution_environment_call_from_list_twice() -> None: # noqa: D103 +def test_local_execution_environment_call_from_list_twice() -> None: env = AgentsExecutionEnvironment.get_execution_environment() input_list = [] diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py b/python/flink_agents/runtime/tests/test_local_memory_object.py index 305aba257..5268077a7 100644 --- a/python/flink_agents/runtime/tests/test_local_memory_object.py +++ b/python/flink_agents/runtime/tests/test_local_memory_object.py @@ -26,7 +26,7 @@ def create_memory() -> LocalMemoryObject: return LocalMemoryObject(MemoryType.SHORT_TERM, {}) -class User: # noqa: D101 +class User: def __init__(self, name: str, age: int) -> None: """Store for later comparison.""" self.name = name @@ -40,7 +40,7 @@ def __eq__(self, other: object) -> bool: ) -def test_basic_set_get_various_types() -> None: # noqa: D103 +def test_basic_set_get_various_types() -> None: mem = create_memory() # int / float / str @@ -74,7 +74,7 @@ def test_basic_set_get_various_types() -> None: # noqa: D103 assert mem.get("user") == user -def test_nested_set_and_get() -> None: # noqa: D103 +def test_nested_set_and_get() -> None: mem = create_memory() mem.set("a.b.c", True) tmp_obj = mem.get("a.b") @@ -88,7 +88,7 @@ def test_nested_set_and_get() -> None: # noqa: D103 assert mem.get("a.b").get("c") is True -def test_new_object_and_is_exist() -> None: # noqa: D103 +def test_new_object_and_is_exist() -> None: mem = create_memory() mem.new_object("foo.bar") assert mem.is_exist("foo") @@ -98,7 +98,7 @@ def test_new_object_and_is_exist() -> None: # noqa: D103 assert fields["bar"] == "NestedObject" -def test_overwrite_behavior() -> None: # noqa: D103 +def test_overwrite_behavior() -> None: mem = create_memory() mem.set("profile", "active") @@ -113,7 +113,7 @@ def test_overwrite_behavior() -> None: # noqa: D103 assert mem.get("profile.status") == "ok" -def test_auto_parent_fill_and_children() -> None: # noqa: D103 +def test_auto_parent_fill_and_children() -> None: mem = create_memory() mem.new_object("x.y.z") @@ -125,7 +125,7 @@ def test_auto_parent_fill_and_children() -> None: # noqa: D103 assert root_fields["x"] == "NestedObject" -def test_disallow_overwrite_object_with_primitive() -> None: # noqa: D103 +def test_disallow_overwrite_object_with_primitive() -> None: mem = create_memory() mem.new_object("obj") try: diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py b/python/flink_agents/runtime/tests/test_memory_reference.py index 497de1a28..10d885b1e 100644 --- a/python/flink_agents/runtime/tests/test_memory_reference.py +++ b/python/flink_agents/runtime/tests/test_memory_reference.py @@ -20,13 +20,13 @@ from flink_agents.runtime.local_memory_object import LocalMemoryObject -class MockRunnerContext: # noqa D101 +class MockRunnerContext: def __init__(self, memory: LocalMemoryObject) -> None: """Mock RunnerContext for testing resolve() method.""" self._memory = memory @property - def short_term_memory(self) -> LocalMemoryObject: # noqa D102 + def short_term_memory(self) -> LocalMemoryObject: return self._memory @@ -35,7 +35,7 @@ def create_memory() -> LocalMemoryObject: return LocalMemoryObject(MemoryType.SHORT_TERM, {}) -class User: # noqa: D101 +class User: def __init__(self, name: str, age: int) -> None: """Store for later comparison.""" self.name = name @@ -43,13 +43,13 @@ def __init__(self, name: str, age: int) -> None: def __eq__(self, other: object) -> bool: return ( - isinstance(other, User) - and other.name == self.name - and other.age == self.age + isinstance(other, User) + and other.name == self.name + and other.age == self.age ) -def test_set_get_involved_ref() -> None: # noqa: D103 +def test_set_get_involved_ref() -> None: mem = create_memory() # Test cases: (path, value, type_name) @@ -72,7 +72,7 @@ def test_set_get_involved_ref() -> None: # noqa: D103 assert retrieved_value == value -def test_memory_ref_create() -> None: # noqa: D103 +def test_memory_ref_create() -> None: path = "a.b.c" ref = MemoryRef.create(MemoryType.SHORT_TERM, path) @@ -80,7 +80,7 @@ def test_memory_ref_create() -> None: # noqa: D103 assert ref.path == path -def test_memory_ref_resolve() -> None: # noqa: D103 +def test_memory_ref_resolve() -> None: mem = create_memory() ctx = MockRunnerContext(mem) @@ -100,7 +100,7 @@ def test_memory_ref_resolve() -> None: # noqa: D103 assert resolved_value == value -def test_get_with_ref_to_nested_object() -> None: # noqa: D103 +def test_get_with_ref_to_nested_object() -> None: mem = create_memory() obj = mem.new_object("a.b") obj.set("c", 10) @@ -112,15 +112,17 @@ def test_get_with_ref_to_nested_object() -> None: # noqa: D103 assert resolved_obj.get("b.c") == 10 -def test_get_with_non_existent_ref() -> None: # noqa: D103 +def test_get_with_non_existent_ref() -> None: mem = create_memory() - non_existent_ref = MemoryRef.create(MemoryType.SHORT_TERM, "this.path.does.not.exist") + non_existent_ref = MemoryRef.create( + MemoryType.SHORT_TERM, "this.path.does.not.exist" + ) assert mem.get(non_existent_ref) is None -def test_ref_equality_and_hashing() -> None: # noqa: D103 +def test_ref_equality_and_hashing() -> None: ref1 = MemoryRef(path="a.b") ref2 = MemoryRef(path="a.b") ref3 = MemoryRef(path="a.c") diff --git a/python/flink_agents/runtime/tests/test_runner_context_execute.py b/python/flink_agents/runtime/tests/test_runner_context_execute.py index fd3fef34e..d93f8f861 100644 --- a/python/flink_agents/runtime/tests/test_runner_context_execute.py +++ b/python/flink_agents/runtime/tests/test_runner_context_execute.py @@ -187,4 +187,3 @@ def test_durable_execute_with_kwargs() -> None: env.execute() assert output_list == [{"alice": 25}] - diff --git a/python/pyproject.toml b/python/pyproject.toml index b85681db1..de60c66b1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -180,8 +180,8 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "dependencies.py" = ["ICN001"] -"tests/**/*.py" = ["ANN001", "ANN201", "D100", "D102", "D103", "B018", "FBT001"] -"_build_backend/tests/**/*.py" = ["ANN001", "ANN201", "D100", "D102", "D103", "B018", "FBT001"] +"**/tests/**/*.py" = ["ANN001", "ANN201", "D100", "D101", "D102", "D103", "D107", "B018", "FBT001", "W505"] +"**/e2e_tests/**/*.py" = ["ANN001", "ANN201", "D100", "D101", "D102", "D103", "D107", "B018", "FBT001", "W505"] [tool.ruff.lint.pycodestyle] max-doc-length = 88