Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']

steps:
- name: Checkout code
Expand Down
120 changes: 65 additions & 55 deletions vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Expand All @@ -59,6 +58,12 @@
from . import types as genai_types


if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


try:
_BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names
except AttributeError:
Expand All @@ -78,20 +83,30 @@
_STDLIB_MODULE_NAMES: frozenset[str] = frozenset() # type: ignore[no-redef]


try:
from google.cloud import storage
if typing.TYPE_CHECKING:
from google.cloud import storage # type: ignore[attr-defined]

_StorageBucket: type[Any] = storage.Bucket
except (ImportError, AttributeError):
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
_StorageBucket: TypeAlias = storage.Bucket
else:
try:
from google.cloud import storage # type: ignore[attr-defined]

_StorageBucket: type[Any] = storage.Bucket
except (ImportError, AttributeError):
_StorageBucket: type[Any] = Any # type: ignore[no-redef]

try:

if typing.TYPE_CHECKING:
import packaging

_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
except (ImportError, AttributeError):
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
_SpecifierSet = packaging.specifiers.SpecifierSet
else:
try:
import packaging

_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
except (ImportError, AttributeError):
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]


try:
Expand Down Expand Up @@ -258,16 +273,22 @@ class OperationRegistrable(Protocol):
"""Protocol for agents that have registered operations."""

@abc.abstractmethod
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ignore[no-untyped-def]
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
"""Register the user provided operations (modes and methods)."""
pass


try:
if typing.TYPE_CHECKING:
from google.adk.agents import BaseAgent

ADKAgent: type[Any] = BaseAgent
except (ImportError, AttributeError):
ADKAgent: type[Any] = Any # type: ignore[no-redef]
ADKAgent: TypeAlias = BaseAgent
else:
try:
from google.adk.agents import BaseAgent

ADKAgent: Optional[TypeAlias] = BaseAgent
except (ImportError, AttributeError):
ADKAgent = None # type: ignore[no-redef]

