Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

try:
import dotenv

dotenv.load_dotenv()
except ImportError:
# dotenv is optional for this test server
Expand All @@ -34,6 +35,7 @@ def ask_sum(a: int, b: int) -> str:
"""Prompt of add tool."""
return f"Can you please calculate the sum of {a} and {b}?"


@mcp.tool()
async def add(a: int, b: int) -> int:
"""Get the detailed information of a specified IP address.
Expand All @@ -47,4 +49,5 @@ async def add(a: int, b: int) -> int:
"""
return a + b


mcp.run("streamable-http")
61 changes: 40 additions & 21 deletions e2e-test/test-scripts/check_resource_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def parse_java_resource_name(java_path: Path) -> dict:
r'public\s+static\s+final\s+String\s+([A-Za-z0-9_]+)\s*=\s*"([^"]+)";',
re.DOTALL,
)
class_re = re.compile(
r"public\s+(?:static\s+)?final\s+class\s+(\w+)\s*\{"
)
class_re = re.compile(r"public\s+(?:static\s+)?final\s+class\s+(\w+)\s*\{")

class_stack = []
brace_depth = 0
Expand Down Expand Up @@ -200,35 +198,42 @@ def _parse_python_resource_name(python_path: Path) -> dict:
elif len(parts) == 3 and parts[2] == "Java":
python_map[(rt, "Java")] = consts
if "ResourceName" in result and "MCP_SERVER" in result["ResourceName"]:
python_map[("MCP", "Python")] = {"MCP_SERVER": result["ResourceName"]["MCP_SERVER"]}
python_map[("MCP", "Python")] = {
"MCP_SERVER": result["ResourceName"]["MCP_SERVER"]
}
return python_map


_JAVA_ONLY_NAMES = frozenset({
"PYTHON_WRAPPER_CONNECTION", "PYTHON_WRAPPER_SETUP",
"PYTHON_WRAPPER_VECTOR_STORE", "PYTHON_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE",
})
_PYTHON_ONLY_NAMES = frozenset({
"JAVA_WRAPPER_CONNECTION", "JAVA_WRAPPER_SETUP",
"JAVA_WRAPPER_VECTOR_STORE", "JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE",
})
_JAVA_ONLY_NAMES = frozenset(
{
"PYTHON_WRAPPER_CONNECTION",
"PYTHON_WRAPPER_SETUP",
"PYTHON_WRAPPER_VECTOR_STORE",
"PYTHON_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE",
}
)
_PYTHON_ONLY_NAMES = frozenset(
{
"JAVA_WRAPPER_CONNECTION",
"JAVA_WRAPPER_SETUP",
"JAVA_WRAPPER_VECTOR_STORE",
"JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE",
}
)


def _find_python_name_for_value(impls: dict, value: str, java_name: str) -> str | None:
return java_name if impls.get(java_name) == value else None


def check_consistency(
java_map: dict, python_map: dict
) -> tuple[list[str], list[str]]:

def check_consistency(java_map: dict, python_map: dict) -> tuple[list[str], list[str]]:
errors = []
warnings = []

all_resource_types = set()
for (rt, _) in java_map:
for rt, _ in java_map:
all_resource_types.add(rt)
for (rt, _) in python_map:
for rt, _ in python_map:
all_resource_types.add(rt)