_AgentEngineInterface = Union[
ADKAgent,
Expand All @@ -283,8 +304,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
class _ModuleAgentAttributes(TypedDict, total=False):
module_name: str
agent_name: str
register_operations: Dict[str, Sequence[str]]
register_operations: Dict[str, list[str]]
sys_paths: Optional[Sequence[str]]
agent: _AgentEngineInterface


class ModuleAgent(Cloneable, OperationRegistrable):
Expand All @@ -300,7 +322,7 @@ def __init__(
*,
module_name: str,
agent_name: str,
register_operations: Dict[str, Sequence[str]],
register_operations: Dict[str, list[str]],
sys_paths: Optional[Sequence[str]] = None,
):
"""Initializes a module-based agent.
Expand All @@ -310,7 +332,7 @@ def __init__(
Required. The name of the module to import.
agent_name (str):
Required. The name of the agent in the module to instantiate.
register_operations (Dict[str, Sequence[str]]):
register_operations (Dict[str, list[str]]):
Required. A dictionary of API modes to a list of method names.
sys_paths (Sequence[str]):
Optional. The system paths to search for the module. It should
Expand All @@ -336,8 +358,11 @@ def clone(self) -> "ModuleAgent":
sys_paths=self._tmpl_attrs.get("sys_paths"),
)

def register_operations(self) -> Dict[str, Sequence[str]]:
self._tmpl_attrs.get("register_operations")
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
reg_operations = self._tmpl_attrs.get("register_operations")
if reg_operations is None:
raise ValueError("Register operations is not set.")
return reg_operations

def set_up(self) -> None:
"""Sets up the agent for execution of queries at runtime.
Expand Down Expand Up @@ -411,7 +436,7 @@ def __call__(
class GetAsyncOperationFunction(Protocol):
async def __call__(
self, *, operation_name: str, **kwargs: Any
) -> Awaitable[AgentEngineOperationUnion]:
) -> AgentEngineOperationUnion:
pass


Expand Down Expand Up @@ -507,7 +532,7 @@ def _await_operation(
def _compare_requirements(
*,
requirements: Mapping[str, str],
constraints: Union[Sequence[str], Mapping[str, "_SpecifierSet"]],
constraints: Union[Sequence[str], Mapping[str, Optional["_SpecifierSet"]]],
required_packages: Optional[Iterator[str]] = None,
) -> _RequirementsValidationResult:
"""Compares the requirements with the constraints.
Expand Down Expand Up @@ -536,7 +561,7 @@ def _compare_requirements(
"""
packaging_version = _import_packaging_version_or_raise()
if required_packages is None:
required_packages = _DEFAULT_REQUIRED_PACKAGES
required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
result = _RequirementsValidationResult(
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
actions=_RequirementsValidationActions(append=set()),
Expand Down Expand Up @@ -583,7 +608,7 @@ def _generate_class_methods_spec_or_raise(
if isinstance(agent, ModuleAgent):
# We do a dry-run of setting up the agent engine to have the operations
# needed for registration.
agent: ModuleAgent = agent.clone()
agent: ModuleAgent = agent.clone() # type: ignore[no-redef]
try:
agent.set_up()
except Exception as e:
Expand Down Expand Up @@ -819,13 +844,13 @@ def _get_gcs_bucket(
new_bucket = storage_client.bucket(staging_bucket)
gcs_bucket = storage_client.create_bucket(new_bucket, location=location)
logger.info(f"Creating bucket {staging_bucket} in {location=}")
return gcs_bucket # type: ignore[no-any-return]
return gcs_bucket


def _get_registered_operations(
*,
agent: _AgentEngineInterface,
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
"""Retrieves registered operations for a AgentEngine."""
if isinstance(agent, OperationRegistrable):
return agent.register_operations()
Expand Down Expand Up @@ -859,13 +884,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
def _import_cloud_storage_or_raise() -> types.ModuleType:
"""Tries to import the Cloud Storage module."""
try:
from google.cloud import storage
from google.cloud import storage # type: ignore[attr-defined]
except ImportError as e:
raise ImportError(
"Cloud Storage is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return storage
return storage # type: ignore[no-any-return]


def _import_packaging_requirements_or_raise() -> types.ModuleType:
Expand Down Expand Up @@ -1202,7 +1227,7 @@ def _upload_agent_engine(
) -> None:
"""Uploads the agent engine to GCS."""
cloudpickle = _import_cloudpickle_or_raise()
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") # type: ignore[attr-defined]
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}")
with blob.open("wb") as f:
try:
cloudpickle.dump(agent, f)
Expand All @@ -1216,7 +1241,7 @@ def _upload_agent_engine(
_ = cloudpickle.load(f)
except Exception as e:
raise TypeError("Agent engine serialized to an invalid format") from e
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}")


Expand All @@ -1227,9 +1252,9 @@ def _upload_requirements(
gcs_dir_name: str,
) -> None:
"""Uploads the requirements file to GCS."""
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}") # type: ignore[attr-defined]
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}")
blob.upload_from_string("\n".join(requirements))
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
logger.info(f"Writing to {dir_name}/{_REQUIREMENTS_FILE}")


Expand All @@ -1246,9 +1271,9 @@ def _upload_extra_packages(
for file in extra_packages:
tar.add(file)
tar_fileobj.seek(0)
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}") # type: ignore[attr-defined]
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}")
blob.upload_from_string(tar_fileobj.read())
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")


Expand Down Expand Up @@ -1369,7 +1394,7 @@ def _validate_requirements_or_warn(
*,
obj: Any,
requirements: List[str],
) -> Mapping[str, str]:
) -> List[str]:
"""Compiles the requirements into a list of requirements."""
requirements = requirements.copy()
try:
Expand All @@ -1380,16 +1405,14 @@ def _validate_requirements_or_warn(
requirements=current_requirements,
constraints=constraints,
)
for warning_type, warnings in missing_requirements.get(
_WARNINGS_KEY, {}
).items():
for warning_type, warnings in missing_requirements["warnings"].items():
if warnings:
logger.warning(
f"The following requirements are {warning_type}: {warnings}"
)
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
for action_type, actions in missing_requirements["actions"].items():
if actions and action_type == _ACTION_APPEND:
for action in actions:
for action in actions: # type: ignore[attr-defined]
requirements.append(action)
logger.info(f"The following requirements are appended: {actions}")
except Exception as e:
Expand All @@ -1413,7 +1436,7 @@ def _validate_requirements_or_raise(
logger.info(f"Read the following lines: {requirements}")
except IOError as err:
raise IOError(f"Failed to read requirements from {requirements=}") from err
requirements = _validate_requirements_or_warn( # type: ignore[assignment]
requirements = _validate_requirements_or_warn(
obj=agent,
requirements=requirements,
)
Expand Down Expand Up @@ -1560,19 +1583,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
return _method


AgentEngineOperationUnion = Union[
genai_types.AgentEngineOperation,
genai_types.AgentEngineMemoryOperation,
genai_types.AgentEngineGenerateMemoriesOperation,
]


class GetOperationFunction(Protocol):
def __call__( # noqa: E704
self, *, operation_name: str, **kwargs: Any
) -> AgentEngineOperationUnion: ...


def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]:
"""Wraps an Agent Engine method, creating a callable for `query` API.

Expand Down Expand Up @@ -1835,7 +1845,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]

return response

return _method
return _method # type: ignore[return-value]


def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
Expand Down
Loading
Loading