for resource_type in sorted(all_resource_types):
Expand Down Expand Up @@ -287,7 +292,10 @@ def check_consistency(

def main() -> int:
root = Path(__file__).resolve().parent.parent.parent
java_path = root / "api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java"
java_path = (
root
/ "api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java"
)
python_path = root / "python/flink_agents/api/resource.py"

if not java_path.exists():
Expand All @@ -303,8 +311,19 @@ def main() -> int:
debug = __import__("os").environ.get("RESOURCE_DEBUG")
if debug:
import json
print("Java map:", json.dumps({str(k): v for k, v in java_map.items()}, indent=2, ensure_ascii=False))
print("Python map:", json.dumps({str(k): v for k, v in python_map.items()}, indent=2, ensure_ascii=False))

print(
"Java map:",
json.dumps(
{str(k): v for k, v in java_map.items()}, indent=2, ensure_ascii=False
),
)
print(
"Python map:",
json.dumps(
{str(k): v for k, v in python_map.items()}, indent=2, ensure_ascii=False
),
)

errors, warnings = check_consistency(java_map, python_map)

Expand Down
8 changes: 4 additions & 4 deletions python/_build_backend/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _write_manifest(path: Path, manifest: dict) -> None:
# ---------------------------------------------------------------------------


class TestJarFilename: # noqa: D101
class TestJarFilename:
def test_without_classifier(self) -> None:
entry = {
"artifact_id": "flink-agents-dist-common",
Expand All @@ -76,7 +76,7 @@ def test_with_classifier(self) -> None:
)


class TestLoadManifest: # noqa: D101
class TestLoadManifest:
def test_load(self, tmp_path) -> None:
manifest = {
"maven_base_url": "https://repo1.maven.org/maven2",
Expand All @@ -89,7 +89,7 @@ def test_load(self, tmp_path) -> None:
assert loaded == manifest


class TestVerifyChecksum: # noqa: D101
class TestVerifyChecksum:
def test_valid_checksum(self, tmp_path) -> None:
content = b"fake jar content"
jar = tmp_path / "test.jar"
Expand All @@ -106,7 +106,7 @@ def test_invalid_checksum(self, tmp_path) -> None:
assert not jar.exists()


class TestEnsureJars: # noqa: D101
class TestEnsureJars:
def test_skip_when_no_manifest(self, tmp_path, monkeypatch) -> None:
monkeypatch.chdir(tmp_path)
_ensure_jars()
Expand Down
5 changes: 4 additions & 1 deletion python/flink_agents/api/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ def add_action(
return self

def add_resource(
self, name: str, resource_type: ResourceType, instance: SerializableResource | ResourceDescriptor
self,
name: str,
resource_type: ResourceType,
instance: SerializableResource | ResourceDescriptor,
) -> "Agent":
"""Add resource to agent instance.

Expand Down
2 changes: 1 addition & 1 deletion python/flink_agents/api/agents/tests/test_row_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flink_agents.api.agents.react_agent import OutputSchema


def test_output_schema_serializable() -> None: # noqa: D103
def test_output_schema_serializable() -> None:
schema = OutputSchema(
output_schema=RowTypeInfo(
[BasicTypeInfo.INT_TYPE_INFO()],
Expand Down
5 changes: 3 additions & 2 deletions python/flink_agents/api/chat_models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
# Call chat model connection to execute chat
merged_kwargs = self.model_kwargs.copy()
merged_kwargs.update(kwargs)
return self._get_connection().chat(messages, tools=self._get_tools(), **merged_kwargs)
return self._get_connection().chat(
messages, tools=self._get_tools(), **merged_kwargs
)

def _record_token_metrics(
self, model_name: str, prompt_tokens: int, completion_tokens: int
Expand Down Expand Up @@ -256,4 +258,3 @@ def _get_tools(self) -> List[Tool]:
err_msg = f"Expect Tool, but is {tool.__class__.__name__}"
raise TypeError(err_msg)
return self.tools

5 changes: 3 additions & 2 deletions python/flink_agents/api/chat_models/java_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class JavaChatModelConnection(BaseChatModelConnection):
unlike JavaChatModelSetup, it does not provide direct chat functionality in Python.
"""

java_class_name: str=""
java_class_name: str = ""


@java_resource
class JavaChatModelSetup(BaseChatModelSetup):
Expand All @@ -43,4 +44,4 @@ class JavaChatModelSetup(BaseChatModelSetup):
implementation.
"""

java_class_name: str=""
java_class_name: str = ""
16 changes: 11 additions & 5 deletions python/flink_agents/api/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ class ConfigOption:
default: The default value for this configuration option
"""

def __init__(self, key: str, config_type: Type[Any], default: Any | None=None) -> None:
def __init__(
self, key: str, config_type: Type[Any], default: Any | None = None
) -> None:
"""Initialize a configuration option."""
self._key = key
self._type = config_type
self._default_value = default

def get_key(self) -> str:
"""Gets the configuration key."""
return self._key
Expand All @@ -47,6 +50,7 @@ def get_default_value(self) -> Any:
"""Returns the default value."""
return self._default_value


class WritableConfiguration(ABC):
"""Abstract base class providing write access to a configuration object.

Expand Down Expand Up @@ -98,14 +102,15 @@ def set(self, option: ConfigOption, value: Any) -> None:
value: The value to set for the key
"""


class ReadableConfiguration(ABC):
"""Abstract base class providing read access to a configuration object.

This class enables retrieval of configuration settings.
"""

@abstractmethod
def get_int(self, key: str, default: int | None=None) -> int:
def get_int(self, key: str, default: int | None = None) -> int:
"""Get the int configuration value by key.

Args:
Expand All @@ -117,7 +122,7 @@ def get_int(self, key: str, default: int | None=None) -> int:
"""

@abstractmethod
def get_float(self, key: str, default: float | None=None) -> float:
def get_float(self, key: str, default: float | None = None) -> float:
"""Get the float configuration value by key.

Args:
Expand All @@ -129,7 +134,7 @@ def get_float(self, key: str, default: float | None=None) -> float:
"""

@abstractmethod
def get_bool(self, key: str, default: bool | None=None) -> bool:
def get_bool(self, key: str, default: bool | None = None) -> bool:
"""Get the boolean configuration value by key.

Args:
Expand All @@ -141,7 +146,7 @@ def get_bool(self, key: str, default: bool | None=None) -> bool:
"""

@abstractmethod
def get_str(self, key: str, default: str | None=None) -> str:
def get_str(self, key: str, default: str | None = None) -> str:
"""Get the string configuration value by key.

Args:
Expand All @@ -163,6 +168,7 @@ def get(self, option: ConfigOption) -> Any:
The value of the given option
"""


class Configuration(WritableConfiguration, ReadableConfiguration, ABC):
"""A configuration object that provides both read and write access to a
configuration object.
Expand Down
1 change: 1 addition & 0 deletions python/flink_agents/api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def vector_store(func: Callable) -> Callable:
func._is_vector_store = True
return func


def java_resource(cls: Type) -> Type:
"""Decorator to mark a class as Java resource."""
cls._is_java_resource = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class JavaEmbeddingModelConnection(BaseEmbeddingModelConnection):
functionality in Python.
"""

java_class_name: str=""
java_class_name: str = ""


@java_resource
class JavaEmbeddingModelSetup(BaseEmbeddingModelSetup):
Expand All @@ -44,4 +45,4 @@ class JavaEmbeddingModelSetup(BaseEmbeddingModelSetup):
implementation.
"""

java_class_name: str=""
java_class_name: str = ""
2 changes: 2 additions & 0 deletions python/flink_agents/api/events/context_retrieval_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ContextRetrievalRequestEvent(Event):
max_results : int
Maximum number of results to return (default: 3)
"""

query: str
vector_store: str
max_results: int = 3
Expand All @@ -51,6 +52,7 @@ class ContextRetrievalResponseEvent(Event):
documents : List[Document]
List of retrieved documents from the vector store
"""

request_id: UUID
query: str
documents: List[Document]
5 changes: 4 additions & 1 deletion python/flink_agents/api/execution_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ def execute(self, job_name: str | None = None) -> None:
"""Execute agent individually."""

def add_resource(
self, name: str, resource_type: ResourceType, instance: SerializableResource | ResourceDescriptor
self,
name: str,
resource_type: ResourceType,
instance: SerializableResource | ResourceDescriptor,
) -> "AgentsExecutionEnvironment":
"""Register resource to agent execution environment.

Expand Down
2 changes: 2 additions & 0 deletions python/flink_agents/api/memory/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

ItemType = str | ChatMessage


class CompactionConfig(BaseModel):
"""Compaction configuration.

Expand All @@ -48,6 +49,7 @@ class CompactionConfig(BaseModel):
prompt: str | Prompt | None = None
limit: int = 1


class LongTermMemoryBackend(Enum):
"""Backend for Long-Term Memory."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)


def test_memory_set_serialization() -> None: # noqa:D103
def test_memory_set_serialization() -> None:
memory_set = MemorySet(
name="chat_history",
item_type=ChatMessage,
Expand Down
7 changes: 5 additions & 2 deletions python/flink_agents/api/memory_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
if TYPE_CHECKING:
from flink_agents.api.memory_reference import MemoryRef


class MemoryType(Enum):
"""Memory types based on MemoryObject."""
SENSORY = "sensory",

SENSORY = ("sensory",)
SHORT_TERM = "short_term"


class MemoryObject(BaseModel, ABC):
"""Representation of an object in the short-term memory.

Expand All @@ -38,7 +41,7 @@ class MemoryObject(BaseModel, ABC):
"""

@abstractmethod
def get(self, path_or_ref: Union[str,"MemoryRef"] ) -> Any:
def get(self, path_or_ref: Union[str, "MemoryRef"]) -> Any:
"""Get the value of a (direct or indirect) field or a MemoryRef in the object.

Parameters
Expand Down
Loading
Loading