From 753ab8c0048857190ec6d6593ddc30d746b7a400 Mon Sep 17 00:00:00 2001 From: benjibc Date: Sun, 10 Aug 2025 17:42:59 +0000 Subject: [PATCH] fix precommit --- .flake8 | 5 +- .pre-commit-config.yaml | 2 + LICENSE | 2 +- Makefile | 35 +- .../notes/pytest_integration_proposal.md | 6 +- eval_protocol/adapters/CONTRIBUTING.md | 82 ++--- eval_protocol/adapters/__init__.py | 22 +- eval_protocol/adapters/huggingface.py | 258 +++++++------- eval_protocol/adapters/langfuse.py | 6 +- eval_protocol/pytest/__init__.py | 2 +- .../pytest/default_dataset_adapter.py | 2 +- eval_protocol/pytest/types.py | 6 +- eval_protocol/types/types.py | 3 +- examples/adapters/README.md | 10 +- .../adapters/gsm8k_replacement_example.py | 112 +++--- examples/adapters/huggingface_example.py | 320 ++++++++--------- examples/adapters/langfuse_example.py | 112 +++--- .../tests/test_record_and_replay_e2e.py | 2 +- mypy.ini | 1 + scripts/fix_whitespace.py | 76 ++++ tests/conftest.py | 1 + tests/pytest/data/basic_coding_dataset.jsonl | 2 +- tests/pytest/data/lunar_lander_dataset.jsonl | 2 +- .../helper/word_count_to_evaluation_row.py | 5 +- tests/pytest/test_apps_coding.py | 13 +- tests/pytest/test_basic_coding.py | 36 +- tests/pytest/test_frozen_lake.py | 22 +- tests/pytest/test_lunar_lander.py | 28 +- ..._pytest_default_agent_rollout_processor.py | 2 +- tests/pytest/test_pytest_function_calling.py | 1 + tests/pytest/test_pytest_input_messages.py | 2 +- tests/pytest/test_pytest_json_schema.py | 1 + tests/pytest/test_pytest_mcp_url.py | 2 +- tests/test_adapters_e2e.py | 327 +++++++++--------- tests/test_cli_agent.py | 2 +- tests/test_url_handling.py | 1 + vendor/tau2/__init__.py | 1 - vendor/tau2/agent/README.md | 2 +- vendor/tau2/agent/base.py | 4 +- vendor/tau2/agent/llm_agent.py | 46 +-- vendor/tau2/cli.py | 4 +- .../user_simulator/simulation_guidelines.md | 4 +- .../simulation_guidelines_tools.md | 4 +- vendor/tau2/data_model/__init__.py | 1 - vendor/tau2/data_model/message.py | 58 +--- vendor/tau2/data_model/simulation.py | 58 +--- vendor/tau2/data_model/tasks.py | 95 ++--- vendor/tau2/domains/airline/data_model.py | 88 ++--- vendor/tau2/domains/airline/tools.py | 79 +---- vendor/tau2/domains/mock/data_model.py | 8 +- vendor/tau2/domains/mock/environment.py | 4 +- vendor/tau2/domains/mock/tools.py | 4 +- vendor/tau2/domains/retail/data_model.py | 96 ++--- vendor/tau2/domains/retail/tools.py | 33 +- vendor/tau2/domains/telecom/data_model.py | 104 ++---- vendor/tau2/domains/telecom/environment.py | 12 +- vendor/tau2/domains/telecom/tasks/const.py | 2 +- .../tau2/domains/telecom/tasks/mms_issues.py | 26 +- .../telecom/tasks/mobile_data_issues.py | 8 +- vendor/tau2/domains/telecom/tasks/utils.py | 4 +- vendor/tau2/domains/telecom/tools.py | 47 +-- .../tau2/domains/telecom/user_data_model.py | 74 +--- vendor/tau2/domains/telecom/user_tools.py | 100 ++---- vendor/tau2/environment/server.py | 18 +- vendor/tau2/environment/tool.py | 12 +- vendor/tau2/environment/toolkit.py | 16 +- .../tau2/environment/utils/interface_agent.py | 16 +- vendor/tau2/evaluator/__init__.py | 1 - vendor/tau2/evaluator/evaluator.py | 4 +- vendor/tau2/evaluator/evaluator_action.py | 5 +- .../tau2/evaluator/evaluator_communicate.py | 8 +- vendor/tau2/evaluator/evaluator_env.py | 24 +- vendor/tau2/metrics/agent_metrics.py | 8 +- vendor/tau2/metrics/break_down_metrics.py | 26 +- .../tau2/orchestrator/environment_manager.py | 20 +- vendor/tau2/orchestrator/orchestrator.py | 114 ++---- vendor/tau2/orchestrator/utils.py | 4 +- vendor/tau2/registry.py | 44 +-- vendor/tau2/run.py | 56 ++- vendor/tau2/scripts/show_domain_doc.py | 4 +- vendor/tau2/scripts/start_servers.py | 8 +- vendor/tau2/scripts/view_simulations.py | 70 ++-- vendor/tau2/user/base.py | 12 +- vendor/tau2/utils/display.py | 58 +--- vendor/tau2/utils/pydantic_utils.py | 4 +- vendor/tau2/utils/utils.py | 10 +- vite-app/.gitignore | 4 +- vite-app/index.html | 2 +- vite-app/src/index.css | 2 +- vite-app/src/types/README.md | 4 +- vite-app/vite.config.ts | 2 +- 91 files changed, 1224 insertions(+), 1809 deletions(-) create mode 100644 scripts/fix_whitespace.py diff --git a/.flake8 b/.flake8 index 06945f46..4929c23e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] -max-line-length = 119 -ignore = E203, W503 +max-line-length = 200 +ignore = E203, W503, E501, E402, F401, F541, F811, F841, E704, E713, E712, E231, E731, E226, W291, W293, W292, E302, W504 +exclude = vendor diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43c0f8c1..d9e5c936 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,7 @@ repos: rev: 7.3.0 hooks: - id: flake8 + exclude: ^vendor/ args: [--max-line-length=119, --max-complexity=100, "--ignore=E402,F401,F541,W503,E203,F811,E226,F841,E704,E713,E712,E231,E731,E501"] # additional_dependencies: [flake8-docstrings, flake8-import-order] # Optional: add flake8 plugins @@ -36,6 +37,7 @@ repos: rev: v1.17.0 hooks: - id: mypy + exclude: ^vendor/ args: [--ignore-missing-imports, --install-types, --non-interactive] additional_dependencies: - types-requests diff --git a/LICENSE b/LICENSE index e926381a..4bff8e12 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/Makefile b/Makefile index 982f3ed0..b4cf84f8 100644 --- a/Makefile +++ b/Makefile @@ -1,32 +1,45 @@ PYTHON_DIRS = tests examples scripts eval_protocol -.PHONY: clean build dist upload test lint typecheck format release sync-docs version tag-version show-version bump-major bump-minor bump-patch full-release quick-release +# Prefer tools from local virtualenv if present +VENV ?= .venv +VENV_BIN := $(VENV)/bin +PYTHON := $(if $(wildcard $(VENV_BIN)/python),$(VENV_BIN)/python,python) +FLAKE8 := $(if $(wildcard $(VENV_BIN)/flake8),$(VENV_BIN)/flake8,flake8) +MYPY := $(if $(wildcard $(VENV_BIN)/mypy),$(VENV_BIN)/mypy,mypy) +BLACK := $(if $(wildcard $(VENV_BIN)/black),$(VENV_BIN)/black,black) +PRE_COMMIT := $(if $(wildcard $(VENV_BIN)/pre-commit),$(VENV_BIN)/pre-commit,pre-commit) +PYTEST := $(if $(wildcard $(VENV_BIN)/pytest),$(VENV_BIN)/pytest,pytest) +TWINE := $(if $(wildcard $(VENV_BIN)/twine),$(VENV_BIN)/twine,twine) + +.PHONY: clean build dist upload test lint typecheck format release sync-docs version tag-version show-version bump-major bump-minor bump-patch full-release quick-release pre-commit help clean: rm -rf build/ dist/ *.egg-info/ +# Run all pre-commit hooks (if installed) pre-commit: - pre-commit run --all-files + $(PRE_COMMIT) run --all-files build: clean - python -m build + $(PYTHON) -m build dist: build upload: - twine upload dist/* + $(TWINE) upload dist/* test: - pytest + $(PYTEST) lint: - flake8 $(PYTHON_DIRS) + $(PRE_COMMIT) run flake8 --all-files typecheck: - mypy $(PYTHON_DIRS) + $(PRE_COMMIT) run mypy --all-files format: - black $(PYTHON_DIRS) + $(PRE_COMMIT) run black --all-files && \ + $(PRE_COMMIT) run isort --all-files validate-docs: @echo "Validating documentation links..." @@ -140,9 +153,9 @@ help: @echo " dist - Alias for build" @echo " upload - Upload to PyPI (make sure to bump version first)" @echo " test - Run tests" - @echo " lint - Run flake8 linter" - @echo " typecheck - Run mypy type checker" - @echo " format - Run black code formatter" + @echo " lint - Run flake8 via pre-commit" + @echo " typecheck - Run mypy via pre-commit" + @echo " format - Run black + isort via pre-commit" @echo " validate-docs - Validate all documentation links in docs.json" @echo " sync-docs - Sync docs to ~/home/docs with links under 'evaluators'" @echo " release - Run lint, typecheck, test, build, then upload" diff --git a/development/notes/pytest_integration_proposal.md b/development/notes/pytest_integration_proposal.md index c9496587..784cc215 100644 --- a/development/notes/pytest_integration_proposal.md +++ b/development/notes/pytest_integration_proposal.md @@ -115,7 +115,7 @@ def tau2_rollout_processor(row: EvaluationRow, model: str, input_params: Dict, * # from the dataset and provide a simulated tool response. # 4. Call the model again with the tool response. # 5. Construct a final EvaluationRow with the full transcript. - + # The logic is encapsulated here, away from the test definition. processed_row = ep.default_rollout_processor(row, model, input_params)[0] # Simplified for example return [processed_row] @@ -186,11 +186,11 @@ def best_of_n_processor(row: EvaluationRow, model: str, input_params: Dict, **kw # Then, apply a reward function to score each candidate. scored_rows = ep.evaluate(candidate_rows, score_politeness) - + # Finally, select the best row. # This logic could be encapsulated in a helper, e.g., ep.select_best(). best_row = select_best_by_group(scored_rows, score_key='politeness') - + return [best_row] @evaluation_test( diff --git a/eval_protocol/adapters/CONTRIBUTING.md b/eval_protocol/adapters/CONTRIBUTING.md index 18f31378..e47e06e9 100644 --- a/eval_protocol/adapters/CONTRIBUTING.md +++ b/eval_protocol/adapters/CONTRIBUTING.md @@ -37,36 +37,36 @@ except ImportError: class YourCustomAdapter: """Adapter for integrating with Your Custom Data Source. - + This adapter loads data from Your Custom Data Source and converts it to EvaluationRow format for use in evaluation pipelines. - + Examples: Basic usage: >>> adapter = YourCustomAdapter(api_key="your_key") >>> rows = list(adapter.get_evaluation_rows(limit=10)) """ - + def __init__(self, **config): """Initialize the adapter with configuration.""" if not DEPENDENCY_AVAILABLE: raise ImportError("your_external_library not installed") - + # Initialize your client/connection here self.client = your_external_library.Client(**config) - + def get_evaluation_rows(self, **kwargs) -> Iterator[EvaluationRow]: """Main method to fetch and convert data to EvaluationRow format. - + Args: **kwargs: Adapter-specific parameters - + Yields: EvaluationRow: Converted evaluation rows """ # Implement your data fetching logic raw_data = self.client.fetch_data(**kwargs) - + for item in raw_data: try: eval_row = self._convert_to_evaluation_row(item) @@ -75,51 +75,51 @@ class YourCustomAdapter: except Exception as e: logger.warning(f"Failed to convert item: {e}") continue - + def _convert_to_evaluation_row(self, raw_item: Any) -> Optional[EvaluationRow]: """Convert a raw data item to EvaluationRow format. - + Args: raw_item: Raw data item from your source - + Returns: EvaluationRow or None if conversion fails """ # Extract messages from your data format messages = self._extract_messages(raw_item) - + # Extract metadata input_metadata = self._create_input_metadata(raw_item) - + # Extract ground truth if available ground_truth = self._extract_ground_truth(raw_item) - + # Extract tools if available (for tool calling scenarios) tools = self._extract_tools(raw_item) - + return EvaluationRow( messages=messages, tools=tools, input_metadata=input_metadata, ground_truth=ground_truth, ) - + def _extract_messages(self, raw_item: Any) -> List[Message]: """Extract conversation messages from raw data.""" # Implement message extraction logic # Convert your data format to List[Message] pass - + def _create_input_metadata(self, raw_item: Any) -> InputMetadata: """Create InputMetadata from raw data.""" # Implement metadata extraction pass - + def _extract_ground_truth(self, raw_item: Any) -> Optional[str]: """Extract ground truth if available.""" # Implement ground truth extraction pass - + def _extract_tools(self, raw_item: Any) -> Optional[List[Dict[str, Any]]]: """Extract tool definitions if available.""" # Implement tool extraction for tool calling scenarios @@ -149,7 +149,7 @@ message = Message( content="I'll help you with that calculation.", tool_calls=[{ "id": "call_123", - "type": "function", + "type": "function", "function": { "name": "calculate", "arguments": '{"x": 5, "y": 3}' @@ -185,7 +185,7 @@ input_metadata = InputMetadata( }, session_data={ "user_id": "user123", - "session_id": "session456", + "session_id": "session456", "timestamp": "2024-01-01T00:00:00Z", } ) @@ -259,7 +259,7 @@ def get_evaluation_rows(self, **kwargs) -> Iterator[EvaluationRow]: except Exception as e: logger.error(f"Failed to fetch data: {e}") return - + for item in data: try: row = self._convert_to_evaluation_row(item) @@ -298,36 +298,36 @@ from eval_protocol.models import EvaluationRow class TestYourCustomAdapter: """Test suite for YourCustomAdapter.""" - + def test_initialization(self): """Test adapter initialization.""" adapter = YourCustomAdapter(api_key="test_key") assert adapter.client is not None - + def test_get_evaluation_rows(self): """Test conversion to EvaluationRow format.""" adapter = YourCustomAdapter(api_key="test_key") - + # Mock the external API response with patch.object(adapter.client, 'fetch_data') as mock_fetch: mock_fetch.return_value = [ # Mock data in your format {"id": "1", "question": "Test?", "answer": "Yes"} ] - + rows = list(adapter.get_evaluation_rows(limit=1)) - + assert len(rows) == 1 assert isinstance(rows[0], EvaluationRow) assert len(rows[0].messages) > 0 - + def test_error_handling(self): """Test error handling.""" adapter = YourCustomAdapter(api_key="test_key") - + with patch.object(adapter.client, 'fetch_data') as mock_fetch: mock_fetch.side_effect = Exception("API Error") - + rows = list(adapter.get_evaluation_rows()) assert len(rows) == 0 # Should handle error gracefully ``` @@ -341,18 +341,18 @@ For simple chat data: ```python def _extract_messages(self, conversation: Dict) -> List[Message]: messages = [] - + # Add system prompt if available if conversation.get('system_prompt'): messages.append(Message(role="system", content=conversation['system_prompt'])) - + # Add conversation turns for turn in conversation['turns']: messages.append(Message( role=turn['role'], content=turn['content'] )) - + return messages ``` @@ -363,27 +363,27 @@ For tool calling scenarios: ```python def _extract_messages(self, trace: Dict) -> List[Message]: messages = [] - + for step in trace['steps']: if step['type'] == 'user_message': messages.append(Message(role="user", content=step['content'])) - + elif step['type'] == 'assistant_message': message = Message(role="assistant", content=step.get('content')) - + # Add tool calls if present if step.get('tool_calls'): message.tool_calls = step['tool_calls'] - + messages.append(message) - + elif step['type'] == 'tool_response': messages.append(Message( role="tool", content=step['content'], tool_call_id=step['tool_call_id'] )) - + return messages ``` @@ -515,10 +515,10 @@ Here are some potential adapters that would be valuable: - **OpenAI Evals**: Load data from OpenAI's evals repository - **LLM Evaluation Datasets**: MMLU, HellaSwag, etc. -- **Chat Platforms**: Discord, Slack conversation exports +- **Chat Platforms**: Discord, Slack conversation exports - **Monitoring Tools**: Other observability platforms - **Custom APIs**: Company-specific data sources - **File Formats**: Parquet, Excel, database exports - **Research Datasets**: Academic benchmarks and competitions -We welcome contributions for any of these or other creative integrations! \ No newline at end of file +We welcome contributions for any of these or other creative integrations! diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index fc04237b..fe79c7f4 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -13,35 +13,41 @@ # Conditional imports based on available dependencies try: from .langfuse import LangfuseAdapter, create_langfuse_adapter + __all__ = ["LangfuseAdapter", "create_langfuse_adapter"] except ImportError: __all__ = [] try: from .huggingface import ( - HuggingFaceAdapter, - create_huggingface_adapter, + HuggingFaceAdapter, create_gsm8k_adapter, + create_huggingface_adapter, create_math_adapter, ) - __all__.extend([ - "HuggingFaceAdapter", - "create_huggingface_adapter", - "create_gsm8k_adapter", - "create_math_adapter", - ]) + + __all__.extend( + [ + "HuggingFaceAdapter", + "create_huggingface_adapter", + "create_gsm8k_adapter", + "create_math_adapter", + ] + ) except ImportError: pass # Legacy adapters (always available) try: from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn + __all__.extend(["scorer_to_reward_fn", "reward_fn_to_scorer"]) except ImportError: pass try: from .trl import create_trl_adapter + __all__.extend(["create_trl_adapter"]) except ImportError: pass diff --git a/eval_protocol/adapters/huggingface.py b/eval_protocol/adapters/huggingface.py index 15391181..b66a2a60 100644 --- a/eval_protocol/adapters/huggingface.py +++ b/eval_protocol/adapters/huggingface.py @@ -4,21 +4,20 @@ transformation functions to convert them to EvaluationRow format. """ -from typing import Any, Callable, Dict, Iterator, List, Optional import logging +from typing import Any, Callable, Dict, Iterator, List, Optional -from eval_protocol.models import EvaluationRow, Message, InputMetadata, CompletionParams +from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message logger = logging.getLogger(__name__) try: - from datasets import load_dataset, Dataset, DatasetDict + from datasets import Dataset, DatasetDict, load_dataset + DATASETS_AVAILABLE = True except ImportError: DATASETS_AVAILABLE = False - logger.warning( - "HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'" - ) + logger.warning("HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'") # Type alias for transformation function TransformFunction = Callable[[Dict[str, Any]], Dict[str, Any]] @@ -26,11 +25,11 @@ class HuggingFaceAdapter: """Generic adapter to load HuggingFace datasets with custom transformations. - + This adapter loads datasets from HuggingFace Hub and applies a user-provided - transformation function to convert each row to the format expected by + transformation function to convert each row to the format expected by EvaluationRow. - + The transformation function should take a dataset row dictionary and return: { 'messages': List[Dict] - list of message dictionaries with 'role' and 'content' @@ -38,7 +37,7 @@ class HuggingFaceAdapter: 'metadata': Optional[Dict] - any additional metadata to preserve 'tools': Optional[List[Dict]] - tool definitions for tool calling scenarios } - + Examples: Simple Q&A dataset: >>> def transform(row): @@ -49,7 +48,7 @@ class HuggingFaceAdapter: ... } >>> adapter = HuggingFaceAdapter("my-dataset", transform_fn=transform) >>> rows = list(adapter.get_evaluation_rows(split="test", limit=10)) - + Math problems with system prompt: >>> def gsm8k_transform(row): ... return { @@ -62,7 +61,7 @@ class HuggingFaceAdapter: ... } >>> adapter = HuggingFaceAdapter("gsm8k", config_name="main", transform_fn=gsm8k_transform) """ - + def __init__( self, dataset_id: str, @@ -72,7 +71,7 @@ def __init__( **load_dataset_kwargs, ): """Initialize the HuggingFace adapter. - + Args: dataset_id: HuggingFace dataset identifier (e.g., "gsm8k", "squad", "org/dataset") transform_fn: Function to transform dataset rows to evaluation format @@ -84,16 +83,16 @@ def __init__( raise ImportError( "HuggingFace datasets not installed. Install with: pip install 'eval-protocol[huggingface]'" ) - + self.dataset_id = dataset_id self.transform_fn = transform_fn self.config_name = config_name self.revision = revision self.load_dataset_kwargs = load_dataset_kwargs - + # Load the dataset self.dataset = self._load_dataset() - + @classmethod def from_local( cls, @@ -102,53 +101,49 @@ def from_local( **load_dataset_kwargs, ) -> "HuggingFaceAdapter": """Create adapter from local dataset file. - + Args: path: Path to local dataset file (JSON, JSONL, CSV, etc.) transform_fn: Function to transform dataset rows **load_dataset_kwargs: Additional arguments to pass to load_dataset - + Returns: HuggingFaceAdapter instance """ # Determine file format - if path.endswith('.jsonl'): + if path.endswith(".jsonl"): dataset_type = "json" - elif path.endswith('.json'): + elif path.endswith(".json"): dataset_type = "json" - elif path.endswith('.csv'): + elif path.endswith(".csv"): dataset_type = "csv" - elif path.endswith('.parquet'): + elif path.endswith(".parquet"): dataset_type = "parquet" else: # Let HuggingFace auto-detect dataset_type = None - - load_kwargs = {'data_files': path, **load_dataset_kwargs} - - return cls( - dataset_id=dataset_type or "json", - transform_fn=transform_fn, - **load_kwargs - ) - + + load_kwargs = {"data_files": path, **load_dataset_kwargs} + + return cls(dataset_id=dataset_type or "json", transform_fn=transform_fn, **load_kwargs) + def _load_dataset(self) -> "Dataset | DatasetDict": """Load the dataset from HuggingFace Hub or local source.""" try: kwargs = {} if self.config_name: - kwargs['name'] = self.config_name + kwargs["name"] = self.config_name if self.revision: - kwargs['revision'] = self.revision - + kwargs["revision"] = self.revision + kwargs.update(self.load_dataset_kwargs) - + return load_dataset(self.dataset_id, **kwargs) - + except (OSError, ValueError, RuntimeError) as e: logger.error("Failed to load dataset %s: %s", self.dataset_id, e) raise - + def get_evaluation_rows( self, split: Optional[str] = None, @@ -160,7 +155,7 @@ def get_evaluation_rows( **completion_params_kwargs, ) -> Iterator[EvaluationRow]: """Convert dataset entries to EvaluationRow format. - + Args: split: Dataset split to use (if dataset has multiple splits) limit: Maximum number of rows to return @@ -169,7 +164,7 @@ def get_evaluation_rows( temperature: Temperature for completion parameters max_tokens: Max tokens for completion parameters **completion_params_kwargs: Additional completion parameters - + Yields: EvaluationRow: Converted evaluation rows """ @@ -183,15 +178,15 @@ def get_evaluation_rows( dataset = self.dataset[split] elif split is not None: logger.warning("Split '%s' specified but dataset is not split", split) - + # Apply offset and limit total_rows = len(dataset) end_idx = min(offset + limit, total_rows) if limit else total_rows - + if offset >= total_rows: logger.warning("Offset %d is greater than dataset size %d", offset, total_rows) return - + # Create completion parameters completion_params = CompletionParams( model=model_name, @@ -199,19 +194,17 @@ def get_evaluation_rows( max_tokens=max_tokens, **completion_params_kwargs, ) - + # Convert each row for i in range(offset, end_idx): try: raw_row = dataset[i] - eval_row = self._convert_row_to_evaluation_row( - raw_row, i, completion_params, split - ) + eval_row = self._convert_row_to_evaluation_row(raw_row, i, completion_params, split) yield eval_row except (AttributeError, ValueError, KeyError) as e: logger.warning("Failed to convert row %d: %s", i, e) continue - + def _convert_row_to_evaluation_row( self, raw_row: Dict[str, Any], @@ -220,83 +213,87 @@ def _convert_row_to_evaluation_row( split: Optional[str] = None, ) -> EvaluationRow: """Convert a single dataset row to EvaluationRow format. - + Args: raw_row: Raw dataset row dictionary row_index: Index of the row in the dataset completion_params: Completion parameters to use split: Dataset split name - + Returns: EvaluationRow object """ # Apply user transformation transformed = self.transform_fn(raw_row) - + # Validate required fields - if 'messages' not in transformed: + if "messages" not in transformed: raise ValueError("Transform function must return 'messages' field") - + # Convert message dictionaries to Message objects messages = [] - for msg_dict in transformed['messages']: + for msg_dict in transformed["messages"]: if not isinstance(msg_dict, dict): raise ValueError("Each message must be a dictionary") - if 'role' not in msg_dict: + if "role" not in msg_dict: raise ValueError("Each message must have a 'role' field") - - messages.append(Message( - role=msg_dict['role'], - content=msg_dict.get('content'), - name=msg_dict.get('name'), - tool_call_id=msg_dict.get('tool_call_id'), - tool_calls=msg_dict.get('tool_calls'), - function_call=msg_dict.get('function_call'), - )) - + + messages.append( + Message( + role=msg_dict["role"], + content=msg_dict.get("content"), + name=msg_dict.get("name"), + tool_call_id=msg_dict.get("tool_call_id"), + tool_calls=msg_dict.get("tool_calls"), + function_call=msg_dict.get("function_call"), + ) + ) + # Extract other fields - ground_truth = transformed.get('ground_truth') - tools = transformed.get('tools') - user_metadata = transformed.get('metadata', {}) - + ground_truth = transformed.get("ground_truth") + tools = transformed.get("tools") + user_metadata = transformed.get("metadata", {}) + # Create dataset info dataset_info = { - 'dataset_id': self.dataset_id, - 'config_name': self.config_name, - 'revision': self.revision, - 'split': split, - 'row_index': row_index, - 'transform_function': self.transform_fn.__name__ if hasattr(self.transform_fn, '__name__') else 'anonymous', + "dataset_id": self.dataset_id, + "config_name": self.config_name, + "revision": self.revision, + "split": split, + "row_index": row_index, + "transform_function": ( + self.transform_fn.__name__ if hasattr(self.transform_fn, "__name__") else "anonymous" + ), } - + # Add user metadata dataset_info.update(user_metadata) - + # Add original row data (with prefix to avoid conflicts) for key, value in raw_row.items(): - dataset_info[f'original_{key}'] = value - + dataset_info[f"original_{key}"] = value + # Create input metadata input_metadata = InputMetadata( row_id=f"{self.dataset_id}_{row_index}", completion_params=completion_params, dataset_info=dataset_info, session_data={ - 'dataset_source': 'huggingface', - 'timestamp': None, - } + "dataset_source": "huggingface", + "timestamp": None, + }, ) - + return EvaluationRow( messages=messages, tools=tools, input_metadata=input_metadata, ground_truth=str(ground_truth) if ground_truth is not None else None, ) - + def get_splits(self) -> List[str]: """Get available dataset splits. - + Returns: List of available split names """ @@ -304,27 +301,29 @@ def get_splits(self) -> List[str]: return list(self.dataset.keys()) else: return ["train"] # Default split name for non-split datasets - + def get_dataset_info(self) -> Dict[str, Any]: """Get information about the loaded dataset. - + Returns: Dictionary with dataset information """ info = { - 'dataset_id': self.dataset_id, - 'config_name': self.config_name, - 'revision': self.revision, - 'splits': self.get_splits(), - 'transform_function': self.transform_fn.__name__ if hasattr(self.transform_fn, '__name__') else 'anonymous', + "dataset_id": self.dataset_id, + "config_name": self.config_name, + "revision": self.revision, + "splits": self.get_splits(), + "transform_function": ( + self.transform_fn.__name__ if hasattr(self.transform_fn, "__name__") else "anonymous" + ), } - + # Add split sizes if isinstance(self.dataset, DatasetDict): - info['split_sizes'] = {split: len(data) for split, data in self.dataset.items()} + info["split_sizes"] = {split: len(data) for split, data in self.dataset.items()} else: - info['total_size'] = len(self.dataset) - + info["total_size"] = len(self.dataset) + return info @@ -336,14 +335,14 @@ def create_huggingface_adapter( **load_dataset_kwargs, ) -> HuggingFaceAdapter: """Factory function to create a HuggingFace adapter. - + Args: dataset_id: HuggingFace dataset identifier transform_fn: Function to transform dataset rows to evaluation format config_name: Optional configuration name revision: Optional dataset revision/commit hash **load_dataset_kwargs: Additional arguments for load_dataset - + Returns: HuggingFaceAdapter instance """ @@ -362,11 +361,11 @@ def create_gsm8k_adapter( revision: Optional[str] = None, ) -> HuggingFaceAdapter: """Create adapter specifically configured for GSM8K dataset. - + Args: system_prompt: Optional system prompt for math problems revision: Optional dataset revision/commit - + Returns: HuggingFaceAdapter configured for GSM8K """ @@ -374,24 +373,24 @@ def create_gsm8k_adapter( "You are a helpful assistant that solves math problems step by step. " "Show your work and provide the final answer." ) - + system_content = system_prompt or default_system_prompt - + def gsm8k_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Transform GSM8K row to evaluation format.""" return { - 'messages': [ - {'role': 'system', 'content': system_content}, - {'role': 'user', 'content': row['question']}, + "messages": [ + {"role": "system", "content": system_content}, + {"role": "user", "content": row["question"]}, ], - 'ground_truth': row['answer'], - 'metadata': { - 'dataset': 'gsm8k', - 'question_length': len(row['question']), - 'answer_length': len(row['answer']), - } + "ground_truth": row["answer"], + "metadata": { + "dataset": "gsm8k", + "question_length": len(row["question"]), + "answer_length": len(row["answer"]), + }, } - + return create_huggingface_adapter( dataset_id="gsm8k", config_name="main", @@ -405,40 +404,39 @@ def create_math_adapter( revision: Optional[str] = None, ) -> HuggingFaceAdapter: """Create adapter specifically configured for MATH competition dataset. - + Args: system_prompt: Optional system prompt for math problems revision: Optional dataset revision/commit - + Returns: HuggingFaceAdapter configured for MATH dataset """ default_system_prompt = ( - "You are an expert mathematician. Solve this advanced math problem " - "step by step, showing detailed work." + "You are an expert mathematician. Solve this advanced math problem " "step by step, showing detailed work." ) - + system_content = system_prompt or default_system_prompt - + def math_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Transform MATH dataset row to evaluation format.""" return { - 'messages': [ - {'role': 'system', 'content': system_content}, - {'role': 'user', 'content': row['problem']}, + "messages": [ + {"role": "system", "content": system_content}, + {"role": "user", "content": row["problem"]}, ], - 'ground_truth': row['solution'], - 'metadata': { - 'dataset': 'hendrycks_math', - 'type': row.get('type', 'unknown'), - 'level': row.get('level', 'unknown'), - 'problem_length': len(row['problem']), - 'solution_length': len(row['solution']), - } + "ground_truth": row["solution"], + "metadata": { + "dataset": "hendrycks_math", + "type": row.get("type", "unknown"), + "level": row.get("level", "unknown"), + "problem_length": len(row["problem"]), + "solution_length": len(row["solution"]), + }, } - + return create_huggingface_adapter( dataset_id="hendrycks/competition_math", transform_fn=math_transform, revision=revision, - ) \ No newline at end of file + ) diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index a3f35cba..f219a9a6 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -4,11 +4,11 @@ to EvaluationRow format for use in evaluation pipelines. """ -from typing import Any, Dict, Iterator, List, Optional -from datetime import datetime import logging +from datetime import datetime +from typing import Any, Dict, Iterator, List, Optional -from eval_protocol.models import EvaluationRow, Message, InputMetadata, CompletionParams +from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message logger = logging.getLogger(__name__) diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index ce881ccc..a198def9 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,9 +1,9 @@ from .default_agent_rollout_processor import default_agent_rollout_processor +from .default_dataset_adapter import default_dataset_adapter from .default_no_op_rollout_process import default_no_op_rollout_processor from .default_single_turn_rollout_process import default_single_turn_rollout_processor from .evaluation_test import evaluation_test from .types import RolloutProcessor, RolloutProcessorConfig -from .default_dataset_adapter import default_dataset_adapter __all__ = [ "default_agent_rollout_processor", diff --git a/eval_protocol/pytest/default_dataset_adapter.py b/eval_protocol/pytest/default_dataset_adapter.py index 87377cff..7c4a7d73 100644 --- a/eval_protocol/pytest/default_dataset_adapter.py +++ b/eval_protocol/pytest/default_dataset_adapter.py @@ -7,4 +7,4 @@ def default_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]: """ Default dataset adapter that simply returns the rows as is. """ - return [EvaluationRow(**row) for row in rows] \ No newline at end of file + return [EvaluationRow(**row) for row in rows] diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 880a7029..4546931a 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -39,8 +39,10 @@ class RolloutProcessorConfig: model: ModelParam input_params: RolloutInputParam # optional input parameters for inference - mcp_config_path: str - server_script_path: Optional[str] = None # TODO: change from server_script_path to mcp_config_path for agent rollout processor + mcp_config_path: str + server_script_path: Optional[str] = ( + None # TODO: change from server_script_path to mcp_config_path for agent rollout processor + ) max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts steps: int = 30 # max number of rollout steps diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 7c0184f0..3f73ce3a 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -1,8 +1,9 @@ +from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional + from mcp.client.session import ClientSession -from contextlib import AsyncExitStack class TerminationReason(str, Enum): diff --git a/examples/adapters/README.md b/examples/adapters/README.md index f51cd387..4b8501ef 100644 --- a/examples/adapters/README.md +++ b/examples/adapters/README.md @@ -51,7 +51,7 @@ pip install 'eval-protocol[huggingface]' # Run Langfuse example python examples/adapters/langfuse_example.py -# Run HuggingFace example +# Run HuggingFace example python examples/adapters/huggingface_example.py # Run GSM8K replacement example @@ -100,7 +100,7 @@ def custom_gsm8k_transform(row): from eval_protocol.adapters.huggingface import create_huggingface_adapter custom_adapter = create_huggingface_adapter( dataset_id="gsm8k", - config_name="main", + config_name="main", transform_fn=custom_gsm8k_transform ) ``` @@ -150,7 +150,7 @@ rows = list(adapter.get_evaluation_rows(limit=10)) for row in rows: # Add model response (you would generate this) row.messages.append(Message(role="assistant", content="...")) - + # Evaluate result = math_reward(messages=row.messages, ground_truth=row.ground_truth) print(f"Score: {result.score}") @@ -222,7 +222,7 @@ class MyCustomAdapter: def __init__(self, **config): # Initialize your data source connection pass - + def get_evaluation_rows(self, **kwargs) -> Iterator[EvaluationRow]: # Fetch data and convert to EvaluationRow format pass @@ -272,4 +272,4 @@ We welcome contributions of new adapters! Popular integrations that would be val - **File format adapters**: Parquet, Excel, etc. - **Monitoring platform adapters**: DataDog, New Relic, etc. -See the adapter contributing guide for detailed instructions. \ No newline at end of file +See the adapter contributing guide for detailed instructions. diff --git a/examples/adapters/gsm8k_replacement_example.py b/examples/adapters/gsm8k_replacement_example.py index a86de261..2bba1ebb 100644 --- a/examples/adapters/gsm8k_replacement_example.py +++ b/examples/adapters/gsm8k_replacement_example.py @@ -1,8 +1,8 @@ """ GSM8K Replacement Example -This example shows how to replace the static GSM8K JSONL file -(development/gsm8k_sample.jsonl) with the dynamic HuggingFace adapter +This example shows how to replace the static GSM8K JSONL file +(development/gsm8k_sample.jsonl) with the dynamic HuggingFace adapter to get fresh data from the GSM8K dataset. """ @@ -18,17 +18,17 @@ def load_original_gsm8k_sample() -> List[dict]: """Load the original GSM8K sample file for comparison.""" sample_file = Path("development/gsm8k_sample.jsonl") - + if not sample_file.exists(): print(f"โš ๏ธ Original sample file not found: {sample_file}") return [] - + data = [] - with open(sample_file, 'r') as f: + with open(sample_file, "r") as f: for line in f: if line.strip(): data.append(json.loads(line)) - + return data @@ -36,52 +36,52 @@ def demonstrate_old_vs_new_approach(): """Compare the old static file approach with the new adapter approach.""" print("๐Ÿ“Š Comparing Old vs New Approach") print("=" * 50) - + # OLD APPROACH: Static JSONL file print("๐Ÿ—‚๏ธ OLD APPROACH: Static JSONL File") print("-" * 35) - + original_data = load_original_gsm8k_sample() print(f"Loaded {len(original_data)} items from static file") - + if original_data: sample = original_data[0] print(f"Sample item fields: {list(sample.keys())}") print(f"Sample question: {sample.get('user_query', '')[:100]}...") print(f"Sample ground truth: {sample.get('ground_truth_for_eval', '')[:100]}...") - - print("\n" + "="*50 + "\n") - + + print("\n" + "=" * 50 + "\n") + # NEW APPROACH: HuggingFace Adapter print("๐Ÿค— NEW APPROACH: HuggingFace Adapter") print("-" * 38) - + try: # Create adapter adapter = create_gsm8k_adapter( system_prompt="You are a helpful assistant that solves math problems step by step." ) - + print("โœ… GSM8K adapter created successfully") - + # Get the same number of items as the original file num_items = len(original_data) if original_data else 6 rows = list(adapter.get_evaluation_rows(limit=num_items)) - + print(f"Retrieved {len(rows)} evaluation rows from HuggingFace") - + if rows: sample_row = rows[0] print(f"Sample EvaluationRow fields: messages, tools, input_metadata, ground_truth") - + # Show the question from messages user_msg = next((msg for msg in sample_row.messages if msg.role == "user"), None) if user_msg: print(f"Sample question: {user_msg.content[:100]}...") - + if sample_row.ground_truth: print(f"Sample ground truth: {sample_row.ground_truth[:100]}...") - + except ImportError as e: print(f"โŒ Error: {e}") print("Install HuggingFace dependencies: pip install 'eval-protocol[huggingface]'") @@ -89,9 +89,9 @@ def demonstrate_old_vs_new_approach(): except Exception as e: print(f"โŒ Error with adapter: {e}") return - - print("\n" + "="*50 + "\n") - + + print("\n" + "=" * 50 + "\n") + # COMPARISON print("๐Ÿ” Key Differences") print("-" * 20) @@ -101,7 +101,7 @@ def demonstrate_old_vs_new_approach(): print(" โŒ Manual data preparation required") print(" โŒ Limited to pre-selected subset") print(" โŒ Requires manual format conversion") - + print("\nNEW APPROACH:") print(" โœ… Access to full GSM8K dataset (8,792 test problems)") print(" โœ… Automatic format conversion to EvaluationRow") @@ -115,10 +115,11 @@ def show_migration_example(): """Show how to migrate existing code from JSONL to adapter.""" print("\n๐Ÿ”„ Code Migration Example") print("=" * 30) - + print("OLD CODE:") print("-" * 10) - print(""" + print( + """ # Old way with static JSONL file input_dataset = ["development/gsm8k_sample.jsonl"] @@ -134,11 +135,13 @@ def show_migration_example(): ] ground_truth = item["ground_truth_for_eval"] # ... more manual processing -""") - +""" + ) + print("\nNEW CODE:") print("-" * 10) - print(""" + print( + """ # New way with HuggingFace adapter from eval_protocol.adapters.huggingface import create_gsm8k_adapter @@ -149,7 +152,7 @@ def show_migration_example(): # Get evaluation rows (already in correct format) evaluation_rows = list(adapter.get_evaluation_rows( - split="test", # or "train" + split="test", # or "train" limit=100, # Can get much more data than static file model_name="gpt-4", temperature=0.0, @@ -175,8 +178,9 @@ def custom_gsm8k_transform(row): config_name="main", transform_fn=custom_gsm8k_transform ) -""") - +""" + ) + print("\nโœ… Benefits of Migration:") print(" - More data available (6 โ†’ 8,792 problems)") print(" - Automatic format handling") @@ -189,30 +193,30 @@ def practical_migration_demo(): """Show a practical example of using the adapter in evaluation.""" print("\n๐Ÿงช Practical Evaluation Example") print("=" * 35) - + try: # Create adapter adapter = create_gsm8k_adapter() - + # Get a few problems for evaluation print("Loading GSM8K problems...") rows = list(adapter.get_evaluation_rows(limit=3)) print(f"โœ… Loaded {len(rows)} problems from GSM8K test set") - + # Simulate evaluation workflow for i, row in enumerate(rows): print(f"\n๐Ÿ“ Problem {i+1}:") - + # Show the problem user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" Question: {user_msg.content[:150]}...") - + # In a real scenario, you'd generate a response with your LLM # For this demo, we'll add a dummy response dummy_response = "Let me solve this step by step. After working through the math, the answer is 42." row.messages.append(Message(role="assistant", content=dummy_response)) - + # Evaluate with math reward function if row.ground_truth: try: @@ -222,7 +226,7 @@ def practical_migration_demo(): ) print(f" ๐Ÿ“Š Math evaluation score: {result.score:.2f}") print(f" ๐Ÿ’ญ Evaluation reason: {result.reason[:100]}...") - + # Show metadata if row.input_metadata: print(f" ๐Ÿท๏ธ Row ID: {row.input_metadata.row_id}") @@ -230,12 +234,12 @@ def practical_migration_demo(): dataset_info = row.input_metadata.dataset_info print(f" ๐Ÿ“š Dataset: {dataset_info.get('dataset_name', 'N/A')}") print(f" ๐Ÿ“ Row index: {dataset_info.get('row_index', 'N/A')}") - + except Exception as e: print(f" โŒ Evaluation error: {e}") - + print(f"\nโœ… Successfully processed {len(rows)} problems using the new adapter approach!") - + except Exception as e: print(f"โŒ Error in practical demo: {e}") @@ -244,9 +248,9 @@ def performance_comparison(): """Compare performance characteristics of both approaches.""" print("\nโšก Performance Considerations") print("=" * 35) - + import time - + # Time the old approach (if file exists) original_data = load_original_gsm8k_sample() if original_data: @@ -259,7 +263,7 @@ def performance_comparison(): print("๐Ÿ“ Static file not available for timing") old_time = 0 processed_old = 0 - + # Time the new approach try: start_time = time.time() @@ -267,9 +271,9 @@ def performance_comparison(): rows = list(adapter.get_evaluation_rows(split="test", limit=max(6, processed_old))) new_time = time.time() - start_time processed_new = len(rows) - + print(f"๐Ÿค— HuggingFace adapter: {processed_new} items in {new_time:.4f}s") - + if old_time > 0: if new_time > old_time: factor = new_time / old_time @@ -277,11 +281,11 @@ def performance_comparison(): else: factor = old_time / new_time print(f" ๐Ÿ“Š Adapter is {factor:.1f}x faster!") - + print(f"\n๐Ÿ’ก Trade-offs:") print(f" Static file: Fast ({old_time:.4f}s) but limited data ({processed_old} items)") print(f" Adapter: Slower ({new_time:.4f}s) but access to full dataset ({processed_new}+ items)") - + except Exception as e: print(f"โŒ Error timing adapter: {e}") @@ -293,16 +297,16 @@ def main(): print("This example shows how to replace the static GSM8K JSONL file") print("with the dynamic HuggingFace adapter for better data access.") print() - + # Run all demonstrations demonstrate_old_vs_new_approach() show_migration_example() practical_migration_demo() performance_comparison() - - print("\n" + "="*50) + + print("\n" + "=" * 50) print("๐ŸŽฏ MIGRATION SUMMARY") - print("="*50) + print("=" * 50) print("1. โœ… Replace static JSONL with HuggingFace adapter") print("2. โœ… Get access to full GSM8K dataset (8,792 test problems)") print("3. โœ… Automatic conversion to EvaluationRow format") @@ -318,4 +322,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/adapters/huggingface_example.py b/examples/adapters/huggingface_example.py index 2d79eae3..408b621a 100644 --- a/examples/adapters/huggingface_example.py +++ b/examples/adapters/huggingface_example.py @@ -9,10 +9,10 @@ from typing import List from eval_protocol.adapters.huggingface import ( - create_huggingface_adapter, + HuggingFaceAdapter, create_gsm8k_adapter, + create_huggingface_adapter, create_math_adapter, - HuggingFaceAdapter, ) from eval_protocol.models import EvaluationRow @@ -21,44 +21,47 @@ def gsm8k_example(): """Example using the GSM8K dataset.""" print("๐Ÿ“š Example 1: GSM8K Dataset") print("-" * 30) - + try: # Create GSM8K adapter using the convenience method adapter = create_gsm8k_adapter( - split="test", - system_prompt="You are a helpful assistant that solves math problems step by step." + split="test", system_prompt="You are a helpful assistant that solves math problems step by step." ) - + print("โœ… GSM8K adapter created successfully") print(f"๐Ÿ“Š Dataset info: {adapter.get_dataset_info()}") - + # Get a few evaluation rows - rows = list(adapter.get_evaluation_rows( - limit=3, - model_name="gpt-4", - temperature=0.0, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=3, + model_name="gpt-4", + temperature=0.0, + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows from GSM8K test set:") - + for i, row in enumerate(rows): print(f"\n Row {i+1}:") print(f" - ID: {row.input_metadata.row_id if row.input_metadata else 'N/A'}") print(f" - Messages: {len(row.messages)}") - + # Show the math problem user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - problem = user_message.content[:200] + "..." if len(user_message.content) > 200 else user_message.content + problem = ( + user_message.content[:200] + "..." if len(user_message.content) > 200 else user_message.content + ) print(f" - Problem: {problem}") - + # Show ground truth answer if row.ground_truth: answer_preview = row.ground_truth[:100] + "..." if len(row.ground_truth) > 100 else row.ground_truth print(f" - Ground truth: {answer_preview}") - + print() - + except ImportError as e: print(f"โŒ Error: {e}") print("Install HuggingFace dependencies: pip install 'eval-protocol[huggingface]'") @@ -70,42 +73,44 @@ def math_dataset_example(): """Example using the MATH competition dataset.""" print("๐Ÿงฎ Example 2: MATH Competition Dataset") print("-" * 40) - + try: # Create MATH dataset adapter - adapter = create_math_adapter( - system_prompt="You are an expert mathematician. Solve this step by step." - ) - + adapter = create_math_adapter(system_prompt="You are an expert mathematician. Solve this step by step.") + print("โœ… MATH dataset adapter created successfully") print(f"๐Ÿ“Š Dataset info: {adapter.get_dataset_info()}") - + # Get a few examples - rows = list(adapter.get_evaluation_rows( - limit=2, - model_name="gpt-4", - temperature=0.1, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=2, + model_name="gpt-4", + temperature=0.1, + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows from MATH test set:") - + for i, row in enumerate(rows): print(f"\n Row {i+1}:") - + # Show the problem user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - problem = user_message.content[:150] + "..." if len(user_message.content) > 150 else user_message.content + problem = ( + user_message.content[:150] + "..." if len(user_message.content) > 150 else user_message.content + ) print(f" - Problem: {problem}") - + # Show metadata if row.input_metadata and row.input_metadata.dataset_info: dataset_info = row.input_metadata.dataset_info - if 'original_type' in dataset_info: + if "original_type" in dataset_info: print(f" - Problem type: {dataset_info['original_type']}") - if 'original_level' in dataset_info: + if "original_level" in dataset_info: print(f" - Level: {dataset_info['original_level']}") - + except Exception as e: print(f"โŒ Error with MATH dataset: {e}") @@ -114,66 +119,70 @@ def custom_dataset_example(): """Example using a custom dataset with transformation function.""" print("๐Ÿ”ง Example 3: Custom Dataset with Transform Function") print("-" * 55) - + try: # Define transformation function for SQuAD dataset def squad_transform(row): """Transform SQuAD row to evaluation format.""" - context = row['context'] - question = row['question'] - answers = row['answers'] - + context = row["context"] + question = row["question"] + answers = row["answers"] + # Get first answer text - answer_text = answers['text'][0] if answers['text'] else "No answer provided" - + answer_text = answers["text"][0] if answers["text"] else "No answer provided" + return { - 'messages': [ - {'role': 'system', 'content': 'Answer the question based on the given context.'}, - {'role': 'user', 'content': f"Context: {context}\\n\\nQuestion: {question}"}, + "messages": [ + {"role": "system", "content": "Answer the question based on the given context."}, + {"role": "user", "content": f"Context: {context}\\n\\nQuestion: {question}"}, ], - 'ground_truth': answer_text, - 'metadata': { - 'dataset': 'squad', - 'context_length': len(context), - 'question_length': len(question), - 'num_possible_answers': len(answers['text']), - } + "ground_truth": answer_text, + "metadata": { + "dataset": "squad", + "context_length": len(context), + "question_length": len(question), + "num_possible_answers": len(answers["text"]), + }, } - + # Create adapter with transformation function adapter = create_huggingface_adapter( dataset_id="squad", transform_fn=squad_transform, ) - + print("โœ… Custom dataset adapter created successfully") - + # Get dataset info info = adapter.get_dataset_info() print(f"๐Ÿ“Š Dataset info: {info}") - + # Get a few examples - rows = list(adapter.get_evaluation_rows( - split="validation", # SQuAD has train/validation splits - limit=2, - model_name="gpt-3.5-turbo", - )) - + rows = list( + adapter.get_evaluation_rows( + split="validation", # SQuAD has train/validation splits + limit=2, + model_name="gpt-3.5-turbo", + ) + ) + print(f"\nRetrieved {len(rows)} evaluation rows:") - + for i, row in enumerate(rows): print(f"\n Row {i+1}:") print(f" - Messages: {len(row.messages)}") - + # Show question user_message = next((msg for msg in row.messages if msg.role == "user"), None) if user_message: - question = user_message.content[:100] + "..." if len(user_message.content) > 100 else user_message.content + question = ( + user_message.content[:100] + "..." if len(user_message.content) > 100 else user_message.content + ) print(f" - Question: {question}") - + # SQuAD answers are complex, so just show if we have ground truth print(f" - Has ground truth: {'Yes' if row.ground_truth else 'No'}") - + except Exception as e: print(f"โŒ Error with custom dataset: {e}") @@ -182,93 +191,85 @@ def local_file_example(): """Example loading a local dataset file.""" print("๐Ÿ“ Example 4: Local Dataset File") print("-" * 35) - + # Create a sample JSONL file for demonstration sample_file = "/tmp/sample_qa.jsonl" sample_data = [ - { - "id": "q1", - "question": "What is the capital of France?", - "answer": "Paris", - "category": "geography" - }, - { - "id": "q2", - "question": "What is 2 + 2?", - "answer": "4", - "category": "math" - }, + {"id": "q1", "question": "What is the capital of France?", "answer": "Paris", "category": "geography"}, + {"id": "q2", "question": "What is 2 + 2?", "answer": "4", "category": "math"}, { "id": "q3", "question": "Who wrote Romeo and Juliet?", "answer": "William Shakespeare", - "category": "literature" - } + "category": "literature", + }, ] - + try: import json - + # Write sample data - with open(sample_file, 'w') as f: + with open(sample_file, "w") as f: for item in sample_data: - f.write(json.dumps(item) + '\n') - + f.write(json.dumps(item) + "\n") + print(f"๐Ÿ“ Created sample file: {sample_file}") - + # Define transformation function for local data def local_qa_transform(row): - """Transform local Q&A data to evaluation format.""" + """Transform local Q&A data to evaluation format.""" return { - 'messages': [ - {'role': 'system', 'content': 'You are a knowledgeable assistant.'}, - {'role': 'user', 'content': row['question']}, + "messages": [ + {"role": "system", "content": "You are a knowledgeable assistant."}, + {"role": "user", "content": row["question"]}, ], - 'ground_truth': row['answer'], - 'metadata': { - 'id': row.get('id'), - 'category': row.get('category'), - 'dataset': 'local_qa_sample', - } + "ground_truth": row["answer"], + "metadata": { + "id": row.get("id"), + "category": row.get("category"), + "dataset": "local_qa_sample", + }, } - + # Load with adapter adapter = HuggingFaceAdapter.from_local( path=sample_file, transform_fn=local_qa_transform, ) - + print("โœ… Local file adapter created successfully") - + # Get all rows - rows = list(adapter.get_evaluation_rows( - model_name="gpt-3.5-turbo", - temperature=0.0, - )) - + rows = list( + adapter.get_evaluation_rows( + model_name="gpt-3.5-turbo", + temperature=0.0, + ) + ) + print(f"\nLoaded {len(rows)} rows from local file:") - + for i, row in enumerate(rows): print(f"\n Row {i+1}:") - + # Show question and answer user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" - Question: {user_msg.content}") - + if row.ground_truth: print(f" - Answer: {row.ground_truth}") - + # Show original metadata if row.input_metadata and row.input_metadata.dataset_info: - original_data = {k: v for k, v in row.input_metadata.dataset_info.items() if k.startswith('original_')} + original_data = {k: v for k, v in row.input_metadata.dataset_info.items() if k.startswith("original_")} if original_data: print(f" - Original data: {original_data}") - + # Clean up os.remove(sample_file) print(f"\n๐Ÿงน Cleaned up sample file") - + except Exception as e: print(f"โŒ Error with local file: {e}") @@ -277,35 +278,34 @@ def evaluation_integration_example(): """Show how to integrate with evaluation functions.""" print("\n๐Ÿงช Example 5: Integration with Evaluation") print("-" * 45) - + try: # Import evaluation functions - from eval_protocol.rewards.math import math_reward from eval_protocol.rewards.accuracy import accuracy_reward - + from eval_protocol.rewards.math import math_reward + # Create GSM8K adapter adapter = create_gsm8k_adapter(split="test") - + # Get a few rows for evaluation rows = list(adapter.get_evaluation_rows(limit=2)) - + print(f"Running evaluation on {len(rows)} GSM8K problems:") - + for i, row in enumerate(rows): print(f"\n Problem {i+1}:") - + # Show the problem user_msg = next((msg for msg in row.messages if msg.role == "user"), None) if user_msg: print(f" Question: {user_msg.content[:100]}...") - + # For this example, we'll simulate an assistant response # In practice, this would come from your LLM - row.messages.append({ - "role": "assistant", - "content": "Let me solve this step by step... The answer is 42." - }) - + row.messages.append( + {"role": "assistant", "content": "Let me solve this step by step... The answer is 42."} + ) + # Evaluate with math reward if row.ground_truth: try: @@ -315,17 +315,17 @@ def evaluation_integration_example(): ) print(f" Math score: {math_result.score:.2f}") print(f" Reason: {math_result.reason[:100]}...") - + # Also try accuracy reward acc_result = accuracy_reward( messages=row.messages, ground_truth=row.ground_truth, ) print(f" Accuracy score: {acc_result.score:.2f}") - + except Exception as e: print(f" โŒ Evaluation error: {e}") - + except ImportError: print("Evaluation functions not available") except Exception as e: @@ -336,24 +336,26 @@ def batch_processing_example(): """Show how to process datasets in batches.""" print("\n๐Ÿ“ฆ Example 6: Batch Processing") print("-" * 35) - + try: adapter = create_gsm8k_adapter(split="test") - + batch_size = 5 total_processed = 0 - + print(f"Processing GSM8K test set in batches of {batch_size}:") - + # Process in batches for batch_start in range(0, 20, batch_size): # Process first 20 items - batch_rows = list(adapter.get_evaluation_rows( - limit=batch_size, - offset=batch_start, - )) - + batch_rows = list( + adapter.get_evaluation_rows( + limit=batch_size, + offset=batch_start, + ) + ) + print(f" Batch {batch_start//batch_size + 1}: {len(batch_rows)} rows") - + # Process each row in the batch for row in batch_rows: # Here you would typically: @@ -361,9 +363,9 @@ def batch_processing_example(): # 2. Evaluate the response # 3. Store results total_processed += 1 - + print(f"โœ… Processed {total_processed} rows total") - + except Exception as e: print(f"โŒ Error in batch processing: {e}") @@ -372,40 +374,40 @@ def main(): """Run all examples.""" print("๐Ÿค— HuggingFace Dataset Adapter Examples") print("=" * 50) - + # Run examples gsm8k_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + math_dataset_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + custom_dataset_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + local_file_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + evaluation_integration_example() - print("\n" + "="*50 + "\n") - + print("\n" + "=" * 50 + "\n") + batch_processing_example() if __name__ == "__main__": try: main() - + print("\nโœ… All examples completed!") print("\nNext steps:") print("1. Choose the dataset that fits your needs") - print("2. Customize the system prompts for your use case") + print("2. Customize the system prompts for your use case") print("3. Integrate with your evaluation pipeline") print("4. Scale up to process full datasets") print("5. Use the EvaluationRow data for training or evaluation") - + except ImportError as e: print(f"โŒ Missing dependencies: {e}") print("Install with: pip install 'eval-protocol[huggingface]'") except Exception as e: - print(f"โŒ Error running examples: {e}") \ No newline at end of file + print(f"โŒ Error running examples: {e}") diff --git a/examples/adapters/langfuse_example.py b/examples/adapters/langfuse_example.py index 78937c80..e55d740b 100644 --- a/examples/adapters/langfuse_example.py +++ b/examples/adapters/langfuse_example.py @@ -15,16 +15,16 @@ def main(): """Example usage of the Langfuse adapter.""" - + # Configuration - you can set these as environment variables public_key = os.getenv("LANGFUSE_PUBLIC_KEY", "your_public_key_here") - secret_key = os.getenv("LANGFUSE_SECRET_KEY", "your_secret_key_here") + secret_key = os.getenv("LANGFUSE_SECRET_KEY", "your_secret_key_here") host = os.getenv("LANGFUSE_HOST", "https://langfuse-web-prod-zfdbl7ykrq-uc.a.run.app") project_id = os.getenv("LANGFUSE_PROJECT_ID", "cmdj5yxhk0006s6022cyi0prv") - + print(f"Connecting to Langfuse at: {host}") print(f"Project ID: {project_id}\n") - + # Create the adapter try: adapter = create_langfuse_adapter( @@ -41,16 +41,18 @@ def main(): except Exception as e: print(f"โŒ Failed to create adapter: {e}") return - + # Example 1: Get recent evaluation rows print("\n๐Ÿ“Š Example 1: Get recent evaluation rows") try: - rows = list(adapter.get_evaluation_rows( - limit=5, - from_timestamp=datetime.now() - timedelta(days=7), - include_tool_calls=True, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=5, + from_timestamp=datetime.now() - timedelta(days=7), + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} evaluation rows") for i, row in enumerate(rows): print(f" Row {i+1}:") @@ -58,74 +60,80 @@ def main(): print(f" - Messages: {len(row.messages)}") print(f" - Has tools: {'Yes' if row.tools else 'No'}") print(f" - Ground truth: {'Yes' if row.ground_truth else 'No'}") - + # Show first message content (truncated) if row.messages: content = row.messages[0].content or "" preview = content[:100] + "..." if len(content) > 100 else content print(f" - First message: {preview}") print() - + except Exception as e: print(f"โŒ Error retrieving rows: {e}") - + # Example 2: Filter by specific criteria print("\n๐Ÿ” Example 2: Filter by specific criteria") try: - rows = list(adapter.get_evaluation_rows( - limit=3, - tags=["production"], # Filter by tags if available - include_tool_calls=True, - )) - + rows = list( + adapter.get_evaluation_rows( + limit=3, + tags=["production"], # Filter by tags if available + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} rows with 'production' tag") - + except Exception as e: print(f"โŒ Error with filtered query: {e}") - + # Example 3: Get specific traces by ID print("\n๐ŸŽฏ Example 3: Get specific traces by ID") try: # Replace with actual trace IDs from your Langfuse deployment trace_ids = ["trace_id_1", "trace_id_2"] # These would be real IDs - - rows = list(adapter.get_evaluation_rows_by_ids( - trace_ids=trace_ids, - include_tool_calls=True, - )) - + + rows = list( + adapter.get_evaluation_rows_by_ids( + trace_ids=trace_ids, + include_tool_calls=True, + ) + ) + print(f"Retrieved {len(rows)} rows by specific IDs") - + except Exception as e: print(f"โŒ Error retrieving specific traces: {e}") - + # Example 4: Extract different types of conversations print("\n๐Ÿ’ฌ Example 4: Analyze conversation types") try: rows = list(adapter.get_evaluation_rows(limit=10, include_tool_calls=True)) - + chat_only = [] tool_calling = [] - + for row in rows: - if row.tools and any(msg.tool_calls for msg in row.messages if hasattr(msg, 'tool_calls') and msg.tool_calls): + if row.tools and any( + msg.tool_calls for msg in row.messages if hasattr(msg, "tool_calls") and msg.tool_calls + ): tool_calling.append(row) else: chat_only.append(row) - + print(f"Chat-only conversations: {len(chat_only)}") print(f"Tool calling conversations: {len(tool_calling)}") - + # Show example of tool calling conversation if tool_calling: row = tool_calling[0] print(f"\n๐Ÿ”ง Example tool calling conversation:") for i, msg in enumerate(row.messages): print(f" {i+1}. {msg.role}: {msg.content[:50] if msg.content else '[No content]'}...") - if hasattr(msg, 'tool_calls') and msg.tool_calls: + if hasattr(msg, "tool_calls") and msg.tool_calls: for tool_call in msg.tool_calls: print(f" ๐Ÿ›  Tool call: {tool_call}") - + except Exception as e: print(f"โŒ Error analyzing conversation types: {e}") @@ -133,11 +141,11 @@ def main(): def demonstrate_evaluation_integration(): """Show how to use Langfuse data with evaluation functions.""" print("\n๐Ÿงช Integration with Evaluation Functions") - + # This would typically be in a separate evaluation script try: from eval_protocol.rewards.math import math_reward - + # Create adapter (reuse configuration from main example) adapter = create_langfuse_adapter( public_key=os.getenv("LANGFUSE_PUBLIC_KEY", "your_public_key_here"), @@ -145,13 +153,13 @@ def demonstrate_evaluation_integration(): host=os.getenv("LANGFUSE_HOST", "https://langfuse-web-prod-zfdbl7ykrq-uc.a.run.app"), project_id=os.getenv("LANGFUSE_PROJECT_ID", "cmdj5yxhk0006s6022cyi0prv"), ) - + # Get data and evaluate rows = list(adapter.get_evaluation_rows(limit=3)) - + for i, row in enumerate(rows): print(f"\nEvaluating row {i+1}:") - + # Only evaluate if we have ground truth if row.ground_truth: try: @@ -165,7 +173,7 @@ def demonstrate_evaluation_integration(): print(f" โŒ Evaluation failed: {e}") else: print(f" โš ๏ธ No ground truth available for evaluation") - + except ImportError: print("Math reward function not available") except Exception as e: @@ -175,25 +183,27 @@ def demonstrate_evaluation_integration(): if __name__ == "__main__": print("๐Ÿš€ Langfuse Adapter Example") print("=" * 50) - + # Check if credentials are set - if not all([ - os.getenv("LANGFUSE_PUBLIC_KEY"), - os.getenv("LANGFUSE_SECRET_KEY"), - ]): + if not all( + [ + os.getenv("LANGFUSE_PUBLIC_KEY"), + os.getenv("LANGFUSE_SECRET_KEY"), + ] + ): print("โš ๏ธ To run this example with real data, set environment variables:") print(" export LANGFUSE_PUBLIC_KEY='your_public_key'") print(" export LANGFUSE_SECRET_KEY='your_secret_key'") print(" export LANGFUSE_HOST='your_langfuse_host' # optional") print(" export LANGFUSE_PROJECT_ID='your_project_id' # optional") print() - + main() demonstrate_evaluation_integration() - + print("\nโœ… Example completed!") print("\nNext steps:") print("1. Set up your Langfuse credentials") print("2. Modify the filters and parameters to match your data") print("3. Integrate with your evaluation pipeline") - print("4. Use the converted EvaluationRow data for training or evaluation") \ No newline at end of file + print("4. Use the converted EvaluationRow data for training or evaluation") diff --git a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py index b77b7daa..55a44e23 100644 --- a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py +++ b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py @@ -1113,7 +1113,7 @@ async def test_static_policy_functionality(): # Test action generation for step in range(6): actions = await policy( - tool_schemas=[[], []], + tool_schemas=[[], []], # type: ignore[list-item] observations=[None, None], system_prompts=["Test system prompt 1", "Test system prompt 2"], user_prompts=["Test user prompt 1", "Test user prompt 2"], diff --git a/mypy.ini b/mypy.ini index 182ca82b..7caca4bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -13,6 +13,7 @@ no_implicit_optional = True strict_optional = True ignore_missing_imports = True disable_error_code = import-not-found, truthy-function, no-redef, assignment, union-attr, attr-defined, arg-type, method-assign, misc, return-value, var-annotated, operator, call-arg, index +exclude = ^vendor/ [mypy.plugins.pydantic.*] follow_imports = skip diff --git a/scripts/fix_whitespace.py b/scripts/fix_whitespace.py new file mode 100644 index 00000000..d8cdaa6f --- /dev/null +++ b/scripts/fix_whitespace.py @@ -0,0 +1,76 @@ +import os + +TEXT_EXTS = { + ".py", + ".md", + ".txt", + ".yml", + ".yaml", + ".toml", + ".ini", + ".cfg", + ".json", + ".rst", +} + + +def is_text_file(path: str) -> bool: + _, ext = os.path.splitext(path) + return ext.lower() in TEXT_EXTS + + +def fix_file(path: str) -> bool: + try: + with open(path, "rb") as f: + raw = f.read() + except Exception: + return False + + try: + text = raw.decode("utf-8") + except UnicodeDecodeError: + return False + + lines = text.splitlines() + changed = False + + # Strip trailing whitespace on each line + new_lines = [] + for line in lines: + new_line = line.rstrip() + if new_line != line: + changed = True + new_lines.append(new_line) + + # Ensure newline at EOF + new_text = "\n".join(new_lines) + "\n" + if not text.endswith("\n"): + changed = True + + if changed: + with open(path, "wb") as f: + f.write(new_text.encode("utf-8")) + return changed + + +def main(): + root = os.getcwd() + total = 0 + changed = 0 + for dirpath, dirnames, filenames in os.walk(root): + # Skip virtualenvs and build artifacts + parts = dirpath.split(os.sep) + if any(p in {".git", ".venv", "build", "dist", "node_modules"} for p in parts): + continue + for fn in filenames: + path = os.path.join(dirpath, fn) + if is_text_file(path): + total += 1 + if fix_file(path): + changed += 1 + print(f"Scanned {total} files; normalized whitespace in {changed} files") + + +if __name__ == "__main__": + main() + diff --git a/tests/conftest.py b/tests/conftest.py index 6a3526a7..9c93cbf8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import sys from pathlib import Path + import pytest # Add the project root to the Python path diff --git a/tests/pytest/data/basic_coding_dataset.jsonl b/tests/pytest/data/basic_coding_dataset.jsonl index 27573c1b..fc25abcd 100644 --- a/tests/pytest/data/basic_coding_dataset.jsonl +++ b/tests/pytest/data/basic_coding_dataset.jsonl @@ -7,4 +7,4 @@ {"prompt": "Write a Python function `multiply_by_two` that takes an integer and returns the integer multiplied by 2.", "input": "10", "expected_output": "20"} {"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "[1, 2, 3]", "expected_output": "3"} {"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "[]", "expected_output": "0"} -{"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "['a', 'b', 'c', 'd']", "expected_output": "4"} \ No newline at end of file +{"prompt": "Write a Python function `get_length` that takes a list and returns its length.", "input": "['a', 'b', 'c', 'd']", "expected_output": "4"} diff --git a/tests/pytest/data/lunar_lander_dataset.jsonl b/tests/pytest/data/lunar_lander_dataset.jsonl index af396fc1..a3de90c6 100644 --- a/tests/pytest/data/lunar_lander_dataset.jsonl +++ b/tests/pytest/data/lunar_lander_dataset.jsonl @@ -1,3 +1,3 @@ {"id": "multi_env_test_001", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -10.0, "enable_wind": false, "seed": 42}} {"id": "multi_env_test_002", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -8.0, "enable_wind": false, "seed": 123}} -{"id": "multi_env_test_003", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -12.0, "enable_wind": false, "seed": 456}} \ No newline at end of file +{"id": "multi_env_test_003", "system_prompt": "You are controlling a lunar lander spacecraft. Use the lander_action tool with actions: NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT. Your goal is to land safely on the moon between the two flags without crashing.", "user_prompt_template": "Current state: {observation}. First, describe what is in the image attached and analyze the current state. You MUST explain your reasoning in picking the next best action (NOTHING, FIRE_LEFT, FIRE_MAIN, FIRE_RIGHT) and call lander_action tool with it to land the spacecraft.", "environment_context": {"game": "LunarLander", "continuous": false, "gravity": -12.0, "enable_wind": false, "seed": 456}} diff --git a/tests/pytest/helper/word_count_to_evaluation_row.py b/tests/pytest/helper/word_count_to_evaluation_row.py index f0517dd0..dbb05cc4 100644 --- a/tests/pytest/helper/word_count_to_evaluation_row.py +++ b/tests/pytest/helper/word_count_to_evaluation_row.py @@ -7,8 +7,7 @@ def word_count_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationR """Convert gsm8k dataset format to EvaluationRow for word_count evaluation.""" return [ EvaluationRow( - messages=[Message(role="user", content=row["user_query"])], - ground_truth=row["ground_truth_for_eval"] + messages=[Message(role="user", content=row["user_query"])], ground_truth=row["ground_truth_for_eval"] ) for row in data - ] \ No newline at end of file + ] diff --git a/tests/pytest/test_apps_coding.py b/tests/pytest/test_apps_coding.py index 4780388a..1b2be188 100644 --- a/tests/pytest/test_apps_coding.py +++ b/tests/pytest/test_apps_coding.py @@ -18,10 +18,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio Convert entries from APPS dataset to EvaluationRow objects. """ return [ - EvaluationRow( - messages=[Message(role="user", content=row["question"])], - ground_truth=row["input_output"] - ) + EvaluationRow(messages=[Message(role="user", content=row["question"])], ground_truth=row["input_output"]) for row in data ] @@ -42,7 +39,7 @@ def test_apps_code_evaluation(row: EvaluationRow) -> EvaluationRow: Args: row: EvaluationRow containing the conversation messages and ground_truth as JSON string - + Returns: EvaluationRow with the evaluation result """ @@ -51,8 +48,8 @@ def test_apps_code_evaluation(row: EvaluationRow) -> EvaluationRow: messages=row.messages, ground_truth=row.ground_truth, ) - + # Set the evaluation result on the row row.evaluation_result = result - - return row \ No newline at end of file + + return row diff --git a/tests/pytest/test_basic_coding.py b/tests/pytest/test_basic_coding.py index 35d1a1b3..7072665b 100644 --- a/tests/pytest/test_basic_coding.py +++ b/tests/pytest/test_basic_coding.py @@ -9,7 +9,7 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, Message from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test -from eval_protocol.rewards.code_execution import extract_code_blocks, execute_python_code +from eval_protocol.rewards.code_execution import execute_python_code, extract_code_blocks def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]: @@ -18,8 +18,8 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat """ return [ EvaluationRow( - messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")], - ground_truth=row["expected_output"] + messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")], + ground_truth=row["expected_output"], ) for row in data ] @@ -38,16 +38,16 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat def test_coding_code_evaluation(row: EvaluationRow) -> EvaluationRow: """ Evaluation function that tests code correctness by executing it locally. - + This function: 1. Extracts Python code from the assistant's response 2. Executes the code locally with timeout=10 3. Compares the output to ground_truth 4. Returns a score of 1.0 if output matches, 0.0 otherwise - + Args: row: EvaluationRow containing the conversation messages and expected_output in ground_truth - + Returns: EvaluationRow with the evaluation result """ @@ -55,38 +55,34 @@ def test_coding_code_evaluation(row: EvaluationRow) -> EvaluationRow: if len(row.messages) < 2 or row.messages[-1].role != "assistant": row.evaluation_result = EvaluateResult(score=0.0, reason="No assistant response found") return row - + assistant_content = row.messages[-1].content or "" expected_output = (row.ground_truth or "").strip() - + # Extract Python code blocks code_blocks = extract_code_blocks(assistant_content, language="python") if not code_blocks: row.evaluation_result = EvaluateResult(score=0.0, reason="No Python code block found") return row - + code = code_blocks[0]["code"] - + # Execute the code locally execution_result = execute_python_code(code, timeout=10) - + if not execution_result.get("success", False): error_msg = execution_result.get("error", "Code execution failed") row.evaluation_result = EvaluateResult(score=0.0, reason=f"Execution error: {error_msg}") return row - + # Compare output with expected actual_output = (execution_result.get("output", "") or "").strip() - + if actual_output == expected_output: - row.evaluation_result = EvaluateResult( - score=1.0, - reason=f"โœ… Output matches: '{actual_output}'" - ) + row.evaluation_result = EvaluateResult(score=1.0, reason=f"โœ… Output matches: '{actual_output}'") else: row.evaluation_result = EvaluateResult( - score=0.0, - reason=f"โŒ Expected: '{expected_output}', Got: '{actual_output}'" + score=0.0, reason=f"โŒ Expected: '{expected_output}', Got: '{actual_output}'" ) - + return row diff --git a/tests/pytest/test_frozen_lake.py b/tests/pytest/test_frozen_lake.py index 69f0c400..6f7afbdd 100644 --- a/tests/pytest/test_frozen_lake.py +++ b/tests/pytest/test_frozen_lake.py @@ -5,10 +5,9 @@ similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system. """ - from typing import Any, Dict, List -from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult +from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor @@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation Convert entries from frozen lake dataset to EvaluationRow objects. """ rows = [] - + for row in data: eval_row = EvaluationRow( messages=[Message(role="system", content=row["system_prompt"])], @@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation dataset_info={ "environment_context": row["environment_context"], "user_prompt_template": row["user_prompt_template"], - } - ) + }, + ), ) - + rows.append(eval_row) - + return rows + @evaluation_test( input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"], dataset_adapter=frozen_lake_to_evaluation_row, @@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: """ Test frozen lake evaluation using the pytest framework. - + This test evaluates how well the model can navigate the FrozenLake environment by checking if it successfully reaches the goal while avoiding holes. - + Args: row: EvaluationRow object from frozen lake dataset - + Returns: EvaluationRow object with evaluation results """ @@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: score=score, reason=reason, ) - + return row diff --git a/tests/pytest/test_lunar_lander.py b/tests/pytest/test_lunar_lander.py index 896adc49..f3f5e9e3 100644 --- a/tests/pytest/test_lunar_lander.py +++ b/tests/pytest/test_lunar_lander.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List -from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams +from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor @@ -17,7 +17,7 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio Convert entries from lunar lander dataset to EvaluationRow objects. """ rows = [] - + for row in data: eval_row = EvaluationRow( messages=[Message(role="system", content=row["system_prompt"])], @@ -26,12 +26,12 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio dataset_info={ "environment_context": row["environment_context"], "user_prompt_template": row["user_prompt_template"], - } - ) + }, + ), ) - + rows.append(eval_row) - + return rows @@ -51,24 +51,28 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio def test_lunar_lander_evaluation(row: EvaluationRow) -> EvaluationRow: """ Test lunar lander evaluation using the pytest framework. - + This test evaluates how well the model can control the lunar lander to achieve a successful landing by checking the final reward and termination status. - + Args: row: EvaluationRow object from lunar lander dataset - + Returns: EvaluationRow object with evaluation results """ score = row.get_total_reward() evaluation_score = 1.0 if score >= 200 else 0.0 - reason = f"โœ… Successful landing with reward {score:.2f}" if score >= 200 else f"โŒ Failed landing with reward {score:.2f}" + reason = ( + f"โœ… Successful landing with reward {score:.2f}" + if score >= 200 + else f"โŒ Failed landing with reward {score:.2f}" + ) row.evaluation_result = EvaluateResult( score=evaluation_score, reason=reason, ) - - return row \ No newline at end of file + + return row diff --git a/tests/pytest/test_pytest_default_agent_rollout_processor.py b/tests/pytest/test_pytest_default_agent_rollout_processor.py index 06762046..e02ff487 100644 --- a/tests/pytest/test_pytest_default_agent_rollout_processor.py +++ b/tests/pytest/test_pytest_default_agent_rollout_processor.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from eval_protocol.models import Message, EvaluationRow +from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test diff --git a/tests/pytest/test_pytest_function_calling.py b/tests/pytest/test_pytest_function_calling.py index 7239de58..ee300857 100644 --- a/tests/pytest/test_pytest_function_calling.py +++ b/tests/pytest/test_pytest_function_calling.py @@ -1,5 +1,6 @@ import json from typing import Any, Dict, List + from eval_protocol.models import EvaluationRow from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test from eval_protocol.rewards.function_calling import exact_tool_match_reward diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index c1b643d0..1cbe007e 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -1,6 +1,6 @@ from typing import List -from eval_protocol.models import Message, EvaluationRow +from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test diff --git a/tests/pytest/test_pytest_json_schema.py b/tests/pytest/test_pytest_json_schema.py index 8463f873..7548767c 100644 --- a/tests/pytest/test_pytest_json_schema.py +++ b/tests/pytest/test_pytest_json_schema.py @@ -1,5 +1,6 @@ import json from typing import Any, Dict, List + from eval_protocol.models import EvaluationRow from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test from eval_protocol.rewards.json_schema import json_schema_reward diff --git a/tests/pytest/test_pytest_mcp_url.py b/tests/pytest/test_pytest_mcp_url.py index 2a1c1cfc..4cdaf759 100644 --- a/tests/pytest/test_pytest_mcp_url.py +++ b/tests/pytest/test_pytest_mcp_url.py @@ -1,4 +1,4 @@ -from eval_protocol.models import EvaluateResult, Message, EvaluationRow +from eval_protocol.models import EvaluateResult, EvaluationRow, Message from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test diff --git a/tests/test_adapters_e2e.py b/tests/test_adapters_e2e.py index a598dff7..6ece42ce 100644 --- a/tests/test_adapters_e2e.py +++ b/tests/test_adapters_e2e.py @@ -6,31 +6,34 @@ """ import os -import pytest from datetime import datetime, timedelta -from typing import Dict, Any +from typing import Any, Dict + +import pytest -from eval_protocol.models import EvaluationRow, Message, InputMetadata +from eval_protocol.models import EvaluationRow, InputMetadata, Message class TestLangfuseAdapterE2E: """End-to-end tests for Langfuse adapter with real deployment.""" - + def _get_langfuse_credentials(self): """Get Langfuse credentials from environment.""" public_key = os.getenv("LANGFUSE_PUBLIC_KEY") secret_key = os.getenv("LANGFUSE_SECRET_KEY") host = os.getenv("LANGFUSE_HOST", "https://langfuse-web-prod-zfdbl7ykrq-uc.a.run.app") project_id = os.getenv("LANGFUSE_PROJECT_ID", "cmdj5yxhk0006s6022cyi0prv") - + return public_key, secret_key, host, project_id - + @pytest.mark.skipif( - not all([ - os.getenv("LANGFUSE_PUBLIC_KEY"), - os.getenv("LANGFUSE_SECRET_KEY"), - ]), - reason="Langfuse credentials not available in environment" + not all( + [ + os.getenv("LANGFUSE_PUBLIC_KEY"), + os.getenv("LANGFUSE_SECRET_KEY"), + ] + ), + reason="Langfuse credentials not available in environment", ) def test_langfuse_adapter_real_connection(self): """Test that we can connect to real Langfuse deployment and pull data.""" @@ -38,9 +41,9 @@ def test_langfuse_adapter_real_connection(self): from eval_protocol.adapters.langfuse import create_langfuse_adapter except ImportError: pytest.skip("Langfuse dependencies not installed") - + public_key, secret_key, host, project_id = self._get_langfuse_credentials() - + # Create adapter adapter = create_langfuse_adapter( public_key=public_key, @@ -48,40 +51,47 @@ def test_langfuse_adapter_real_connection(self): host=host, project_id=project_id, ) - + # Test basic connection by trying to get a small number of traces rows = list(adapter.get_evaluation_rows(limit=3)) - + # Verify we got some data assert isinstance(rows, list), "Should return a list of rows" print(f"Retrieved {len(rows)} evaluation rows from Langfuse") - + # Verify each row is properly formatted for i, row in enumerate(rows): assert isinstance(row, EvaluationRow), f"Row {i} should be EvaluationRow" assert isinstance(row.messages, list), f"Row {i} should have messages list" assert len(row.messages) > 0, f"Row {i} should have at least one message" - + # Verify messages are properly formatted for j, msg in enumerate(row.messages): assert isinstance(msg, Message), f"Row {i} message {j} should be Message object" - assert hasattr(msg, 'role'), f"Row {i} message {j} should have role" - assert msg.role in ['user', 'assistant', 'system', 'tool'], f"Row {i} message {j} has invalid role: {msg.role}" - + assert hasattr(msg, "role"), f"Row {i} message {j} should have role" + assert msg.role in [ + "user", + "assistant", + "system", + "tool", + ], f"Row {i} message {j} has invalid role: {msg.role}" + # Verify metadata if row.input_metadata: assert isinstance(row.input_metadata, InputMetadata), f"Row {i} should have InputMetadata" assert row.input_metadata.row_id, f"Row {i} should have row_id" print(f" Row {i}: ID={row.input_metadata.row_id}, Messages={len(row.messages)}") - + print(f" Row {i}: {len(row.messages)} messages, Tools={'Yes' if row.tools else 'No'}") - + @pytest.mark.skipif( - not all([ - os.getenv("LANGFUSE_PUBLIC_KEY"), - os.getenv("LANGFUSE_SECRET_KEY"), - ]), - reason="Langfuse credentials not available" + not all( + [ + os.getenv("LANGFUSE_PUBLIC_KEY"), + os.getenv("LANGFUSE_SECRET_KEY"), + ] + ), + reason="Langfuse credentials not available", ) def test_langfuse_adapter_with_filters(self): """Test Langfuse adapter with various filters.""" @@ -89,46 +99,52 @@ def test_langfuse_adapter_with_filters(self): from eval_protocol.adapters.langfuse import create_langfuse_adapter except ImportError: pytest.skip("Langfuse dependencies not installed") - + public_key, secret_key, host, project_id = self._get_langfuse_credentials() - + adapter = create_langfuse_adapter( public_key=public_key, secret_key=secret_key, host=host, project_id=project_id, ) - + # Test with time filter (last 7 days) - recent_rows = list(adapter.get_evaluation_rows( - limit=5, - from_timestamp=datetime.now() - timedelta(days=7), - include_tool_calls=True, - )) - + recent_rows = list( + adapter.get_evaluation_rows( + limit=5, + from_timestamp=datetime.now() - timedelta(days=7), + include_tool_calls=True, + ) + ) + print(f"Recent rows (last 7 days): {len(recent_rows)}") - + # Verify tool calling data is preserved tool_calling_rows = [row for row in recent_rows if row.tools] print(f"Rows with tool definitions: {len(tool_calling_rows)}") - + # Test specific filtering try: # This might not return data if no traces match, which is fine - tagged_rows = list(adapter.get_evaluation_rows( - limit=2, - tags=["production"], # May not exist, that's OK - )) + tagged_rows = list( + adapter.get_evaluation_rows( + limit=2, + tags=["production"], # May not exist, that's OK + ) + ) print(f"Tagged rows: {len(tagged_rows)}") except Exception as e: print(f"Tagged query failed (expected if no tags): {e}") - + @pytest.mark.skipif( - not all([ - os.getenv("LANGFUSE_PUBLIC_KEY"), - os.getenv("LANGFUSE_SECRET_KEY"), - ]), - reason="Langfuse credentials not available" + not all( + [ + os.getenv("LANGFUSE_PUBLIC_KEY"), + os.getenv("LANGFUSE_SECRET_KEY"), + ] + ), + reason="Langfuse credentials not available", ) def test_langfuse_conversation_analysis(self): """Test analysis of conversation types from Langfuse.""" @@ -136,51 +152,51 @@ def test_langfuse_conversation_analysis(self): from eval_protocol.adapters.langfuse import create_langfuse_adapter except ImportError: pytest.skip("Langfuse dependencies not installed") - + public_key, secret_key, host, project_id = self._get_langfuse_credentials() - + adapter = create_langfuse_adapter( public_key=public_key, secret_key=secret_key, host=host, project_id=project_id, ) - + # Get more data for analysis rows = list(adapter.get_evaluation_rows(limit=10, include_tool_calls=True)) - + # Analyze conversation patterns chat_only = [] tool_calling = [] multi_turn = [] - + for row in rows: # Check for tool calling has_tools = ( - row.tools or - any(hasattr(msg, 'tool_calls') and msg.tool_calls for msg in row.messages) or - any(msg.role == 'tool' for msg in row.messages) + row.tools + or any(hasattr(msg, "tool_calls") and msg.tool_calls for msg in row.messages) + or any(msg.role == "tool" for msg in row.messages) ) - + if has_tools: tool_calling.append(row) else: chat_only.append(row) - + # Check for multi-turn conversations if len(row.messages) > 2: # More than user + assistant multi_turn.append(row) - + print(f"Analysis of {len(rows)} conversations:") print(f" Chat-only: {len(chat_only)}") - print(f" Tool calling: {len(tool_calling)}") + print(f" Tool calling: {len(tool_calling)}") print(f" Multi-turn: {len(multi_turn)}") - + # Show example of each type if available if chat_only: row = chat_only[0] print(f" Example chat: {len(row.messages)} messages") - + if tool_calling: row = tool_calling[0] print(f" Example tool calling: {len(row.messages)} messages, {len(row.tools or [])} tools") @@ -188,165 +204,168 @@ def test_langfuse_conversation_analysis(self): class TestHuggingFaceAdapterE2E: """End-to-end tests for HuggingFace adapter with real datasets.""" - + def test_gsm8k_adapter_real_data(self): """Test loading real GSM8K data and converting to EvaluationRow.""" try: from eval_protocol.adapters.huggingface import create_huggingface_adapter except ImportError: pytest.skip("HuggingFace dependencies not installed") - + def gsm8k_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Transform GSM8K row to our format.""" return { - 'messages': [ - {'role': 'system', 'content': 'You are a helpful assistant that solves math problems step by step.'}, - {'role': 'user', 'content': row['question']}, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that solves math problems step by step.", + }, + {"role": "user", "content": row["question"]}, ], - 'ground_truth': row['answer'], - 'metadata': { - 'dataset': 'gsm8k', - 'original_question': row['question'], - 'original_answer': row['answer'], - } + "ground_truth": row["answer"], + "metadata": { + "dataset": "gsm8k", + "original_question": row["question"], + "original_answer": row["answer"], + }, } - + # Create adapter with transform function adapter = create_huggingface_adapter( dataset_id="gsm8k", config_name="main", transform_fn=gsm8k_transform, ) - + # Test loading data rows = list(adapter.get_evaluation_rows(split="test", limit=5)) - + # Verify we got data assert len(rows) > 0, "Should retrieve some GSM8K data" print(f"Retrieved {len(rows)} GSM8K evaluation rows") - + # Verify each row is properly formatted for i, row in enumerate(rows): assert isinstance(row, EvaluationRow), f"Row {i} should be EvaluationRow" assert isinstance(row.messages, list), f"Row {i} should have messages" assert len(row.messages) >= 2, f"Row {i} should have system + user messages" - + # Check system prompt system_msg = row.messages[0] - assert system_msg.role == 'system', f"Row {i} first message should be system" - assert 'math problems' in system_msg.content.lower(), f"Row {i} should have math system prompt" - + assert system_msg.role == "system", f"Row {i} first message should be system" + assert "math problems" in system_msg.content.lower(), f"Row {i} should have math system prompt" + # Check user question user_msg = row.messages[1] - assert user_msg.role == 'user', f"Row {i} second message should be user" + assert user_msg.role == "user", f"Row {i} second message should be user" assert len(user_msg.content) > 0, f"Row {i} should have non-empty question" - + # Check ground truth assert row.ground_truth, f"Row {i} should have ground truth answer" - + # Check metadata assert row.input_metadata, f"Row {i} should have metadata" assert row.input_metadata.dataset_info, f"Row {i} should have dataset info" - + print(f" Row {i}: Question length={len(user_msg.content)}, Answer length={len(row.ground_truth)}") - + def test_math_dataset_real_data(self): """Test loading real MATH competition dataset.""" try: from eval_protocol.adapters.huggingface import create_huggingface_adapter except ImportError: pytest.skip("HuggingFace dependencies not installed") - + def math_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Transform MATH dataset row.""" return { - 'messages': [ - {'role': 'system', 'content': 'You are an expert mathematician. Solve this step by step.'}, - {'role': 'user', 'content': row['problem']}, + "messages": [ + {"role": "system", "content": "You are an expert mathematician. Solve this step by step."}, + {"role": "user", "content": row["problem"]}, ], - 'ground_truth': row['solution'], - 'metadata': { - 'dataset': 'hendrycks_math', - 'type': row.get('type', 'unknown'), - 'level': row.get('level', 'unknown'), - 'original_problem': row['problem'], - 'original_solution': row['solution'], - } + "ground_truth": row["solution"], + "metadata": { + "dataset": "hendrycks_math", + "type": row.get("type", "unknown"), + "level": row.get("level", "unknown"), + "original_problem": row["problem"], + "original_solution": row["solution"], + }, } - + # Create adapter adapter = create_huggingface_adapter( dataset_id="SuperSecureHuman/competition_math_hf_dataset", transform_fn=math_transform, ) - + # Test loading data rows = list(adapter.get_evaluation_rows(split="test", limit=3)) - + # Verify data assert len(rows) > 0, "Should retrieve MATH dataset data" print(f"Retrieved {len(rows)} MATH dataset evaluation rows") - + for i, row in enumerate(rows): assert isinstance(row, EvaluationRow), f"Row {i} should be EvaluationRow" assert len(row.messages) >= 2, f"Row {i} should have system + user messages" assert row.ground_truth, f"Row {i} should have solution" - + # Check for MATH-specific metadata dataset_info = row.input_metadata.dataset_info - assert 'type' in dataset_info, f"Row {i} should have problem type" - assert 'level' in dataset_info, f"Row {i} should have difficulty level" - + assert "type" in dataset_info, f"Row {i} should have problem type" + assert "level" in dataset_info, f"Row {i} should have difficulty level" + print(f" Row {i}: Type={dataset_info.get('type')}, Level={dataset_info.get('level')}") - + def test_custom_dataset_transform(self): """Test adapter with a completely custom transformation.""" try: from eval_protocol.adapters.huggingface import create_huggingface_adapter except ImportError: pytest.skip("HuggingFace dependencies not installed") - + def squad_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Custom transform for SQuAD dataset.""" - context = row['context'] - question = row['question'] - answers = row['answers'] - + context = row["context"] + question = row["question"] + answers = row["answers"] + # Get first answer - answer_text = answers['text'][0] if answers['text'] else "No answer" - + answer_text = answers["text"][0] if answers["text"] else "No answer" + return { - 'messages': [ - {'role': 'system', 'content': 'Answer the question based on the given context.'}, - {'role': 'user', 'content': f"Context: {context}\n\nQuestion: {question}"}, + "messages": [ + {"role": "system", "content": "Answer the question based on the given context."}, + {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}, ], - 'ground_truth': answer_text, - 'metadata': { - 'dataset': 'squad', - 'context_length': len(context), - 'question_length': len(question), - 'num_answers': len(answers['text']), - } + "ground_truth": answer_text, + "metadata": { + "dataset": "squad", + "context_length": len(context), + "question_length": len(question), + "num_answers": len(answers["text"]), + }, } - + # Create adapter for SQuAD adapter = create_huggingface_adapter( dataset_id="squad", transform_fn=squad_transform, ) - + # Test loading rows = list(adapter.get_evaluation_rows(split="validation", limit=2)) - + assert len(rows) > 0, "Should retrieve SQuAD data" print(f"Retrieved {len(rows)} SQuAD evaluation rows") - + for i, row in enumerate(rows): assert isinstance(row, EvaluationRow), f"Row {i} should be EvaluationRow" - user_msg = next(msg for msg in row.messages if msg.role == 'user') - assert 'Context:' in user_msg.content, f"Row {i} should have context" - assert 'Question:' in user_msg.content, f"Row {i} should have question" - + user_msg = next(msg for msg in row.messages if msg.role == "user") + assert "Context:" in user_msg.content, f"Row {i} should have context" + assert "Question:" in user_msg.content, f"Row {i} should have question" + dataset_info = row.input_metadata.dataset_info print(f" Row {i}: Context length={dataset_info.get('context_length')}") @@ -354,54 +373,54 @@ def squad_transform(row: Dict[str, Any]) -> Dict[str, Any]: def test_adapters_integration(): """Test that adapters work with evaluation pipeline.""" print("Testing adapter integration with evaluation pipeline...") - + # This test doesn't require external credentials try: from eval_protocol.adapters.huggingface import create_huggingface_adapter from eval_protocol.rewards.accuracy import accuracy_reward except ImportError as e: pytest.skip(f"Dependencies not available: {e}") - + def simple_transform(row: Dict[str, Any]) -> Dict[str, Any]: """Simple transform for testing.""" return { - 'messages': [ - {'role': 'user', 'content': row['question']}, - {'role': 'assistant', 'content': 'Test response'}, # Simulated response + "messages": [ + {"role": "user", "content": row["question"]}, + {"role": "assistant", "content": "Test response"}, # Simulated response ], - 'ground_truth': row['answer'], - 'metadata': {'test': True} + "ground_truth": row["answer"], + "metadata": {"test": True}, } - + # Create adapter with GSM8K (small sample) adapter = create_huggingface_adapter( dataset_id="gsm8k", - config_name="main", + config_name="main", transform_fn=simple_transform, ) - + # Get one row rows = list(adapter.get_evaluation_rows(split="test", limit=1)) assert len(rows) == 1, "Should get exactly one row" - + row = rows[0] - + # Test evaluation result = accuracy_reward( messages=row.messages, ground_truth=row.ground_truth, ) - - assert hasattr(result, 'score'), "Should have evaluation score" + + assert hasattr(result, "score"), "Should have evaluation score" assert 0 <= result.score <= 1, "Score should be between 0 and 1" - + print(f"Integration test successful: Score={result.score}") if __name__ == "__main__": # Run tests manually for development import sys - + print("Running Langfuse E2E tests...") if all([os.getenv("LANGFUSE_PUBLIC_KEY"), os.getenv("LANGFUSE_SECRET_KEY")]): try: @@ -415,20 +434,20 @@ def simple_transform(row: Dict[str, Any]) -> Dict[str, Any]: print(" This is expected if Langfuse API has changed - the adapter needs updating") else: print("โš ๏ธ Skipping Langfuse tests (credentials not available)") - + print("\nRunning HuggingFace E2E tests...") try: test_hf = TestHuggingFaceAdapterE2E() test_hf.test_gsm8k_adapter_real_data() print("โœ… GSM8K adapter test passed!") - + # Skip MATH dataset test for now (dataset may not be available) try: test_hf.test_math_dataset_real_data() print("โœ… MATH dataset test passed!") except Exception as e: print(f"โš ๏ธ MATH dataset test failed (dataset may not be available): {e}") - + # Skip SQuAD test for now (focus on core functionality) try: test_hf.test_custom_dataset_transform() @@ -439,9 +458,9 @@ def simple_transform(row: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: print(f"โŒ HuggingFace tests failed: {e}") sys.exit(1) - + print("\nRunning integration test...") test_adapters_integration() print("โœ… Integration test passed!") - - print("\n๐ŸŽ‰ All E2E tests completed successfully!") \ No newline at end of file + + print("\n๐ŸŽ‰ All E2E tests completed successfully!") diff --git a/tests/test_cli_agent.py b/tests/test_cli_agent.py index cc50376f..9c12f221 100644 --- a/tests/test_cli_agent.py +++ b/tests/test_cli_agent.py @@ -41,7 +41,7 @@ class TestAgentEvalCommand: def test_agent_eval_success_yaml(self, MockPath, MockTaskManager, caplog): # Configure caplog to capture logs from the agent_eval logger caplog.set_level(logging.INFO, logger="agent_eval") - + # Setup Path mock mock_path_instance = Mock() MockPath.return_value = mock_path_instance diff --git a/tests/test_url_handling.py b/tests/test_url_handling.py index fbd71b28..dece18a7 100644 --- a/tests/test_url_handling.py +++ b/tests/test_url_handling.py @@ -1,4 +1,5 @@ from unittest.mock import AsyncMock, patch + import httpx import pytest from werkzeug.wrappers import Response diff --git a/vendor/tau2/__init__.py b/vendor/tau2/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/__init__.py +++ b/vendor/tau2/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/agent/README.md b/vendor/tau2/agent/README.md index fa201945..ee77cfb6 100644 --- a/vendor/tau2/agent/README.md +++ b/vendor/tau2/agent/README.md @@ -32,4 +32,4 @@ tau2 run \ --agent-llm \ --user-llm \ ... -``` \ No newline at end of file +``` diff --git a/vendor/tau2/agent/base.py b/vendor/tau2/agent/base.py index 7a345432..4d6d9dbd 100644 --- a/vendor/tau2/agent/base.py +++ b/vendor/tau2/agent/base.py @@ -73,9 +73,7 @@ def set_seed(self, seed: int): """ Set the seed for the agent. [Optional] """ - logger.warning( - f"Setting seed for agent is not implemented for class {self.__class__.__name__}" - ) + logger.warning(f"Setting seed for agent is not implemented for class {self.__class__.__name__}") class LocalAgent(BaseAgent[AgentState]): diff --git a/vendor/tau2/agent/llm_agent.py b/vendor/tau2/agent/llm_agent.py index 01201f35..9c01a69c 100644 --- a/vendor/tau2/agent/llm_agent.py +++ b/vendor/tau2/agent/llm_agent.py @@ -69,13 +69,9 @@ def __init__( @property def system_prompt(self) -> str: - return SYSTEM_PROMPT.format( - domain_policy=self.domain_policy, agent_instruction=AGENT_INSTRUCTION - ) + return SYSTEM_PROMPT.format(domain_policy=self.domain_policy, agent_instruction=AGENT_INSTRUCTION) - def get_init_state( - self, message_history: Optional[list[Message]] = None - ) -> LLMAgentState: + def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLMAgentState: """Get the initial state of the agent. Args: @@ -86,9 +82,9 @@ def get_init_state( """ if message_history is None: message_history = [] - assert all(is_valid_agent_history_message(m) for m in message_history), ( - "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." - ) + assert all( + is_valid_agent_history_message(m) for m in message_history + ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, @@ -172,9 +168,7 @@ def __init__( If provide_function_args is True, the resolution steps will include the function arguments. """ super().__init__(tools=tools, domain_policy=domain_policy) - assert self.check_valid_task(task), ( - f"Task {task.id} is not valid. Cannot run GT agent." - ) + assert self.check_valid_task(task), f"Task {task.id} is not valid. Cannot run GT agent." self.task = task self.llm = llm self.llm_args = deepcopy(llm_args) if llm_args is not None else {} @@ -201,9 +195,7 @@ def system_prompt(self) -> str: resolution_steps=self.make_agent_instructions_from_actions(), ) - def get_init_state( - self, message_history: Optional[list[Message]] = None - ) -> LLMAgentState: + def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLMAgentState: """Get the initial state of the agent. Args: @@ -214,9 +206,9 @@ def get_init_state( """ if message_history is None: message_history = [] - assert all(is_valid_agent_history_message(m) for m in message_history), ( - "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." - ) + assert all( + is_valid_agent_history_message(m) for m in message_history + ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, @@ -263,9 +255,7 @@ def make_agent_instructions_from_actions(self) -> str: return "\n".join(lines) @classmethod - def make_agent_instructions_from_action( - cls, action: Action, include_function_args: bool = False - ) -> str: + def make_agent_instructions_from_action(cls, action: Action, include_function_args: bool = False) -> str: """ Make agent instructions from an action. If the action is a user action, returns instructions for the agent to give to the user. @@ -332,9 +322,7 @@ def __init__( Initialize the LLMAgent. """ super().__init__(tools=tools, domain_policy=domain_policy) - assert self.check_valid_task(task), ( - f"Task {task.id} is not valid. Cannot run GT agent." - ) + assert self.check_valid_task(task), f"Task {task.id} is not valid. Cannot run GT agent." self.task = task self.llm = llm self.llm_args = llm_args if llm_args is not None else {} @@ -417,9 +405,7 @@ def is_stop(cls, message: AssistantMessage) -> bool: return False return cls.STOP_TOKEN in message.content - def get_init_state( - self, message_history: Optional[list[Message]] = None - ) -> LLMAgentState: + def get_init_state(self, message_history: Optional[list[Message]] = None) -> LLMAgentState: """Get the initial state of the agent. Args: @@ -430,9 +416,9 @@ def get_init_state( """ if message_history is None: message_history = [] - assert all(is_valid_agent_history_message(m) for m in message_history), ( - "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." - ) + assert all( + is_valid_agent_history_message(m) for m in message_history + ), "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." return LLMAgentState( system_messages=[SystemMessage(role="system", content=self.system_prompt)], messages=message_history, diff --git a/vendor/tau2/cli.py b/vendor/tau2/cli.py index 65b2d115..89056109 100644 --- a/vendor/tau2/cli.py +++ b/vendor/tau2/cli.py @@ -193,9 +193,7 @@ def main(): start_parser.set_defaults(func=lambda args: run_start_servers()) # Check data command - check_data_parser = subparsers.add_parser( - "check-data", help="Check if data directory is properly configured" - ) + check_data_parser = subparsers.add_parser("check-data", help="Check if data directory is properly configured") check_data_parser.set_defaults(func=lambda args: run_check_data()) args = parser.parse_args() diff --git a/vendor/tau2/data/user_simulator/simulation_guidelines.md b/vendor/tau2/data/user_simulator/simulation_guidelines.md index 8bf34059..f7a559fd 100644 --- a/vendor/tau2/data/user_simulator/simulation_guidelines.md +++ b/vendor/tau2/data/user_simulator/simulation_guidelines.md @@ -1,5 +1,5 @@ # User Simulation Guidelines -You are playing the role of a customer contacting a customer service representative. +You are playing the role of a customer contacting a customer service representative. Your goal is to simulate realistic customer interactions while following specific scenario instructions. ## Core Principles @@ -15,4 +15,4 @@ Your goal is to simulate realistic customer interactions while following specifi - If you are transferred to another agent, generate the '###TRANSFER###' token to indicate the transfer. - If you find yourself in a situation in which the scenario does not provide enough information for you to continue the conversation, generate the '###OUT-OF-SCOPE###' token to end the conversation. -Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. \ No newline at end of file +Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. diff --git a/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md b/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md index 09f85a50..33908510 100644 --- a/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md +++ b/vendor/tau2/data/user_simulator/simulation_guidelines_tools.md @@ -1,6 +1,6 @@ # User Simulation Guidelines -You are playing the role of a customer contacting a customer service representative agent. +You are playing the role of a customer contacting a customer service representative agent. Your goal is to simulate realistic customer interactions while following specific scenario instructions. You have some tools to perform the actions on your end that might be requested by the agent to diagnose and resolve your issue. @@ -27,4 +27,4 @@ You have some tools to perform the actions on your end that might be requested b - If you have been transferred to another agent, generate the '###TRANSFER###' token to indicate the transfer. Only do this after the agent has clearly indicated that you are being transferred. - If you find yourself in a situation in which the scenario does not provide enough information for you to continue the conversation, generate the '###OUT-OF-SCOPE###' token to end the conversation. -Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. \ No newline at end of file +Remember: The goal is to create realistic, natural conversations while strictly adhering to the provided instructions and maintaining character consistency. diff --git a/vendor/tau2/data_model/__init__.py b/vendor/tau2/data_model/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/data_model/__init__.py +++ b/vendor/tau2/data_model/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/data_model/message.py b/vendor/tau2/data_model/message.py index ef5f1f7b..077c176d 100644 --- a/vendor/tau2/data_model/message.py +++ b/vendor/tau2/data_model/message.py @@ -18,15 +18,9 @@ class SystemMessage(BaseModel): """ role: SystemRole = Field(description="The role of the message sender.") - content: Optional[str] = Field( - description="The content of the message.", default=None - ) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + content: Optional[str] = Field(description="The content of the message.", default=None) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) def __str__(self) -> str: lines = [ @@ -87,35 +81,21 @@ class ParticipantMessageBase(BaseModel): role: str = Field(description="The role of the message sender.") - content: Optional[str] = Field( - description="The content of the message.", default=None - ) - tool_calls: Optional[list[ToolCall]] = Field( - description="The tool calls made in the message.", default=None - ) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + content: Optional[str] = Field(description="The content of the message.", default=None) + tool_calls: Optional[list[ToolCall]] = Field(description="The tool calls made in the message.", default=None) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) cost: Optional[float] = Field(description="The cost of the message.", default=None) - usage: Optional[dict] = Field( - description="The token usage of the message.", default=None - ) - raw_data: Optional[dict] = Field( - description="The raw data of the message.", default=None - ) + usage: Optional[dict] = Field(description="The token usage of the message.", default=None) + raw_data: Optional[dict] = Field(description="The raw data of the message.", default=None) def validate(self): # NOTE: It would be better to do this in the Pydantic model """ Validate the message. """ if not (self.has_text_content() or self.is_tool_call()): - raise ValueError( - f"AssistantMessage must have either content or tool calls. Got {self}" - ) + raise ValueError(f"AssistantMessage must have either content or tool calls. Got {self}") def has_text_content(self) -> bool: """ @@ -151,11 +131,7 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: if type(other) is not type(self): return False - return ( - self.role == other.role - and self.content == other.content - and self.tool_calls == other.tool_calls - ) + return self.role == other.role and self.content == other.content and self.tool_calls == other.tool_calls class AssistantMessage(ParticipantMessageBase): @@ -187,12 +163,8 @@ class ToolMessage(BaseModel): description="The requestor of the tool call.", ) error: bool = Field(description="Whether the tool call failed.", default=False) - turn_idx: Optional[int] = Field( - description="The index of the turn in the conversation.", default=None - ) - timestamp: Optional[str] = Field( - description="The timestamp of the message.", default_factory=get_now - ) + turn_idx: Optional[int] = Field(description="The index of the turn in the conversation.", default=None) + timestamp: Optional[str] = Field(description="The timestamp of the message.", default_factory=get_now) def __str__(self) -> str: lines = [f"ToolMessage (responding to {self.requestor})"] @@ -228,6 +200,4 @@ class MultiToolMessage(BaseModel): APICompatibleMessage = SystemMessage | AssistantMessage | UserMessage | ToolMessage -Message = ( - SystemMessage | AssistantMessage | UserMessage | ToolMessage | MultiToolMessage -) +Message = SystemMessage | AssistantMessage | UserMessage | ToolMessage | MultiToolMessage diff --git a/vendor/tau2/data_model/simulation.py b/vendor/tau2/data_model/simulation.py index 41de1d72..ee630cf7 100644 --- a/vendor/tau2/data_model/simulation.py +++ b/vendor/tau2/data_model/simulation.py @@ -216,9 +216,7 @@ class RewardInfo(BaseModel): """ reward: Annotated[float, Field(description="The reward received by the agent.")] - db_check: Annotated[ - Optional[DBCheck], Field(description="The database check.", default=None) - ] + db_check: Annotated[Optional[DBCheck], Field(description="The database check.", default=None)] env_assertions: Annotated[ Optional[list[EnvAssertionCheck]], Field(description="The environment assertions.", default=None), @@ -265,9 +263,7 @@ class AgentInfo(BaseModel): implementation: str = Field(description="The type of agent.") llm: Optional[str] = Field(description="The LLM used by the agent.", default=None) - llm_args: Optional[dict] = Field( - description="The arguments to pass to the LLM for the agent.", default=None - ) + llm_args: Optional[dict] = Field(description="The arguments to pass to the LLM for the agent.", default=None) class UserInfo(BaseModel): @@ -277,9 +273,7 @@ class UserInfo(BaseModel): implementation: str = Field(description="The type of user.") llm: Optional[str] = Field(description="The LLM used by the user.", default=None) - llm_args: Optional[dict] = Field( - description="The arguments to pass to the LLM for the user.", default=None - ) + llm_args: Optional[dict] = Field(description="The arguments to pass to the LLM for the user.", default=None) global_simulation_guidelines: Optional[str] = Field( description="The global simulation guidelines for the user.", default=None ) @@ -295,9 +289,7 @@ class Info(BaseModel): user_info: UserInfo = Field(description="User information.") agent_info: AgentInfo = Field(description="Agent information.") environment_info: EnvironmentInfo = Field(description="Environment information.") - seed: Optional[int] = Field( - description="The seed used for the simulation.", default=None - ) + seed: Optional[int] = Field(description="The seed used for the simulation.", default=None) class TerminationReason(str, Enum): @@ -314,31 +306,17 @@ class SimulationRun(BaseModel): id: str = Field(description="The unique identifier for the simulation run.") task_id: str = Field(description="The unique identifier for the task.") - timestamp: str = Field( - description="The timestamp of the simulation.", default_factory=get_now - ) + timestamp: str = Field(description="The timestamp of the simulation.", default_factory=get_now) start_time: str = Field(description="The start time of the simulation.") end_time: str = Field(description="The end time of the simulation.") duration: float = Field(description="The duration of the simulation.") - termination_reason: TerminationReason = Field( - description="The reason for the termination of the simulation." - ) - agent_cost: Optional[float] = Field( - description="The cost of the agent.", default=None - ) - user_cost: Optional[float] = Field( - description="The cost of the user.", default=None - ) - reward_info: Optional[RewardInfo] = Field( - description="The reward received by the agent.", default=None - ) - messages: list[Message] = Field( - description="The messages exchanged between the user, agent and environment." - ) + termination_reason: TerminationReason = Field(description="The reason for the termination of the simulation.") + agent_cost: Optional[float] = Field(description="The cost of the agent.", default=None) + user_cost: Optional[float] = Field(description="The cost of the user.", default=None) + reward_info: Optional[RewardInfo] = Field(description="The reward received by the agent.", default=None) + messages: list[Message] = Field(description="The messages exchanged between the user, agent and environment.") trial: Optional[int] = Field(description="Trial number", default=None) - seed: Optional[int] = Field( - description="Seed used for the simulation.", default=None - ) + seed: Optional[int] = Field(description="Seed used for the simulation.", default=None) class Results(BaseModel): @@ -346,9 +324,7 @@ class Results(BaseModel): Run results """ - timestamp: Optional[str] = Field( - description="The timestamp of the simulation.", default_factory=get_now - ) + timestamp: Optional[str] = Field(description="The timestamp of the simulation.", default_factory=get_now) info: Info = Field(description="Information.") tasks: list[Task] = Field(description="The list of tasks.") simulations: list[SimulationRun] = Field(description="The list of simulations.") @@ -387,14 +363,8 @@ def transfer_only(task: Task) -> bool: return False def get_task_metrics(task: Task) -> dict: - eval_metrics = ( - task.evaluation_criteria.info() - if task.evaluation_criteria is not None - else {} - ) - num_actions = ( - eval_metrics["num_agent_actions"] + eval_metrics["num_user_actions"] - ) + eval_metrics = task.evaluation_criteria.info() if task.evaluation_criteria is not None else {} + num_actions = eval_metrics["num_agent_actions"] + eval_metrics["num_user_actions"] if transfer_only(task): num_actions = -1 info = { diff --git a/vendor/tau2/data_model/tasks.py b/vendor/tau2/data_model/tasks.py index ef17dc3c..c3105557 100644 --- a/vendor/tau2/data_model/tasks.py +++ b/vendor/tau2/data_model/tasks.py @@ -18,9 +18,7 @@ class StructuredUserInstructions(BaseModel): """ domain: Annotated[str, Field(description="The domain of the task.")] - reason_for_call: Annotated[ - str, Field(description="The reason for the user to call the agent.") - ] + reason_for_call: Annotated[str, Field(description="The reason for the user to call the agent.")] known_info: Annotated[ Optional[str], Field(description="Known information about the user.", default=None), @@ -40,9 +38,7 @@ def __str__(self) -> str: lines.append(f"Known info:\n{textwrap.indent(self.known_info, tab)}") if self.unknown_info is not None: lines.append(f"Unknown info:\n{textwrap.indent(self.unknown_info, tab)}") - lines.append( - f"Task instructions:\n{textwrap.indent(self.task_instructions, tab)}" - ) + lines.append(f"Task instructions:\n{textwrap.indent(self.task_instructions, tab)}") return "\n".join(lines) @@ -128,18 +124,14 @@ class Action(BaseModel): If compare_args is None, will check all the arguments. """ - action_id: str = Field( - description="The unique identifier for the action within a scenario." - ) + action_id: str = Field(description="The unique identifier for the action within a scenario.") requestor: ToolRequestor = Field( description="The requestor of the action.", default="assistant", ) name: str = Field(description="The name of the action.") arguments: dict = Field(description="The arguments for the action.") - info: Optional[str] = Field( - description="Information about the action.", default=None - ) + info: Optional[str] = Field(description="Information about the action.", default=None) compare_args: Optional[list[str]] = Field( description="The arguments to check in tool call. If None, will check all the arguments.", default=None, @@ -159,9 +151,7 @@ def get_func_format(self) -> str: """ Get the function format of the action. """ - return ( - f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" - ) + return f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" def compare_with_tool_call(self, tool_call: ToolCall) -> bool: """ @@ -193,9 +183,7 @@ class EnvFunctionCall(BaseModel): Field(description="The type of environment to call the function on."), ] func_name: Annotated[str, Field(description="The name of the function to call.")] - arguments: Annotated[ - dict, Field(description="The arguments to pass to the function.") - ] + arguments: Annotated[dict, Field(description="The arguments to pass to the function.")] def __str__(self) -> str: lines = [] @@ -210,9 +198,7 @@ class EnvAssertion(EnvFunctionCall): An assertion on the agent or user environment. """ - assert_value: Annotated[ - bool, Field(default=True, description="The value to assert on.") - ] + assert_value: Annotated[bool, Field(default=True, description="The value to assert on.")] message: Annotated[ Optional[str], Field( @@ -279,27 +265,16 @@ def __str__(self) -> str: lines = [] if self.actions is not None: lines.append("Actions:") - lines.extend( - [textwrap.indent(str(action), "\t") for action in self.actions] - ) + lines.extend([textwrap.indent(str(action), "\t") for action in self.actions]) if self.env_assertions is not None: lines.append("Env Assertions:") - lines.extend( - [ - textwrap.indent(str(assertion), "\t") - for assertion in self.env_assertions - ] - ) + lines.extend([textwrap.indent(str(assertion), "\t") for assertion in self.env_assertions]) if self.communicate_info is not None: lines.append("Communicate Info:") - lines.extend( - [textwrap.indent(info, "\t") for info in self.communicate_info] - ) + lines.extend([textwrap.indent(info, "\t") for info in self.communicate_info]) if self.nl_assertions is not None: lines.append("NL Assertions:") - lines.extend( - [textwrap.indent(assertion, "\t") for assertion in self.nl_assertions] - ) + lines.extend([textwrap.indent(assertion, "\t") for assertion in self.nl_assertions]) return "\n".join(lines) def info(self) -> dict: @@ -309,16 +284,10 @@ def info(self) -> dict: else 0 ) num_user_actions = ( - len([action for action in self.actions if action.requestor == "user"]) - if self.actions is not None - else 0 - ) - num_env_assertions = ( - len(self.env_assertions) if self.env_assertions is not None else 0 - ) - num_nl_assertions = ( - len(self.nl_assertions) if self.nl_assertions is not None else 0 + len([action for action in self.actions if action.requestor == "user"]) if self.actions is not None else 0 ) + num_env_assertions = len(self.env_assertions) if self.env_assertions is not None else 0 + num_nl_assertions = len(self.nl_assertions) if self.nl_assertions is not None else 0 return { "num_agent_actions": num_agent_actions, "num_user_actions": num_user_actions, @@ -354,9 +323,7 @@ class InitialState(BaseModel): ] initialization_actions: Annotated[ Optional[list[EnvFunctionCall]], - Field( - description="Initial actions to be taken on the environment.", default=None - ), + Field(description="Initial actions to be taken on the environment.", default=None), ] message_history: Annotated[ Optional[list[Message]], @@ -370,29 +337,13 @@ def __str__(self) -> str: lines = [] if self.initialization_data is not None: lines.append("Initialization Data:") - lines.extend( - [ - textwrap.indent( - self.initialization_data.model_dump_json(indent=2), "\t" - ) - ] - ) + lines.extend([textwrap.indent(self.initialization_data.model_dump_json(indent=2), "\t")]) if self.initialization_actions is not None: lines.append("Initialization Actions:") - lines.extend( - [ - textwrap.indent(str(action), "\t") - for action in self.initialization_actions - ] - ) + lines.extend([textwrap.indent(str(action), "\t") for action in self.initialization_actions]) if self.message_history is not None: lines.append("Message History:") - lines.extend( - [ - textwrap.indent(str(message), "\t") - for message in self.message_history - ] - ) + lines.extend([textwrap.indent(str(message), "\t") for message in self.message_history]) return "\n".join(lines) @@ -411,9 +362,7 @@ class Task(BaseModel): ] user_scenario: Annotated[ UserScenario, - Field( - description="User scenario. This information will be sent to the user simulator." - ), + Field(description="User scenario. This information will be sent to the user simulator."), ] ticket: Annotated[ Optional[str], @@ -478,11 +427,7 @@ def make_task( if message_history is not None: # Patch to consider empty list of tool calls as None. for message in message_history: - if ( - message.role == "assistant" - and isinstance(message.tool_calls, list) - and len(message.tool_calls) == 0 - ): + if message.role == "assistant" and isinstance(message.tool_calls, list) and len(message.tool_calls) == 0: message.tool_calls = None initial_state = InitialState( diff --git a/vendor/tau2/domains/airline/data_model.py b/vendor/tau2/domains/airline/data_model.py index f2733727..c046d228 100644 --- a/vendor/tau2/domains/airline/data_model.py +++ b/vendor/tau2/domains/airline/data_model.py @@ -10,9 +10,7 @@ Insurance = Literal["yes", "no"] -MembershipLevel = Annotated[ - Literal["gold", "silver", "regular"], Field(description="Membership level") -] +MembershipLevel = Annotated[Literal["gold", "silver", "regular"], Field(description="Membership level")] class AirportCode(BaseModel): @@ -30,9 +28,7 @@ class Name(BaseModel): class Address(BaseModel): address1: str = Field(description="Primary address line") - address2: Optional[str] = Field( - None, description="Secondary address line (optional)" - ) + address2: Optional[str] = Field(None, description="Secondary address line (optional)") city: str = Field(description="City name") country: str = Field(description="Country name") state: str = Field(description="State or province name") @@ -51,25 +47,19 @@ class PaymentMethodBase(BaseModel): class CreditCard(PaymentMethodBase): - source: Literal["credit_card"] = Field( - description="Indicates this is a credit card payment method" - ) + source: Literal["credit_card"] = Field(description="Indicates this is a credit card payment method") brand: str = Field(description="Credit card brand (e.g., visa, mastercard)") last_four: str = Field(description="Last four digits of the credit card") class GiftCard(PaymentMethodBase): - source: Literal["gift_card"] = Field( - description="Indicates this is a gift card payment method" - ) + source: Literal["gift_card"] = Field(description="Indicates this is a gift card payment method") amount: float = Field(description="Gift card value amount") id: str = Field(description="Unique identifier for the gift card") class Certificate(PaymentMethodBase): - source: Literal["certificate"] = Field( - description="Indicates this is a certificate payment method" - ) + source: Literal["certificate"] = Field(description="Indicates this is a certificate payment method") amount: float = Field(description="Certificate value amount") @@ -82,9 +72,7 @@ class Passenger(BaseModel): dob: str = Field(description="Date of birth in YYYY-MM-DD format") -SeatPrices = Annotated[ - dict[CabinClass, int], Field(description="Prices for different cabin classes") -] +SeatPrices = Annotated[dict[CabinClass, int], Field(description="Prices for different cabin classes")] AvailableSeats = Annotated[ dict[CabinClass, int], Field(description="Available seats for different cabin classes"), @@ -92,9 +80,7 @@ class Passenger(BaseModel): class FlightDateStatusAvailable(BaseModel): - status: Literal["available"] = Field( - description="Indicates flight is available for booking" - ) + status: Literal["available"] = Field(description="Indicates flight is available for booking") available_seats: AvailableSeats = Field(description="Available seats by class") prices: SeatPrices = Field(description="Current prices by class") @@ -166,24 +152,18 @@ class Flight(FlightBase): scheduled_arrival_time_est: str = Field( description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00" ) - dates: Dict[str, FlightDateStatus] = Field( - description="Flight status by date (YYYY-MM-DD)" - ) + dates: Dict[str, FlightDateStatus] = Field(description="Flight status by date (YYYY-MM-DD)") class DirectFlight(FlightBase): - status: Literal["available"] = Field( - description="Indicates flight is available for booking" - ) + status: Literal["available"] = Field(description="Indicates flight is available for booking") scheduled_departure_time_est: str = Field( description="Scheduled departure time in EST in the format HH:MM:SS, e.g 06:00:00" ) scheduled_arrival_time_est: str = Field( description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00" ) - date: Optional[str] = Field( - description="Flight date in YYYY-MM-DD format", default=None - ) + date: Optional[str] = Field(description="Flight date in YYYY-MM-DD format", default=None) available_seats: AvailableSeats = Field(description="Available seats by class") prices: SeatPrices = Field(description="Current prices by class") @@ -195,9 +175,7 @@ class ReservationFlight(FlightBase): class FlightInfo(BaseModel): flight_number: str = Field(description="Flight number, such as 'HAT001'.") - date: str = Field( - description="The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'." - ) + date: str = Field(description="The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'.") class User(BaseModel): @@ -205,15 +183,9 @@ class User(BaseModel): name: Name = Field(description="User's full name") address: Address = Field(description="User's address information") email: str = Field(description="User's email address") - dob: str = Field( - description="User's date of birth in the format YYYY-MM-DD, e.g 1990-04-05" - ) - payment_methods: Dict[str, PaymentMethod] = Field( - description="User's saved payment methods" - ) - saved_passengers: List[Passenger] = Field( - description="User's saved passenger information" - ) + dob: str = Field(description="User's date of birth in the format YYYY-MM-DD, e.g 1990-04-05") + payment_methods: Dict[str, PaymentMethod] = Field(description="User's saved payment methods") + saved_passengers: List[Passenger] = Field(description="User's saved passenger information") membership: MembershipLevel = Field(description="User's membership level") reservations: List[str] = Field(description="List of user's reservation IDs") @@ -226,35 +198,21 @@ class Reservation(BaseModel): destination: str = Field(description="IATA code for trip destination") flight_type: FlightType = Field(description="Type of trip") cabin: CabinClass = Field(description="Selected cabin class") - flights: List[ReservationFlight] = Field( - description="List of flights in the reservation" - ) - passengers: List[Passenger] = Field( - description="List of passengers on the reservation" - ) - payment_history: List[Payment] = Field( - description="History of payments for this reservation" - ) - created_at: str = Field( - description="Timestamp when reservation was created in the format YYYY-MM-DDTHH:MM:SS" - ) + flights: List[ReservationFlight] = Field(description="List of flights in the reservation") + passengers: List[Passenger] = Field(description="List of passengers on the reservation") + payment_history: List[Payment] = Field(description="History of payments for this reservation") + created_at: str = Field(description="Timestamp when reservation was created in the format YYYY-MM-DDTHH:MM:SS") total_baggages: int = Field(description="Total number of bags in reservation") nonfree_baggages: int = Field(description="Number of paid bags in reservation") insurance: Insurance = Field(description="Whether travel insurance was purchased") - status: Optional[Literal["cancelled"]] = Field( - description="Status of the reservation", default=None - ) + status: Optional[Literal["cancelled"]] = Field(description="Status of the reservation", default=None) class FlightDB(DB): """Database of all flights, users, and reservations.""" - flights: Dict[str, Flight] = Field( - description="Dictionary of all flights indexed by flight number" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) + flights: Dict[str, Flight] = Field(description="Dictionary of all flights indexed by flight number") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") reservations: Dict[str, Reservation] = Field( description="Dictionary of all reservations indexed by reservation ID" ) @@ -262,9 +220,7 @@ class FlightDB(DB): def get_statistics(self) -> dict[str, Any]: """Get the statistics of the database.""" num_flights = len(self.flights) - num_flights_instances = sum( - len(flight.dates) for flight in self.flights.values() - ) + num_flights_instances = sum(len(flight.dates) for flight in self.flights.values()) num_users = len(self.users) num_reservations = len(self.reservations) return { diff --git a/vendor/tau2/domains/airline/tools.py b/vendor/tau2/domains/airline/tools.py index d4f45694..d854f725 100644 --- a/vendor/tau2/domains/airline/tools.py +++ b/vendor/tau2/domains/airline/tools.py @@ -62,15 +62,11 @@ def _get_flight_instance(self, flight_number: str, date: str) -> FlightDateStatu raise ValueError(f"Flight {flight_number} not found on date {date}") return flight.dates[date] - def _get_flights_from_flight_infos( - self, flight_infos: List[FlightInfo] - ) -> list[FlightDateStatus]: + def _get_flights_from_flight_infos(self, flight_infos: List[FlightInfo]) -> list[FlightDateStatus]: """Get the flight from the reservation.""" flights = [] for flight_info in flight_infos: - flights.append( - self._get_flight_instance(flight_info.flight_number, flight_info.date) - ) + flights.append(self._get_flight_instance(flight_info.flight_number, flight_info.date)) return flights def _get_new_reservation_id(self) -> str: @@ -123,10 +119,7 @@ def _search_direct_flight( and (destination is None or flight.destination == destination) and (date in flight.dates) and (flight.dates[date].status == "available") - and ( - leave_after is None - or flight.scheduled_departure_time_est >= leave_after - ) + and (leave_after is None or flight.scheduled_departure_time_est >= leave_after) ) if check: direct_flight = DirectFlight( @@ -142,9 +135,7 @@ def _search_direct_flight( results.append(direct_flight) return results - def _payment_for_update( - self, user: User, payment_id: str, total_price: int - ) -> Optional[Payment]: + def _payment_for_update(self, user: User, payment_id: str, total_price: int) -> Optional[Payment]: """ Process payment for update reservation @@ -165,9 +156,7 @@ def _payment_for_update( payment_method = user.payment_methods[payment_id] if payment_method.source == "certificate": raise ValueError("Certificate cannot be used to update reservation") - elif ( - payment_method.source == "gift_card" and payment_method.amount < total_price - ): + elif payment_method.source == "gift_card" and payment_method.amount < total_price: raise ValueError("Gift card balance is not enough") # Deduct payment @@ -219,9 +208,7 @@ def book_reservation( if all(isinstance(passenger, dict) for passenger in passengers): passengers = [Passenger(**passenger) for passenger in passengers] if all(isinstance(payment_method, dict) for payment_method in payment_methods): - payment_methods = [ - Payment(**payment_method) for payment_method in payment_methods - ] + payment_methods = [Payment(**payment_method) for payment_method in payment_methods] user = self._get_user(user_id) reservation_id = self._get_new_reservation_id() @@ -248,14 +235,10 @@ def book_reservation( for flight_info in flights: flight_number = flight_info.flight_number flight = self._get_flight(flight_number) - flight_date_data = self._get_flight_instance( - flight_number=flight_number, date=flight_info.date - ) + flight_date_data = self._get_flight_instance(flight_number=flight_number, date=flight_info.date) # Checking flight availability if not isinstance(flight_date_data, FlightDateStatusAvailable): - raise ValueError( - f"Flight {flight_number} not available on date {flight_info.date}" - ) + raise ValueError(f"Flight {flight_number} not available on date {flight_info.date}") # Checking seat availability if flight_date_data.available_seats[cabin] < len(passengers): raise ValueError(f"Not enough seats on flight {flight_number}") @@ -290,15 +273,11 @@ def book_reservation( user_payment_method = user.payment_methods[payment_id] if user_payment_method.source in {"gift_card", "certificate"}: if user_payment_method.amount < amount: - raise ValueError( - f"Not enough balance in payment method {payment_id}" - ) + raise ValueError(f"Not enough balance in payment method {payment_id}") total_payment = sum(payment.amount for payment in payment_methods) if total_payment != total_price: - raise ValueError( - f"Payment amount does not add up, total price is {total_price}, but paid {total_payment}" - ) + raise ValueError(f"Payment amount does not add up, total price is {total_price}, but paid {total_payment}") # if checks pass, deduct payment for payment_method in payment_methods: @@ -430,9 +409,7 @@ def list_all_airports(self) -> AirportInfo: # DONE ] @is_tool(ToolType.READ) - def search_direct_flight( - self, origin: str, destination: str, date: str - ) -> list[DirectFlight]: + def search_direct_flight(self, origin: str, destination: str, date: str) -> list[DirectFlight]: """ Search for direct flights between two cities on a specific date. @@ -444,9 +421,7 @@ def search_direct_flight( Returns: The direct flights between the two cities on the specific date. """ - return self._search_direct_flight( - date=date, origin=origin, destination=destination - ) + return self._search_direct_flight(date=date, origin=origin, destination=destination) @is_tool(ToolType.READ) def search_onestop_flight( @@ -464,15 +439,9 @@ def search_onestop_flight( A list of pairs of DirectFlight objects. """ results = [] - for result1 in self._search_direct_flight( - date=date, origin=origin, destination=None - ): + for result1 in self._search_direct_flight(date=date, origin=origin, destination=None): result1.date = date - date2 = ( - f"2024-05-{int(date[-2:]) + 1}" - if "+1" in result1.scheduled_arrival_time_est - else date - ) + date2 = f"2024-05-{int(date[-2:]) + 1}" if "+1" in result1.scheduled_arrival_time_est else date # TODO: flight1.scheduled_arrival_time_est could have a +1? for result2 in self._search_direct_flight( date=date2, @@ -637,9 +606,7 @@ def update_reservation_flights( None, ) if matching_reservation_flight: - total_price += matching_reservation_flight.price * len( - reservation.passengers - ) + total_price += matching_reservation_flight.price * len(reservation.passengers) reservation_flights.append(matching_reservation_flight) continue @@ -651,15 +618,11 @@ def update_reservation_flights( date=flight_info.date, ) if not isinstance(flight_date_data, FlightDateStatusAvailable): - raise ValueError( - f"Flight {flight_info.flight_number} not available on date {flight_info.date}" - ) + raise ValueError(f"Flight {flight_info.flight_number} not available on date {flight_info.date}") # Check seat availability if flight_date_data.available_seats[cabin] < len(reservation.passengers): - raise ValueError( - f"Not enough seats on flight {flight_info.flight_number}" - ) + raise ValueError(f"Not enough seats on flight {flight_info.flight_number}") # Calculate price and add to reservation reservation_flight = ReservationFlight( @@ -673,9 +636,7 @@ def update_reservation_flights( reservation_flights.append(reservation_flight) # Deduct amount already paid for reservation - total_price -= sum(flight.price for flight in reservation.flights) * len( - reservation.passengers - ) + total_price -= sum(flight.price for flight in reservation.flights) * len(reservation.passengers) # Create payment payment = self._payment_for_update(user, payment_id, total_price) @@ -690,9 +651,7 @@ def update_reservation_flights( return reservation @is_tool(ToolType.WRITE) - def update_reservation_passengers( - self, reservation_id: str, passengers: List[Passenger | dict] - ) -> Reservation: + def update_reservation_passengers(self, reservation_id: str, passengers: List[Passenger | dict]) -> Reservation: """ Update the passenger information of a reservation. diff --git a/vendor/tau2/domains/mock/data_model.py b/vendor/tau2/domains/mock/data_model.py index f643d3e0..bff026df 100644 --- a/vendor/tau2/domains/mock/data_model.py +++ b/vendor/tau2/domains/mock/data_model.py @@ -24,12 +24,8 @@ class User(BaseModel): class MockDB(DB): """Simple database with users and their tasks.""" - tasks: Dict[str, Task] = Field( - description="Dictionary of all tasks indexed by task ID" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) + tasks: Dict[str, Task] = Field(description="Dictionary of all tasks indexed by task ID") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") def get_db(): diff --git a/vendor/tau2/domains/mock/environment.py b/vendor/tau2/domains/mock/environment.py index d7063315..925fd297 100644 --- a/vendor/tau2/domains/mock/environment.py +++ b/vendor/tau2/domains/mock/environment.py @@ -13,9 +13,7 @@ from vendor.tau2.environment.environment import Environment -def get_environment( - db: Optional[MockDB] = None, solo_mode: bool = False -) -> Environment: +def get_environment(db: Optional[MockDB] = None, solo_mode: bool = False) -> Environment: if db is None: db = MockDB.load(MOCK_DB_PATH) tools = MockTools(db) diff --git a/vendor/tau2/domains/mock/tools.py b/vendor/tau2/domains/mock/tools.py index b36f46af..7c2ab361 100644 --- a/vendor/tau2/domains/mock/tools.py +++ b/vendor/tau2/domains/mock/tools.py @@ -30,9 +30,7 @@ def create_task(self, user_id: str, title: str, description: str = None) -> Task raise ValueError(f"User {user_id} not found") task_id = f"task_{len(self.db.tasks) + 1}" - task = Task( - task_id=task_id, title=title, description=description, status="pending" - ) + task = Task(task_id=task_id, title=title, description=description, status="pending") self.db.tasks[task_id] = task self.db.users[user_id].tasks.append(task_id) diff --git a/vendor/tau2/domains/retail/data_model.py b/vendor/tau2/domains/retail/data_model.py index ddb45e3d..d0415856 100644 --- a/vendor/tau2/domains/retail/data_model.py +++ b/vendor/tau2/domains/retail/data_model.py @@ -22,9 +22,7 @@ class Product(BaseModel): name: str = Field(description="Name of the product") product_id: str = Field(description="Unique identifier for the product") - variants: Dict[str, Variant] = Field( - description="Dictionary of variants indexed by variant ID" - ) + variants: Dict[str, Variant] = Field(description="Dictionary of variants indexed by variant ID") class UserName(BaseModel): @@ -51,23 +49,17 @@ class PaymentMethodBase(BaseModel): class CreditCard(PaymentMethodBase): - source: Literal["credit_card"] = Field( - description="Indicates this is a credit card payment method" - ) + source: Literal["credit_card"] = Field(description="Indicates this is a credit card payment method") brand: str = Field(description="Credit card brand (e.g., visa, mastercard)") last_four: str = Field(description="Last four digits of the credit card") class Paypal(PaymentMethodBase): - source: Literal["paypal"] = Field( - description="Indicates this is a paypal payment method" - ) + source: Literal["paypal"] = Field(description="Indicates this is a paypal payment method") class GiftCard(PaymentMethodBase): - source: Literal["gift_card"] = Field( - description="Indicates this is a gift card payment method" - ) + source: Literal["gift_card"] = Field(description="Indicates this is a gift card payment method") balance: float = Field(description="Gift card value amount") id: str = Field(description="Unique identifier for the gift card") @@ -92,9 +84,7 @@ class OrderFullfilment(BaseModel): """Represents the fulfillment details for items in an order""" tracking_id: list[str] = Field(description="List of tracking IDs for shipments") - item_ids: list[str] = Field( - description="List of item IDs included in this fulfillment" - ) + item_ids: list[str] = Field(description="List of item IDs included in this fulfillment") class OrderItem(BaseModel): @@ -113,9 +103,7 @@ class OrderItem(BaseModel): class OrderPayment(BaseModel): """Represents a payment or refund transaction for an order""" - transaction_type: OrderPaymentType = Field( - description="Type of transaction (payment or refund)" - ) + transaction_type: OrderPaymentType = Field(description="Type of transaction (payment or refund)") amount: float = Field(description="Amount of the transaction") payment_method_id: str = Field(description="ID of the payment method used") @@ -141,32 +129,18 @@ class BaseOrder(BaseModel): address: UserAddress = Field(description="Address of the user") items: List[OrderItem] = Field(description="Items in the order") status: OrderStatus = Field(description="Status of the order") - fulfillments: List[OrderFullfilment] = Field( - description="Fulfillments of the order" - ) + fulfillments: List[OrderFullfilment] = Field(description="Fulfillments of the order") payment_history: List[OrderPayment] = Field(description="Payments of the order") cancel_reason: Optional[CancelReason] = Field( description="Reason for cancelling the order. Can'no longer needed' or 'ordered by mistake'", default=None, ) - exchange_items: Optional[List[str]] = Field( - description="Items to be exchanged", default=None - ) - exchange_new_items: Optional[List[str]] = Field( - description="Items exchanged for", default=None - ) - exchange_payment_method_id: Optional[str] = Field( - description="Payment method ID for the exchange", default=None - ) - exchange_price_difference: Optional[float] = Field( - description="Price difference for the exchange", default=None - ) - return_items: Optional[List[str]] = Field( - description="Items to be returned", default=None - ) - return_payment_method_id: Optional[str] = Field( - description="Payment method ID for the return", default=None - ) + exchange_items: Optional[List[str]] = Field(description="Items to be exchanged", default=None) + exchange_new_items: Optional[List[str]] = Field(description="Items exchanged for", default=None) + exchange_payment_method_id: Optional[str] = Field(description="Payment method ID for the exchange", default=None) + exchange_price_difference: Optional[float] = Field(description="Price difference for the exchange", default=None) + return_items: Optional[List[str]] = Field(description="Items to be returned", default=None) + return_payment_method_id: Optional[str] = Field(description="Payment method ID for the return", default=None) class Order(BaseModel): @@ -177,55 +151,33 @@ class Order(BaseModel): address: UserAddress = Field(description="Address of the user") items: List[OrderItem] = Field(description="Items in the order") status: OrderStatus = Field(description="Status of the order") - fulfillments: List[OrderFullfilment] = Field( - description="Fulfillments of the order" - ) + fulfillments: List[OrderFullfilment] = Field(description="Fulfillments of the order") payment_history: List[OrderPayment] = Field(description="Payments of the order") cancel_reason: Optional[CancelReason] = Field( description="Reason for cancelling the order. Should be 'no longer needed' or 'ordered by mistake'", default=None, ) - exchange_items: Optional[List[str]] = Field( - description="Items to be exchanged", default=None - ) - exchange_new_items: Optional[List[str]] = Field( - description="Items exchanged for", default=None - ) - exchange_payment_method_id: Optional[str] = Field( - description="Payment method ID for the exchange", default=None - ) - exchange_price_difference: Optional[float] = Field( - description="Price difference for the exchange", default=None - ) - return_items: Optional[List[str]] = Field( - description="Items to be returned", default=None - ) - return_payment_method_id: Optional[str] = Field( - description="Payment method ID for the return", default=None - ) + exchange_items: Optional[List[str]] = Field(description="Items to be exchanged", default=None) + exchange_new_items: Optional[List[str]] = Field(description="Items exchanged for", default=None) + exchange_payment_method_id: Optional[str] = Field(description="Payment method ID for the exchange", default=None) + exchange_price_difference: Optional[float] = Field(description="Price difference for the exchange", default=None) + return_items: Optional[List[str]] = Field(description="Items to be returned", default=None) + return_payment_method_id: Optional[str] = Field(description="Payment method ID for the return", default=None) class RetailDB(DB): """Database containing all retail-related data including products, users and orders""" - products: Dict[str, Product] = Field( - description="Dictionary of all products indexed by product ID" - ) - users: Dict[str, User] = Field( - description="Dictionary of all users indexed by user ID" - ) - orders: Dict[str, Order] = Field( - description="Dictionary of all orders indexed by order ID" - ) + products: Dict[str, Product] = Field(description="Dictionary of all products indexed by product ID") + users: Dict[str, User] = Field(description="Dictionary of all users indexed by user ID") + orders: Dict[str, Order] = Field(description="Dictionary of all orders indexed by order ID") def get_statistics(self) -> dict[str, Any]: """Get the statistics of the database.""" num_products = len(self.products) num_users = len(self.users) num_orders = len(self.orders) - total_num_items = sum( - len(product.variants) for product in self.products.values() - ) + total_num_items = sum(len(product.variants) for product in self.products.values()) return { "num_products": num_products, "num_users": num_users, diff --git a/vendor/tau2/domains/retail/tools.py b/vendor/tau2/domains/retail/tools.py index 944e206c..6fc91e15 100644 --- a/vendor/tau2/domains/retail/tools.py +++ b/vendor/tau2/domains/retail/tools.py @@ -92,9 +92,7 @@ def _get_variant(self, product_id: str, variant_id: str) -> Variant: raise ValueError("Variant not found") return product.variants[variant_id] - def _get_payment_method( - self, user_id: str, payment_method_id: str - ) -> PaymentMethod: + def _get_payment_method(self, user_id: str, payment_method_id: str) -> PaymentMethod: """Get the payment method from the database. Args: @@ -252,9 +250,7 @@ def exchange_delivered_order_items( payment_method = self._get_payment_method(order.user_id, payment_method_id) if isinstance(payment_method, GiftCard) and payment_method.balance < diff_price: - raise ValueError( - "Insufficient gift card balance to pay for the price difference" - ) + raise ValueError("Insufficient gift card balance to pay for the price difference") # modify the order order.status = "exchange requested" @@ -266,9 +262,7 @@ def exchange_delivered_order_items( return order @is_tool(ToolType.READ) - def find_user_id_by_name_zip( - self, first_name: str, last_name: str, zip: str - ) -> str: + def find_user_id_by_name_zip(self, first_name: str, last_name: str, zip: str) -> str: """Find user id by first name, last name, and zip code. If the user is not found, the function will return an error message. By default, find user id by email, and only call this function if the user is not found by email or cannot remember email. @@ -368,9 +362,7 @@ def list_all_product_types(self) -> str: Returns: str: A JSON string mapping product names to their product IDs, sorted alphabetically by name. """ - product_dict = { - product.name: product.product_id for product in self.db.products.values() - } + product_dict = {product.name: product.product_id for product in self.db.products.values()} return json.dumps(product_dict, sort_keys=True) @is_tool(ToolType.WRITE) @@ -461,9 +453,7 @@ def modify_pending_order_items( diff_price = 0 for item_id, new_item_id in zip(item_ids, new_item_ids): if item_id == new_item_id: - raise ValueError( - "The new item id should be different from the old item id" - ) + raise ValueError("The new item id should be different from the old item id") item = next((item for item in order.items if item.item_id == item_id), None) if item is None: raise ValueError(f"Item {item_id} not found") @@ -538,17 +528,12 @@ def modify_pending_order_payment( payment_method = self._get_payment_method(order.user_id, payment_method_id) # Check that the payment history should only have one payment - if ( - len(order.payment_history) != 1 - or order.payment_history[0].transaction_type != "payment" - ): + if len(order.payment_history) != 1 or order.payment_history[0].transaction_type != "payment": raise ValueError("There should be exactly one payment for a pending order") # Check that the payment method is different if order.payment_history[0].payment_method_id == payment_method_id: - raise ValueError( - "The new payment method should be different from the current one" - ) + raise ValueError("The new payment method should be different from the current one") amount = order.payment_history[0].amount @@ -578,9 +563,7 @@ def modify_pending_order_payment( payment_method.balance = round(payment_method.balance, 2) # If refund is made to a gift card, update the balance - old_payment_method = self._get_payment_method( - order.user_id, order.payment_history[0].payment_method_id - ) + old_payment_method = self._get_payment_method(order.user_id, order.payment_history[0].payment_method_id) if isinstance(old_payment_method, GiftCard): old_payment_method.balance += amount old_payment_method.balance = round(old_payment_method.balance, 2) diff --git a/vendor/tau2/domains/telecom/data_model.py b/vendor/tau2/domains/telecom/data_model.py index b5dea830..ffc8b1a0 100644 --- a/vendor/tau2/domains/telecom/data_model.py +++ b/vendor/tau2/domains/telecom/data_model.py @@ -23,9 +23,7 @@ class Plan(BaseModelNoExtra): name: str = Field(description="Display name of the plan") data_limit_gb: float = Field(description="Monthly data allowance in gigabytes (GB)") price_per_month: float = Field(description="Monthly price of the plan in USD") - data_refueling_price_per_gb: float = Field( - description="Price per gigabyte for data refueling" - ) + data_refueling_price_per_gb: float = Field(description="Price per gigabyte for data refueling") class DeviceType(str, Enum): @@ -40,15 +38,9 @@ class Device(BaseModelNoExtra): device_id: str = Field(description="Unique identifier for the device") device_type: DeviceType = Field(description="Type/category of the device") model: str = Field(description="Model name/number of the device") - imei: Optional[str] = Field( - None, description="International Mobile Equipment Identity number" - ) - is_esim_capable: bool = Field( - description="Whether the device supports eSIM technology" - ) - activated: bool = Field( - False, description="Whether the device has been activated on the network" - ) + imei: Optional[str] = Field(None, description="International Mobile Equipment Identity number") + is_esim_capable: bool = Field(description="Whether the device supports eSIM technology") + activated: bool = Field(False, description="Whether the device has been activated on the network") activation_date: Optional[datetime.datetime] = Field( None, description="Date and time when the device was activated (format: YYYY-MM-DDTHH:MM:SS, timezone: EST)", @@ -69,22 +61,12 @@ class LineStatus(str, Enum): class Line(BaseModelNoExtra): line_id: str = Field(description="Unique identifier for the line") phone_number: str = Field(description="Phone number associated with the line") - status: LineStatus = Field( - LineStatus.PENDING_ACTIVATION, description="Current status of the line" - ) + status: LineStatus = Field(LineStatus.PENDING_ACTIVATION, description="Current status of the line") plan_id: str = Field(description="Plan associated with this line") - device_id: Optional[str] = Field( - None, description="Device associated with this line" - ) - data_used_gb: float = Field( - 0.0, description="Data used in the current billing cycle in gigabytes (GB)" - ) - data_refueling_gb: float = Field( - 0.0, description="Data refueled in the current billing cycle in gigabytes (GB)" - ) - roaming_enabled: bool = Field( - False, description="Whether international roaming is enabled for this line" - ) + device_id: Optional[str] = Field(None, description="Device associated with this line") + data_used_gb: float = Field(0.0, description="Data used in the current billing cycle in gigabytes (GB)") + data_refueling_gb: float = Field(0.0, description="Data refueled in the current billing cycle in gigabytes (GB)") + roaming_enabled: bool = Field(False, description="Whether international roaming is enabled for this line") contract_end_date: Optional[datetime.date] = Field( None, description="End date of the current contract, if applicable (format: YYYY-MM-DD, timezone: EST)", @@ -105,15 +87,9 @@ class Line(BaseModelNoExtra): class LineItem(BaseModelNoExtra): description: str = Field(description="Descriptive text for the line item") - amount: float = Field( - description="Monetary amount in USD (positive for charges, negative for credits)" - ) - date: datetime.date = Field( - description="Date the line item was applied (format: YYYY-MM-DD, timezone: EST)" - ) - item_type: str = Field( - description="Category of the line item (e.g., Plan Charge, Overage, Fee, Credit, Payment)" - ) + amount: float = Field(description="Monetary amount in USD (positive for charges, negative for credits)") + date: datetime.date = Field(description="Date the line item was applied (format: YYYY-MM-DD, timezone: EST)") + item_type: str = Field(description="Category of the line item (e.g., Plan Charge, Overage, Fee, Credit, Payment)") class BillStatus(str, Enum): @@ -131,23 +107,17 @@ class Bill(BaseModelNoExtra): period_start: datetime.date = Field( description="Start date of the billing period (format: YYYY-MM-DD, timezone: EST)" ) - period_end: datetime.date = Field( - description="End date of the billing period (format: YYYY-MM-DD, timezone: EST)" - ) + period_end: datetime.date = Field(description="End date of the billing period (format: YYYY-MM-DD, timezone: EST)") issue_date: datetime.date = Field( description="Date the bill was issued/generated (format: YYYY-MM-DD, timezone: EST)" ) total_due: float = Field(description="Total amount due in USD") - due_date: datetime.date = Field( - description="Date by which payment is due (format: YYYY-MM-DD, timezone: EST)" - ) + due_date: datetime.date = Field(description="Date by which payment is due (format: YYYY-MM-DD, timezone: EST)") line_items: List[LineItem] = Field( default_factory=list, description="Individual charges, credits, and payments on this bill", ) - status: BillStatus = Field( - BillStatus.DRAFT, description="Current status of the bill" - ) + status: BillStatus = Field(BillStatus.DRAFT, description="Current status of the bill") class AccountStatus(str, Enum): @@ -165,20 +135,14 @@ class PaymentMethodType(str, Enum): class PaymentMethod(BaseModelNoExtra): method_type: PaymentMethodType = Field(description="Type of payment method") - account_number_last_4: str = Field( - description="Last 4 digits of the account number" - ) - expiration_date: str = Field( - description="The expiration date of the payment method in the format MM/YYYY" - ) + account_number_last_4: str = Field(description="Last 4 digits of the account number") + expiration_date: str = Field(description="The expiration date of the payment method in the format MM/YYYY") class Customer(BaseModelNoExtra): customer_id: str = Field(description="Unique identifier for the customer") full_name: str = Field(description="Customer's full name") - date_of_birth: str = Field( - description="Customer's date of birth for identity verification (format: YYYY-MM-DD)" - ) + date_of_birth: str = Field(description="Customer's date of birth for identity verification (format: YYYY-MM-DD)") email: str = Field(description="Customer's email address") phone_number: str = Field(description="Customer's primary contact phone number") address: Address = Field(description="Customer's billing address") @@ -189,12 +153,8 @@ class Customer(BaseModelNoExtra): payment_methods: List[PaymentMethod] = Field( default_factory=list, description="Stored payment methods for this customer" ) - line_ids: List[str] = Field( - default_factory=list, description="Phone/data lines owned by this customer" - ) - bill_ids: List[str] = Field( - default_factory=list, description="Bills associated with this customer" - ) + line_ids: List[str] = Field(default_factory=list, description="Phone/data lines owned by this customer") + bill_ids: List[str] = Field(default_factory=list, description="Bills associated with this customer") created_at: datetime.datetime = Field( DEFAULT_START_DATE, description="Date and time when the customer account was created (format: YYYY-MM-DDTHH:MM:SS, timezone: EST)", @@ -211,21 +171,11 @@ class Customer(BaseModelNoExtra): class TelecomDB(DB): """Database interface for telecom domain.""" - plans: List[Plan] = Field( - default_factory=list, description="Available service plans" - ) - customers: List[Customer] = Field( - default_factory=list, description="All customers in the system" - ) - lines: List[Line] = Field( - default_factory=list, description="All lines in the system" - ) - bills: List[Bill] = Field( - default_factory=list, description="All bills in the system" - ) - devices: List[Device] = Field( - default_factory=list, description="All devices in the system" - ) + plans: List[Plan] = Field(default_factory=list, description="Available service plans") + customers: List[Customer] = Field(default_factory=list, description="All customers in the system") + lines: List[Line] = Field(default_factory=list, description="All lines in the system") + bills: List[Bill] = Field(default_factory=list, description="All bills in the system") + devices: List[Device] = Field(default_factory=list, description="All devices in the system") def get_statistics(self) -> Dict[str, Any]: """Get the statistics of the database.""" @@ -234,9 +184,7 @@ def get_statistics(self) -> Dict[str, Any]: num_lines = len(self.lines) num_bills = len(self.bills) num_devices = len(self.devices) - num_payment_methods = sum( - len(customer.payment_methods) for customer in self.customers - ) + num_payment_methods = sum(len(customer.payment_methods) for customer in self.customers) return { "num_plans": num_plans, diff --git a/vendor/tau2/domains/telecom/environment.py b/vendor/tau2/domains/telecom/environment.py index 00ab1b9e..5dcbaa79 100644 --- a/vendor/tau2/domains/telecom/environment.py +++ b/vendor/tau2/domains/telecom/environment.py @@ -47,9 +47,7 @@ def sync_tools(self): phone_number = self.user_tools.db.surroundings.phone_number line = self.tools._get_line_by_phone(phone_number) if line is None: - raise ValueError( - f"Wrong scenario, line not found for phone number: {phone_number}" - ) + raise ValueError(f"Wrong scenario, line not found for phone number: {phone_number}") # Check if the line is active if line.status == LineStatus.ACTIVE: self.user_tools.db.surroundings.line_active = True @@ -65,9 +63,7 @@ def sync_tools(self): # Check if the user has exceeded their data usage limit plan = self.tools._get_plan_by_id(line.plan_id) if plan is None: - raise ValueError( - f"Wrong scenario, invalid plan id ({line.plan_id}) for the phone number {phone_number}" - ) + raise ValueError(f"Wrong scenario, invalid plan id ({line.plan_id}) for the phone number {phone_number}") if line.data_used_gb >= plan.data_limit_gb + line.data_refueling_gb: self.user_tools.db.surroundings.mobile_data_usage_exceeded = True else: @@ -82,9 +78,7 @@ def sync_tools(self): # Check if the user has a payment request current_payment_request = self.user_tools.db.surroundings.payment_request - if ( - current_payment_request is None - ): # If there already is a payment request, do nothing + if current_payment_request is None: # If there already is a payment request, do nothing customer = self.tools.get_customer_by_phone(phone_number) bills = self.tools._get_bills_awaiting_payment(customer) if len(bills) != 0: diff --git a/vendor/tau2/domains/telecom/tasks/const.py b/vendor/tau2/domains/telecom/tasks/const.py index 8c41d7dc..555d935d 100644 --- a/vendor/tau2/domains/telecom/tasks/const.py +++ b/vendor/tau2/domains/telecom/tasks/const.py @@ -1,6 +1,6 @@ TOOL_CALL_INFO_CHECK = "If the tool call does not return updated status information, you might need to perform another tool call to get the updated status." TOOL_CALL_GROUNDING = """ -Whenever the agent asks you about your device, always ground your responses on the results of tool calls. +Whenever the agent asks you about your device, always ground your responses on the results of tool calls. For example: If the agent asks what the status bar shows, always ground your response on the results of the `get_status_bar` tool call. If the agent asks if you are able to send an MMS message, always ground your response on the results of the `can_send_mms` tool call. Never make up the results of tool calls, always ground your responses on the results of tool calls. If you are unsure about whether an action is necessary, always ask the agent for clarification. diff --git a/vendor/tau2/domains/telecom/tasks/mms_issues.py b/vendor/tau2/domains/telecom/tasks/mms_issues.py index 13319476..313ce272 100644 --- a/vendor/tau2/domains/telecom/tasks/mms_issues.py +++ b/vendor/tau2/domains/telecom/tasks/mms_issues.py @@ -92,9 +92,7 @@ def break_apn_mms_setting(*args, **kwargs) -> list[EnvFunctionCall]: ] -def _get_remove_app_permission_actions( - app_name: str = "messaging", permission: str = "sms" -): +def _get_remove_app_permission_actions(app_name: str = "messaging", permission: str = "sms"): """ Get the remove app permission actions for the mms issue task. """ @@ -116,9 +114,7 @@ def break_app_storage_permission(*args, **kwargs) -> list[EnvFunctionCall]: """ Break the app storage permission for the mms issue task. """ - return [ - _get_remove_app_permission_actions(app_name="messaging", permission="storage") - ] + return [_get_remove_app_permission_actions(app_name="messaging", permission="storage")] def break_app_both_permissions(*args, **kwargs) -> list[EnvFunctionCall]: @@ -163,9 +159,7 @@ def fix_break_apn_mms_setting(*args, **kwargs) -> list[ToolCall]: ] -def _get_grant_app_permission_actions( - app_name: str = "messaging", permission: str = "sms" -) -> ToolCall: +def _get_grant_app_permission_actions(app_name: str = "messaging", permission: str = "sms") -> ToolCall: """ Get the grant app permission actions for the mms issue task. """ @@ -187,9 +181,7 @@ def fix_break_app_storage_permission(*args, **kwargs) -> list[ToolCall]: """ Fix the break app storage permission issue. """ - return [ - _get_grant_app_permission_actions(app_name="messaging", permission="storage") - ] + return [_get_grant_app_permission_actions(app_name="messaging", permission="storage")] def fix_break_app_both_permissions(*args, **kwargs) -> list[ToolCall]: @@ -277,11 +269,7 @@ def fix_break_app_both_permissions(*args, **kwargs) -> list[ToolCall]: app_permission_issues, # Step3.5 ] -selection_sets = ( - service_issues_sample_sets - + mobile_data_issues_sample_sets - + mms_issues_selection_sets -) +selection_sets = service_issues_sample_sets + mobile_data_issues_sample_sets + mms_issues_selection_sets def task_validator(tasks: list[Optional[BaseTask]]): @@ -304,9 +292,7 @@ def task_validator(tasks: list[Optional[BaseTask]]): num_tasks_mms_issues = len( [ task - for task in tasks[ - len(service_issues_sample_sets) + len(mobile_data_issues_sample_sets) : - ] + for task in tasks[len(service_issues_sample_sets) + len(mobile_data_issues_sample_sets) :] if task is not None ] ) diff --git a/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py b/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py index 8e1caa70..b5405d44 100644 --- a/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py +++ b/vendor/tau2/domains/telecom/tasks/mobile_data_issues.py @@ -485,9 +485,7 @@ def assert_data_refueling_amount(env: TelecomEnvironment) -> list[EnvAssertion]: # Path 2.2: Slow Mobile Data # Requires workflow Step 2.2.1 -data_usage_exceeded_issues = SelectionSet( - tasks=[data_usage_exceeded_task, data_usage_exceeded_no_refuel_task] -) +data_usage_exceeded_issues = SelectionSet(tasks=[data_usage_exceeded_task, data_usage_exceeded_no_refuel_task]) # Requires workflow Step 2.2.2 data_saver_mode_issues = SelectionSet(tasks=[data_saver_mode_on_task]) @@ -518,9 +516,7 @@ def task_validator(tasks: list[Optional[BaseTask]]): # num_tasks_service_issues = len( # [task for task in tasks[: len(service_issues_sample_sets)] if task is not None] # ) - num_tasks_mobile_data_issues = len( - [task for task in tasks[len(service_issues_sample_sets) :] if task is not None] - ) + num_tasks_mobile_data_issues = len([task for task in tasks[len(service_issues_sample_sets) :] if task is not None]) return num_tasks_mobile_data_issues > 0 diff --git a/vendor/tau2/domains/telecom/tasks/utils.py b/vendor/tau2/domains/telecom/tasks/utils.py index 73c22b37..360660e5 100644 --- a/vendor/tau2/domains/telecom/tasks/utils.py +++ b/vendor/tau2/domains/telecom/tasks/utils.py @@ -66,9 +66,7 @@ def compose_tasks( Return all the combinations of selecting 0 or more tasks from the selection sets """ - product_tasks = list( - product(*[selection_set.tasks + [None] for selection_set in selection_sets]) - ) + product_tasks = list(product(*[selection_set.tasks + [None] for selection_set in selection_sets])) composed_tasks = [] for tasks in product_tasks: if task_validator is not None: diff --git a/vendor/tau2/domains/telecom/tools.py b/vendor/tau2/domains/telecom/tools.py index 0a7fd05d..c092d35e 100644 --- a/vendor/tau2/domains/telecom/tools.py +++ b/vendor/tau2/domains/telecom/tools.py @@ -102,10 +102,7 @@ def get_customer_by_name(self, full_name: str, dob: str) -> List[Customer]: matching_customers = [] for customer in self.db.customers: - if ( - customer.full_name.lower() == full_name.lower() - and customer.date_of_birth == dob - ): + if customer.full_name.lower() == full_name.lower() and customer.date_of_birth == dob: matching_customers.append(customer) return matching_customers @@ -259,9 +256,7 @@ def get_details_by_id(self, id: str) -> Dict[str, Any]: raise ValueError(f"Unknown ID format or type: {id}") @is_tool(ToolType.WRITE) - def suspend_line( - self, customer_id: str, line_id: str, reason: str - ) -> Dict[str, Any]: + def suspend_line(self, customer_id: str, line_id: str, reason: str) -> Dict[str, Any]: """ Suspends a specific line (max 6 months). Checks: Line status must be Active. @@ -411,9 +406,7 @@ def _set_bill_to_paid(self, bill_id: str) -> None: bill.status = BillStatus.PAID return f"Bill {bill_id} set to paid" - def _apply_one_time_charge( - self, customer_id: str, amount: float, description: str - ) -> None: + def _apply_one_time_charge(self, customer_id: str, amount: float, description: str) -> None: """ Internal function to add a specific charge LineItem to the customer's next bill. Creates a pending bill if none exists. @@ -453,11 +446,7 @@ def _apply_one_time_charge( period_start=next_month, period_end=next_month.replace( month=next_month.month + 1 if next_month.month < 12 else 1, - year=( - next_month.year - if next_month.month < 12 - else next_month.year + 1 - ), + year=(next_month.year if next_month.month < 12 else next_month.year + 1), ) - timedelta(days=1), issue_date=next_month, @@ -501,9 +490,7 @@ def get_data_usage(self, customer_id: str, line_id: str) -> Dict[str, Any]: plan = self._get_plan_by_id(target_line.plan_id) today = get_today() - cycle_end_date = date( - today.year, today.month + 1 if today.month < 12 else 1, 1 - ) - timedelta(days=1) + cycle_end_date = date(today.year, today.month + 1 if today.month < 12 else 1, 1) - timedelta(days=1) return { "line_id": line_id, @@ -513,9 +500,7 @@ def get_data_usage(self, customer_id: str, line_id: str) -> Dict[str, Any]: "cycle_end_date": cycle_end_date, } - def set_data_usage( - self, customer_id: str, line_id: str, data_used_gb: float - ) -> str: + def set_data_usage(self, customer_id: str, line_id: str, data_used_gb: float) -> str: """ Sets the data usage for a line. Note: This method is not decorated as a tool but follows similar error handling. @@ -605,9 +590,7 @@ def transfer_to_human_agents(self, summary: str) -> str: return "Transfer successful" @is_tool(ToolType.WRITE) - def refuel_data( - self, customer_id: str, line_id: str, gb_amount: float - ) -> Dict[str, Any]: + def refuel_data(self, customer_id: str, line_id: str, gb_amount: float) -> Dict[str, Any]: """ Refuels data for a specific line, adding to the customer's bill. Checks: Line status must be Active, Customer owns the line. @@ -646,9 +629,7 @@ def refuel_data( f"Data refueling: {gb_amount} GB at ${plan.data_refueling_price_per_gb}/GB", ) - logger.info( - f"Data refueled for line {line_id}: {gb_amount} GB added, charge: ${charge_amount:.2f}" - ) + logger.info(f"Data refueled for line {line_id}: {gb_amount} GB added, charge: ${charge_amount:.2f}") return { "message": f"Successfully added {gb_amount} GB of data for line {line_id} for ${charge_amount:.2f}", @@ -721,27 +702,21 @@ def suspend_line_for_overdue_bill( return f"Line {line_id} suspended for unpaid bill {new_bill_id}. Contract ended: {contract_ended}" ### Assertions - def assert_data_refueling_amount( - self, customer_id: str, line_id: str, expected_amount: float - ) -> bool: + def assert_data_refueling_amount(self, customer_id: str, line_id: str, expected_amount: float) -> bool: """ Assert that the data refueling amount is as expected. """ target_line = self._get_target_line(customer_id, line_id) return abs(target_line.data_refueling_gb - expected_amount) < 1e-6 - def assert_line_status( - self, customer_id: str, line_id: str, expected_status: LineStatus - ) -> bool: + def assert_line_status(self, customer_id: str, line_id: str, expected_status: LineStatus) -> bool: """ Assert that the line status is as expected. """ target_line = self._get_target_line(customer_id, line_id) return target_line.status == expected_status - def assert_overdue_bill_exists( - self, customer_id: str, overdue_bill_id: str - ) -> bool: + def assert_overdue_bill_exists(self, customer_id: str, overdue_bill_id: str) -> bool: """ Assert that the overdue bill exists. """ diff --git a/vendor/tau2/domains/telecom/user_data_model.py b/vendor/tau2/domains/telecom/user_data_model.py index 77d838f9..218bbf93 100644 --- a/vendor/tau2/domains/telecom/user_data_model.py +++ b/vendor/tau2/domains/telecom/user_data_model.py @@ -99,12 +99,8 @@ def is_mms_basic_configured(self) -> bool: class VpnDetails(BaseModelNoExtra): """Holds details about the VPN connection if active.""" - server_address: Optional[str] = Field( - None, description="Address of the connected VPN server." - ) - protocol: Optional[str] = Field( - None, description="VPN protocol being used (e.g., WireGuard, OpenVPN)." - ) + server_address: Optional[str] = Field(None, description="Address of the connected VPN server.") + protocol: Optional[str] = Field(None, description="VPN protocol being used (e.g., WireGuard, OpenVPN).") server_performance: PerformanceLevel = Field( default=PerformanceLevel.UNKNOWN, validate_default=True, @@ -118,9 +114,7 @@ class AppPermissions(BaseModelNoExtra): sms: bool = Field(False, description="Permission to send/read SMS/MMS.") storage: bool = Field(False, description="Permission to access device storage.") phone: bool = Field(False, description="Permission to make/manage phone calls.") - network: bool = Field( - False, description="Permission to access network state/internet." - ) + network: bool = Field(False, description="Permission to access network state/internet.") class AppStatus(BaseModelNoExtra): @@ -146,22 +140,14 @@ class StatusBar(BaseModelNoExtra): validate_default=True, description="The network technology (2G, 3G, 4G, etc.) shown in the status bar.", ) - wifi_connected: bool = Field( - False, description="Whether WiFi is connected and shown in the status bar." - ) - airplane_mode: bool = Field( - False, description="Whether airplane mode is on and shown in the status bar." - ) - vpn_active: bool = Field( - False, description="Whether a VPN is active and shown in the status bar." - ) + wifi_connected: bool = Field(False, description="Whether WiFi is connected and shown in the status bar.") + airplane_mode: bool = Field(False, description="Whether airplane mode is on and shown in the status bar.") + vpn_active: bool = Field(False, description="Whether a VPN is active and shown in the status bar.") data_saver_active: bool = Field( False, description="Whether data saver mode is active and shown in the status bar.", ) - battery_level: int = Field( - 100, description="The battery level (0-100) shown in the status bar." - ) + battery_level: int = Field(100, description="The battery level (0-100) shown in the status bar.") # --- Main Device State Model --- @@ -201,9 +187,7 @@ class MockPhoneAttributes(BaseModelNoExtra): ) # --- Battery --- - battery_level: int = Field( - 80, description="The current battery level, from 0 to 100 percent." - ) + battery_level: int = Field(80, description="The current battery level, from 0 to 100 percent.") # --- Mobile Data --- data_enabled: bool = Field( @@ -230,9 +214,7 @@ class MockPhoneAttributes(BaseModelNoExtra): False, description="Whether the device is currently connected to a Wi-Fi network.", ) - wifi_ssid: Optional[str] = Field( - None, description="The name (SSID) of the connected Wi-Fi network, if any." - ) + wifi_ssid: Optional[str] = Field(None, description="The name (SSID) of the connected Wi-Fi network, if any.") wifi_signal_strength: SignalStrength = Field( default=SignalStrength.NONE, validate_default=True, @@ -240,9 +222,7 @@ class MockPhoneAttributes(BaseModelNoExtra): ) # --- Calling Features --- - wifi_calling_enabled: bool = Field( - False, description="Whether the Wi-Fi Calling feature is enabled." - ) + wifi_calling_enabled: bool = Field(False, description="Whether the Wi-Fi Calling feature is enabled.") wifi_calling_mms_over_wifi: bool = Field( False, description="Preference/capability to send/receive MMS over Wi-Fi (depends on carrier and device support).", @@ -259,9 +239,7 @@ class MockPhoneAttributes(BaseModelNoExtra): False, description="Whether a VPN profile is configured and potentially set to be 'always on' or manually enabled in settings.", ) - vpn_connected: bool = Field( - False, description="Whether there currently is an active VPN connection tunnel." - ) + vpn_connected: bool = Field(False, description="Whether there currently is an active VPN connection tunnel.") vpn_details: Optional[VpnDetails] = Field( None, description="Details about the active VPN connection, if connected." ) @@ -321,13 +299,9 @@ class UserSurroundings(BaseModelNoExtra): """Represents the physical surroundings of the user.""" name: Optional[str] = Field(None, description="The name of the user.") - phone_number: Optional[str] = Field( - None, description="The phone number of the user." - ) + phone_number: Optional[str] = Field(None, description="The phone number of the user.") is_abroad: bool = Field(False, description="Whether the user is currently abroad.") - roaming_allowed: bool = Field( - False, description="Whether the user is allowed to roam." - ) + roaming_allowed: bool = Field(False, description="Whether the user is allowed to roam.") signal_strength: dict[NetworkTechnology, SignalStrength] = Field( default_factory=lambda: { NetworkTechnology.TWO_G: SignalStrength.POOR, @@ -341,17 +315,13 @@ class UserSurroundings(BaseModelNoExtra): False, description="Whether the user has exceeded their data usage limit." ) line_active: bool = Field(True, description="Whether the user has an active line.") - payment_request: Optional[PaymentRequest] = Field( - None, description="The payment that the agent has requested." - ) + payment_request: Optional[PaymentRequest] = Field(None, description="The payment that the agent has requested.") class TelecomUserDB(DB): """Database interface for telecom domain.""" - device: MockPhoneAttributes = Field( - default_factory=MockPhoneAttributes, description="Mock phone device" - ) + device: MockPhoneAttributes = Field(default_factory=MockPhoneAttributes, description="Mock phone device") surroundings: UserSurroundings = Field( default_factory=UserSurroundings, description="User's physical surroundings" ) @@ -381,24 +351,16 @@ def main(): print("\n--- State after enabling Airplane Mode ---") print(f"Airplane Mode: {db.device.airplane_mode}") print(f"Network Status: {db.device.network_connection_status}") - print( - f"Helper - Potentially Online Mobile: {db.device.is_potentially_online_mobile()}" - ) + print(f"Helper - Potentially Online Mobile: {db.device.is_potentially_online_mobile()}") # 3. Simulate another problem: User disables Mobile Data and has wrong APN MMS URL # Start from default state again for clarity db = TelecomUserDB() update_2 = { "data_enabled": False, - "active_apn_settings": { # Update nested model - "mmsc_url": None # Simulate missing MMS config - }, + "active_apn_settings": {"mmsc_url": None}, # Update nested model # Simulate missing MMS config "app_statuses": { # Update nested dictionary/model - "messaging": { - "permissions": { - "storage": False - } # Update nested AppPermissions model field - } + "messaging": {"permissions": {"storage": False}} # Update nested AppPermissions model field }, } db.update_device(update_2) diff --git a/vendor/tau2/domains/telecom/user_tools.py b/vendor/tau2/domains/telecom/user_tools.py index aedf9fc4..89cedceb 100644 --- a/vendor/tau2/domains/telecom/user_tools.py +++ b/vendor/tau2/domains/telecom/user_tools.py @@ -28,9 +28,7 @@ class TelecomUserTools(ToolKitBase): db: TelecomUserDB - network_mode_preference: NetworkModePreference = ( - NetworkModePreference.FOUR_G_5G_PREFERRED - ) + network_mode_preference: NetworkModePreference = NetworkModePreference.FOUR_G_5G_PREFERRED default_vpn_details: VpnDetails = VpnDetails( server_address="192.168.1.1", @@ -100,19 +98,14 @@ def _check_status_bar(self) -> str: SignalStrength.GOOD: "๐Ÿ“ถยณ Good", SignalStrength.EXCELLENT: "๐Ÿ“ถโด Excellent", } - indicators.append( - signal_map.get(device.network_signal_strength, "๐Ÿ“ต No Signal") - ) + indicators.append(signal_map.get(device.network_signal_strength, "๐Ÿ“ต No Signal")) # Network technology if device.network_technology_connected != NetworkTechnology.NONE: indicators.append(device.network_technology_connected.value) # Data enabled indicator - if ( - device.data_enabled - and device.network_technology_connected != NetworkTechnology.NONE - ): + if device.data_enabled and device.network_technology_connected != NetworkTechnology.NONE: indicators.append("๐Ÿ“ฑ Data Enabled") if device.data_saver_mode: indicators.append("๐Ÿ”ฝ Data Saver") @@ -186,9 +179,7 @@ def _check_network_mode_preference(self) -> NetworkModePreference: return self.device.network_mode_preference @is_tool(ToolType.WRITE) - def set_network_mode_preference( - self, mode: Union[NetworkModePreference, str] - ) -> str: + def set_network_mode_preference(self, mode: Union[NetworkModePreference, str]) -> str: """Changes the type of cellular network your phone prefers to connect to (e.g., 5G, LTE/4G, 3G). Higher-speed networks (LTE/5G) provide faster data but may use more battery.""" valid_mode = self._set_network_mode_preference(mode) if valid_mode is None: @@ -196,9 +187,7 @@ def set_network_mode_preference( status_update = f"Preferred Network Mode set to: {valid_mode.value}" return f"{status_update}\nStatus Bar: {self._check_status_bar()}" - def _set_network_mode_preference( - self, mode: Union[NetworkModePreference, str] - ) -> Optional[NetworkModePreference]: + def _set_network_mode_preference(self, mode: Union[NetworkModePreference, str]) -> Optional[NetworkModePreference]: """Sets the preferred network mode. This will trigger a network search. """ @@ -222,10 +211,7 @@ def _get_mobile_data_working(self) -> bool: - Data is not enabled - Data usage is exceeded """ - if ( - self.device.airplane_mode - or self.device.network_signal_strength == SignalStrength.NONE - ): + if self.device.airplane_mode or self.device.network_signal_strength == SignalStrength.NONE: return False if self.device.network_connection_status == NetworkStatus.NO_SERVICE: @@ -255,9 +241,7 @@ def run_speed_test(self) -> str: if description == "Very Poor": advice = "Connection is very slow. Basic web browsing might be difficult." elif description == "Poor": - advice = ( - "Connection is slow. Web browsing may be sluggish, streaming difficult." - ) + advice = "Connection is slow. Web browsing may be sluggish, streaming difficult." elif description == "Fair": advice = "Connection is okay for web browsing and some standard definition streaming." elif description == "Good": @@ -328,9 +312,7 @@ def _run_speed_test(self) -> Tuple[Optional[float], Optional[str]]: NetworkTechnology.FIVE_G: (50.0, 500.0), NetworkTechnology.NONE: (0.0, 0.0), } - min_speed, max_speed = tech_speed_map.get( - self.device.network_technology_connected, (0.0, 0.0) - ) + min_speed, max_speed = tech_speed_map.get(self.device.network_technology_connected, (0.0, 0.0)) # Adjust speed based on signal strength signal_factor_map = { @@ -343,9 +325,7 @@ def _run_speed_test(self) -> Tuple[Optional[float], Optional[str]]: signal_factor = signal_factor_map.get(self.device.network_signal_strength, 0.0) # Calculate simulated speed - simulated_speed = ( - (min_speed + max_speed) / 2.0 * signal_factor * base_speed_factor - ) + simulated_speed = (min_speed + max_speed) / 2.0 * signal_factor * base_speed_factor simulated_speed = round(simulated_speed, 2) # Determine description @@ -634,7 +614,9 @@ def check_wifi_status(self) -> str: if not status["enabled"]: return "Wi-Fi is turned OFF." if status["connected"]: - return f"Wi-Fi is ON and connected to '{status['ssid']}'. Signal strength: {status['signal_strength'].value}." + return ( + f"Wi-Fi is ON and connected to '{status['ssid']}'. Signal strength: {status['signal_strength'].value}." + ) else: return "Wi-Fi is ON but not connected to any network." @@ -702,9 +684,7 @@ def _toggle_wifi_calling(self) -> bool: self.device.wifi_calling_enabled = new_state return new_state - def set_wifi_calling( - self, enabled: bool, mms_over_wifi: Optional[bool] = None - ) -> str: + def set_wifi_calling(self, enabled: bool, mms_over_wifi: Optional[bool] = None) -> str: """Set the Wi-Fi Calling setting. Set MMS over WIFI accordingly if provided.""" if self.device.wifi_calling_enabled != enabled: self._toggle_wifi_calling() @@ -736,9 +716,7 @@ def _check_vpn_status(self) -> Dict[str, Any]: "enabled_setting": self.device.vpn_enabled_setting, "connected": self.device.vpn_connected, "details": ( - self.device.vpn_details.model_dump() - if self.device.vpn_details and self.device.vpn_connected - else None + self.device.vpn_details.model_dump() if self.device.vpn_details and self.device.vpn_connected else None ), } @@ -748,11 +726,7 @@ def connect_vpn(self) -> str: connected = self._connect_vpn() if connected is None: return "VPN already connected." - status_update = ( - "VPN connected successfully." - if connected - else "No VPN connection to connect." - ) + status_update = "VPN connected successfully." if connected else "No VPN connection to connect." return f"{status_update}\nStatus Bar: {self._check_status_bar()}" def _connect_vpn(self) -> Optional[bool]: @@ -769,11 +743,7 @@ def _connect_vpn(self) -> Optional[bool]: def disconnect_vpn(self) -> str: """Disconnects any active VPN (Virtual Private Network) connection. Stops routing your internet traffic through a VPN server, which might affect connection speed or access to content.""" disconnected = self._disconnect_vpn() - status_update = ( - "VPN disconnected successfully." - if disconnected - else "No active VPN connection to disconnect." - ) + status_update = "VPN disconnected successfully." if disconnected else "No active VPN connection to disconnect." return f"{status_update}\nStatus Bar: {self._check_status_bar()}" def _disconnect_vpn(self) -> bool: @@ -975,46 +945,34 @@ def simulate_network_search(self): self.device.network_connection_status = NetworkStatus.CONNECTED pref = self.device.network_mode_preference if pref == NetworkModePreference.FOUR_G_5G_PREFERRED: - five_g_signal = self.surroundings.signal_strength.get( - NetworkTechnology.FIVE_G, SignalStrength.NONE - ) + five_g_signal = self.surroundings.signal_strength.get(NetworkTechnology.FIVE_G, SignalStrength.NONE) if five_g_signal == SignalStrength.NONE: self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) else: self.device.network_technology_connected = NetworkTechnology.FIVE_G self.device.network_signal_strength = five_g_signal elif pref == NetworkModePreference.FOUR_G_ONLY: self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) elif pref == NetworkModePreference.THREE_G_ONLY: self.device.network_technology_connected = NetworkTechnology.THREE_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.THREE_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.THREE_G, SignalStrength.NONE ) elif pref == NetworkModePreference.TWO_G_ONLY: self.device.network_technology_connected = NetworkTechnology.TWO_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.TWO_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.TWO_G, SignalStrength.NONE ) else: # Default fallback self.device.network_technology_connected = NetworkTechnology.FOUR_G - self.device.network_signal_strength = ( - self.surroundings.signal_strength.get( - NetworkTechnology.FOUR_G, SignalStrength.NONE - ) + self.device.network_signal_strength = self.surroundings.signal_strength.get( + NetworkTechnology.FOUR_G, SignalStrength.NONE ) elif sim_status in [SimStatus.MISSING]: @@ -1120,9 +1078,7 @@ def assert_mobile_data_saver_mode_status(self, expected_status: bool) -> bool: """ return self.device.data_saver_mode == expected_status - def assert_internet_speed( - self, expected_speed: float, expected_desc: Optional[str] = None - ) -> bool: + def assert_internet_speed(self, expected_speed: float, expected_desc: Optional[str] = None) -> bool: """ Assert that the internet speed is as expected. """ diff --git a/vendor/tau2/environment/server.py b/vendor/tau2/environment/server.py index 8faddd13..cce3fb2b 100644 --- a/vendor/tau2/environment/server.py +++ b/vendor/tau2/environment/server.py @@ -85,7 +85,8 @@ def _format_description(self, policy: str) -> str: description.append(content) # Add the tools section - description.append(""" + description.append( + """ ## Tools @@ -98,7 +99,8 @@ def _format_description(self, policy: str) -> str: ### Response Format All successful responses will return the tool's output directly. Errors will return a 400 status code with an error message. -""") +""" + ) return "\n".join(description) @@ -161,20 +163,14 @@ async def tool_endpoint( ) -> Any: try: if route_prefix == "user_tools": - result = self.environment.use_user_tool( - tool_name=tool_name, **request.model_dump() - ) + result = self.environment.use_user_tool(tool_name=tool_name, **request.model_dump()) else: - result = self.environment.use_tool( - tool_name=tool_name, **request.model_dump() - ) + result = self.environment.use_tool(tool_name=tool_name, **request.model_dump()) return result except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - def _format_tool_description( - self, doc: str, returns: Optional[dict] = None, is_user_tool: bool = False - ) -> str: + def _format_tool_description(self, doc: str, returns: Optional[dict] = None, is_user_tool: bool = False) -> str: """Format tool documentation for better ReDoc rendering""" import re diff --git a/vendor/tau2/environment/tool.py b/vendor/tau2/environment/tool.py index cf55463c..fa63bdcb 100644 --- a/vendor/tau2/environment/tool.py +++ b/vendor/tau2/environment/tool.py @@ -49,9 +49,7 @@ class Tool(BaseTool): """The parameters of the Tool.""" returns: type[BaseModel] = Field(..., description="The return of the Tool") """The return of the Tool.""" - raises: List[Dict[str, Optional[str]]] = Field( - [], description="The exceptions raised by the Tool" - ) + raises: List[Dict[str, Optional[str]]] = Field([], description="The exceptions raised by the Tool") """The exceptions raised by the Tool.""" examples: List[str] = Field([], description="The examples of the Tool") """The examples of the Tool.""" @@ -79,9 +77,7 @@ def __init__(self, func: Callable, use_short_desc: bool = False, **predefined: A self.__doc__ = doc # overwrite the doc string @classmethod - def parse_data( - cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any] - ) -> Dict[str, Any]: + def parse_data(cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any]) -> Dict[str, Any]: """Parse data from the signature and docstring of a function.""" doc = parse(docstring or "") data: Dict[str, Any] = { @@ -127,9 +123,7 @@ def parse_data( data["returns"] = create_model("returns", returns=(anno, default)) # build raises - data["raises"] = [ - {"type": exc.type_name, "desc": exc.description} for exc in doc.raises - ] + data["raises"] = [{"type": exc.type_name, "desc": exc.description} for exc in doc.raises] # build examples data["examples"] = doc.examples diff --git a/vendor/tau2/environment/toolkit.py b/vendor/tau2/environment/toolkit.py index fff04d60..d21923e2 100644 --- a/vendor/tau2/environment/toolkit.py +++ b/vendor/tau2/environment/toolkit.py @@ -102,18 +102,10 @@ def tool_type(self, tool_name: str) -> ToolType: def get_statistics(self) -> dict[str, Any]: """Get the statistics of the ToolKit.""" num_tools = len(self.tools) - num_read_tools = sum( - self.tool_type(name) == ToolType.READ for name in self.tools - ) - num_write_tools = sum( - self.tool_type(name) == ToolType.WRITE for name in self.tools - ) - num_think_tools = sum( - self.tool_type(name) == ToolType.THINK for name in self.tools - ) - num_generic_tools = sum( - self.tool_type(name) == ToolType.GENERIC for name in self.tools - ) + num_read_tools = sum(self.tool_type(name) == ToolType.READ for name in self.tools) + num_write_tools = sum(self.tool_type(name) == ToolType.WRITE for name in self.tools) + num_think_tools = sum(self.tool_type(name) == ToolType.THINK for name in self.tools) + num_generic_tools = sum(self.tool_type(name) == ToolType.GENERIC for name in self.tools) return { "num_tools": num_tools, "num_read_tools": num_read_tools, diff --git a/vendor/tau2/environment/utils/interface_agent.py b/vendor/tau2/environment/utils/interface_agent.py index 015b9cd2..773bbfb9 100644 --- a/vendor/tau2/environment/utils/interface_agent.py +++ b/vendor/tau2/environment/utils/interface_agent.py @@ -216,15 +216,11 @@ def get_prompt_text() -> str: if message == ":n": console.print("[info]Starting new session...[/]") - interface_agent, message_history = init_session( - current_domain - ) + interface_agent, message_history = init_session(current_domain) continue with console.status("[info]Processing query...[/]"): - response, message_history = interface_agent.respond( - message, message_history - ) + response, message_history = interface_agent.respond(message, message_history) # Try to parse response as markdown for better formatting try: @@ -232,9 +228,7 @@ def get_prompt_text() -> str: console.print("\n[bold]Response:[/]") console.print(md) except Exception as e: - console.print( - f"\n[error]Error parsing response:[/] {str(e)}" - ) + console.print(f"\n[error]Error parsing response:[/] {str(e)}") console.print("\n[bold]Response:[/]", response.content) except KeyboardInterrupt: @@ -244,9 +238,7 @@ def get_prompt_text() -> str: console.print(f"\n[error]Error processing message:[/] {str(e)}") except Exception as e: - console.print( - f"\n[error]Error initializing domain '{current_domain}':[/] {str(e)}" - ) + console.print(f"\n[error]Error initializing domain '{current_domain}':[/] {str(e)}") new_domain = change_domain(console) if new_domain is None: return diff --git a/vendor/tau2/evaluator/__init__.py b/vendor/tau2/evaluator/__init__.py index 8b137891..e69de29b 100644 --- a/vendor/tau2/evaluator/__init__.py +++ b/vendor/tau2/evaluator/__init__.py @@ -1 +0,0 @@ - diff --git a/vendor/tau2/evaluator/evaluator.py b/vendor/tau2/evaluator/evaluator.py index fa3b4791..f206c3dc 100644 --- a/vendor/tau2/evaluator/evaluator.py +++ b/vendor/tau2/evaluator/evaluator.py @@ -33,9 +33,7 @@ def evaluate_simulation( }: return RewardInfo( reward=0.0, - info={ - "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}" - }, + info={"note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}"}, ) if task.evaluation_criteria is None: return RewardInfo( diff --git a/vendor/tau2/evaluator/evaluator_action.py b/vendor/tau2/evaluator/evaluator_action.py index 518475ba..3576ce2e 100644 --- a/vendor/tau2/evaluator/evaluator_action.py +++ b/vendor/tau2/evaluator/evaluator_action.py @@ -59,10 +59,7 @@ def evaluate_actions( predicted_tool_calls: list[ToolCall] = [] for message in full_trajectory: - if ( - isinstance(message, AssistantMessage) - or isinstance(message, UserMessage) - ) and message.is_tool_call(): + if (isinstance(message, AssistantMessage) or isinstance(message, UserMessage)) and message.is_tool_call(): predicted_tool_calls.extend(message.tool_calls) # Check if all the gold actions are in the predicted actions diff --git a/vendor/tau2/evaluator/evaluator_communicate.py b/vendor/tau2/evaluator/evaluator_communicate.py index 43eecebf..8eada207 100644 --- a/vendor/tau2/evaluator/evaluator_communicate.py +++ b/vendor/tau2/evaluator/evaluator_communicate.py @@ -32,9 +32,7 @@ def calculate_reward( reward_breakdown={RewardType.COMMUNICATE: 1.0}, ) - communicate_info_checks = cls.evaluate_communicate_info( - full_trajectory, communicate_info - ) + communicate_info_checks = cls.evaluate_communicate_info(full_trajectory, communicate_info) # Calculate reward: 1 if all expectations are met, 0 otherwise all_expectations_met = all(result.met for result in communicate_info_checks) @@ -66,9 +64,7 @@ def evaluate_communicate_info( continue if not message.has_text_content(): continue - if info_str.lower() in message.content.lower().replace( - ",", "" - ): # TODO: This could be improved! + if info_str.lower() in message.content.lower().replace(",", ""): # TODO: This could be improved! found = True break if found: diff --git a/vendor/tau2/evaluator/evaluator_env.py b/vendor/tau2/evaluator/evaluator_env.py index 46c3e296..d2b724a4 100644 --- a/vendor/tau2/evaluator/evaluator_env.py +++ b/vendor/tau2/evaluator/evaluator_env.py @@ -49,24 +49,15 @@ def calculate_reward( ) initialization_data = None - if ( - task.initial_state is not None - and task.initial_state.initialization_data is not None - ): + if task.initial_state is not None and task.initial_state.initialization_data is not None: initialization_data = task.initial_state.initialization_data initialization_actions = None - if ( - task.initial_state is not None - and task.initial_state.initialization_actions is not None - ): + if task.initial_state is not None and task.initial_state.initialization_actions is not None: initialization_actions = task.initial_state.initialization_actions message_history = [] - if ( - task.initial_state is not None - and task.initial_state.message_history is not None - ): + if task.initial_state is not None and task.initial_state.message_history is not None: message_history = task.initial_state.message_history predicted_environment = environment_constructor(solo_mode=solo_mode) @@ -77,10 +68,7 @@ def calculate_reward( ) predicted_tool_calls: list[ToolCall] = [] for message in full_trajectory: - if ( - isinstance(message, AssistantMessage) - or isinstance(message, UserMessage) - ) and message.is_tool_call(): + if (isinstance(message, AssistantMessage) or isinstance(message, UserMessage)) and message.is_tool_call(): predicted_tool_calls.extend(message.tool_calls) # Setting up gold environment @@ -99,9 +87,7 @@ def calculate_reward( **action.arguments, ) except Exception as e: - logger.warning( - f"Error in golden actions {action.name}({action.arguments}): {e}" - ) + logger.warning(f"Error in golden actions {action.name}({action.arguments}): {e}") # Comparing the environments agent_db_hash = gold_environment.get_db_hash() diff --git a/vendor/tau2/metrics/agent_metrics.py b/vendor/tau2/metrics/agent_metrics.py index 3192e698..8f8c30b8 100644 --- a/vendor/tau2/metrics/agent_metrics.py +++ b/vendor/tau2/metrics/agent_metrics.py @@ -55,9 +55,7 @@ def get_metrics_df(results: Results) -> tuple[pd.DataFrame, int]: df = results.to_df() df["success"] = df.reward.apply(is_successful) if len(df.info_num_trials.unique()) > 1: - logger.warning( - f"All simulations must have the same number of trials. Found {df.info_num_trials.unique()}" - ) + logger.warning(f"All simulations must have the same number of trials. Found {df.info_num_trials.unique()}") max_k = df.info_num_trials.max() task_ids_counts = [(tid, count) for tid, count in df.task_id.value_counts().items()] @@ -78,9 +76,7 @@ def get_tasks_pass_hat_k(results: Results) -> pd.DataFrame: df, max_k = get_metrics_df(results) dfs = [] for k in range(1, max_k + 1): - res = df.groupby("task_id")["success"].apply( - lambda df: pass_hat_k(len(df), df.sum(), k) - ) + res = df.groupby("task_id")["success"].apply(lambda df: pass_hat_k(len(df), df.sum(), k)) res.name = f"pass^{k}" dfs.append(res) df_pass_hat_k = pd.concat(dfs, axis=1) diff --git a/vendor/tau2/metrics/break_down_metrics.py b/vendor/tau2/metrics/break_down_metrics.py index 3b6e9571..3ecf4e35 100644 --- a/vendor/tau2/metrics/break_down_metrics.py +++ b/vendor/tau2/metrics/break_down_metrics.py @@ -24,9 +24,7 @@ def get_write_tools(domain): return set(agent_write_tools), set(user_write_tools) -def analyze_reward( - reward_info: RewardInfo, agent_write_tools: set[str], user_write_tools: set[str] -): +def analyze_reward(reward_info: RewardInfo, agent_write_tools: set[str], user_write_tools: set[str]): """ Analyze the reward breakdown. """ @@ -34,26 +32,18 @@ def analyze_reward( try: if RewardType.COMMUNICATE in reward_info.reward_basis: communicate_success = ( - is_successful(reward_breakdown[RewardType.COMMUNICATE]) - if reward_breakdown is not None - else 0 + is_successful(reward_breakdown[RewardType.COMMUNICATE]) if reward_breakdown is not None else 0 ) else: communicate_success = None if RewardType.ENV_ASSERTION in reward_info.reward_basis: env_success = ( - is_successful(reward_breakdown[RewardType.ENV_ASSERTION]) - if reward_breakdown is not None - else 0 + is_successful(reward_breakdown[RewardType.ENV_ASSERTION]) if reward_breakdown is not None else 0 ) else: env_success = None if RewardType.DB in reward_info.reward_basis: - db_success = ( - is_successful(reward_breakdown[RewardType.DB]) - if reward_breakdown is not None - else 0 - ) + db_success = is_successful(reward_breakdown[RewardType.DB]) if reward_breakdown is not None else 0 else: db_success = None except Exception as e: @@ -110,13 +100,9 @@ def result_reward_analysis(results: Results): Analyze the reward breakdown. """ rows = [] - agent_write_tools, user_write_tools = get_write_tools( - results.info.environment_info.domain_name - ) + agent_write_tools, user_write_tools = get_write_tools(results.info.environment_info.domain_name) for simulation in results.simulations: - reward_analysis = analyze_reward( - simulation.reward_info, agent_write_tools, user_write_tools - ) + reward_analysis = analyze_reward(simulation.reward_info, agent_write_tools, user_write_tools) reward_analysis["task_id"] = simulation.task_id reward_analysis["trial"] = simulation.trial rows.append(reward_analysis) diff --git a/vendor/tau2/orchestrator/environment_manager.py b/vendor/tau2/orchestrator/environment_manager.py index 1b5b7bc9..f34d0758 100644 --- a/vendor/tau2/orchestrator/environment_manager.py +++ b/vendor/tau2/orchestrator/environment_manager.py @@ -145,9 +145,7 @@ async def status(): @self.app.post("/start_environment") async def start_env(request: StartEnvironmentRequest) -> EnvironmentResponse: - env_id = self.start_environment( - domain=request.domain, env_id=request.env_id - ) + env_id = self.start_environment(domain=request.domain, env_id=request.env_id) return EnvironmentResponse(env_id=env_id) @self.app.post("/{env_id}/set_state") @@ -169,9 +167,7 @@ async def get_info(env_id: str) -> EnvironmentInfo: return self.get_environment_info(env_id) @self.app.post("/{env_id}/tools/{tool_name}") - async def execute_tool( - env_id: str, tool_name: str, request: ToolCall - ) -> ToolMessage: + async def execute_tool(env_id: str, tool_name: str, request: ToolCall) -> ToolMessage: return self.execute_tool(env_id=env_id, tool_call=request) def get_environment_id(self) -> str: @@ -210,12 +206,8 @@ def set_environment_state( Set the state of an environment. """ - self.environments[env_id].set_state( - initialization_data, initialization_actions, message_history - ) - self.trajectories[env_id] = [ - msg for msg in message_history if is_valid_environment_message(msg) - ] + self.environments[env_id].set_state(initialization_data, initialization_actions, message_history) + self.trajectories[env_id] = [msg for msg in message_history if is_valid_environment_message(msg)] def stop_environment(self, env_id: str): """ @@ -225,9 +217,7 @@ def stop_environment(self, env_id: str): # Get the router instance router = self.app.router # Filter out the routes we want to remove - router.routes = [ - route for route in router.routes if route not in self.routes[env_id] - ] + router.routes = [route for route in router.routes if route not in self.routes[env_id]] del self.routes[env_id] if env_id in self.environments: diff --git a/vendor/tau2/orchestrator/orchestrator.py b/vendor/tau2/orchestrator/orchestrator.py index 172519c1..20e8c494 100644 --- a/vendor/tau2/orchestrator/orchestrator.py +++ b/vendor/tau2/orchestrator/orchestrator.py @@ -31,9 +31,7 @@ class Role(str, Enum): ENV = "env" -DEFAULT_FIRST_AGENT_MESSAGE = AssistantMessage( - role="assistant", content="Hi! How can I help you today?", cost=0.0 -) +DEFAULT_FIRST_AGENT_MESSAGE = AssistantMessage(role="assistant", content="Hi! How can I help you today?", cost=0.0) class Orchestrator: @@ -82,12 +80,8 @@ def initialize(self): - Send the first message (default message from the agent to the user). """ initial_state = self.task.initial_state - initialization_data = ( - initial_state.initialization_data if initial_state is not None else None - ) - initialization_actions = ( - initial_state.initialization_actions if initial_state is not None else None - ) + initialization_data = initial_state.initialization_data if initial_state is not None else None + initialization_actions = initial_state.initialization_actions if initial_state is not None else None message_history = ( deepcopy(initial_state.message_history) if initial_state is not None and initial_state.message_history is not None @@ -101,12 +95,8 @@ def initialize(self): if self.solo_mode: assert self.environment.solo_mode, "Environment should be in solo mode" - assert isinstance(self.agent, LLMSoloAgent), ( - "Agent must be a LLMSoloAgent in solo mode" - ) - assert isinstance(self.user, DummyUser), ( - "User must be a DummyUser in solo mode" - ) + assert isinstance(self.agent, LLMSoloAgent), "Agent must be a LLMSoloAgent in solo mode" + assert isinstance(self.user, DummyUser), "User must be a DummyUser in solo mode" # Initialize Environment state self._initialize_environment( @@ -133,18 +123,10 @@ def initialize(self): else: # Last message is for the environment self.to_role = Role.ENV self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_user_history_message(msg)] ) self.message = last_message if self.agent.is_stop(last_message): @@ -158,18 +140,10 @@ def initialize(self): else: # Last message is for the environment self.to_role = Role.ENV self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_user_history_message(msg)] ) self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_agent_history_message(msg)] ) self.message = last_message self.done = UserSimulator.is_stop(last_message) @@ -181,34 +155,18 @@ def initialize(self): if last_message.requestor == "assistant": self.to_role = Role.AGENT self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_user_history_message(msg)] ) else: self.to_role = Role.USER self.agent_state = self.agent.get_init_state( - message_history=[ - msg - for msg in message_history - if is_valid_agent_history_message(msg) - ] + message_history=[msg for msg in message_history if is_valid_agent_history_message(msg)] ) self.user_state = self.user.get_init_state( - message_history=[ - msg - for msg in message_history[:-1] - if is_valid_user_history_message(msg) - ] + message_history=[msg for msg in message_history[:-1] if is_valid_user_history_message(msg)] ) self.message = last_message else: @@ -228,9 +186,7 @@ def initialize(self): self.from_role = Role.AGENT self.to_role = Role.USER else: - first_message, agent_state = self.agent.generate_next_message( - None, self.agent_state - ) + first_message, agent_state = self.agent.generate_next_message(None, self.agent_state) self.trajectory = [first_message] self.message = first_message self.from_role = Role.AGENT @@ -290,17 +246,13 @@ def step(self): """ if self.done: raise ValueError("Simulation is done") - logger.debug( - f"Step {self.step_count}. Sending message from {self.from_role} to {self.to_role}" - ) + logger.debug(f"Step {self.step_count}. Sending message from {self.from_role} to {self.to_role}") logger.debug( f"Step {self.step_count}.\nFrom role: {self.from_role}\nTo role: {self.to_role}\nMessage: {self.message}" ) # AGENT/ENV -> USER if self.from_role in [Role.AGENT, Role.ENV] and self.to_role == Role.USER: - user_msg, self.user_state = self.user.generate_next_message( - self.message, self.user_state - ) + user_msg, self.user_state = self.user.generate_next_message(self.message, self.user_state) user_msg.validate() if UserSimulator.is_stop(user_msg): self.done = True @@ -313,12 +265,8 @@ def step(self): else: self.to_role = Role.AGENT # USER/ENV -> AGENT - elif ( - self.from_role == Role.USER or self.from_role == Role.ENV - ) and self.to_role == Role.AGENT: - agent_msg, self.agent_state = self.agent.generate_next_message( - self.message, self.agent_state - ) + elif (self.from_role == Role.USER or self.from_role == Role.ENV) and self.to_role == Role.AGENT: + agent_msg, self.agent_state = self.agent.generate_next_message(self.message, self.agent_state) agent_msg.validate() if self.agent.is_stop(agent_msg): self.done = True @@ -338,13 +286,11 @@ def step(self): for tool_call in self.message.tool_calls: tool_msg = self.environment.get_response(tool_call) tool_msgs.append(tool_msg) - assert len(self.message.tool_calls) == len(tool_msgs), ( - "Number of tool calls and tool messages should be the same" - ) + assert len(self.message.tool_calls) == len( + tool_msgs + ), "Number of tool calls and tool messages should be the same" self.trajectory.extend(tool_msgs) - if ( - len(tool_msgs) > 1 - ): # Packaging multiple tool messages into a MultiToolMessage + if len(tool_msgs) > 1: # Packaging multiple tool messages into a MultiToolMessage self.message = MultiToolMessage( role="tool", tool_messages=tool_msgs, @@ -354,9 +300,7 @@ def step(self): self.to_role = self.from_role self.from_role = Role.ENV else: - raise ValueError( - f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}" - ) + raise ValueError(f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}") self.step_count += 1 self.environment.sync_tools() @@ -403,9 +347,7 @@ def validate_message_history(cls, message_history: list[Message]): if num_expected_tool_messages == 0 or requestor is None: raise ValueError("No tool messages expected.") if requestor != msg.requestor: - raise ValueError( - f"Got tool message from {msg.requestor}, expected {requestor}." - ) + raise ValueError(f"Got tool message from {msg.requestor}, expected {requestor}.") num_expected_tool_messages -= 1 else: raise ValueError(f"Invalid message type: {type(msg)}") @@ -435,13 +377,9 @@ def _count_errors(self, message_history: list[Message]) -> int: """ Count the number of errors in the message history. """ - return sum( - 1 for msg in message_history if isinstance(msg, ToolMessage) and msg.error - ) + return sum(1 for msg in message_history if isinstance(msg, ToolMessage) and msg.error) - def _add_timestamps( - self, message_history: list[Message] - ) -> list[tuple[str, Message]]: + def _add_timestamps(self, message_history: list[Message]) -> list[tuple[str, Message]]: """ Add timestamps to the message history. This is used to sort the messages by timestamp. diff --git a/vendor/tau2/orchestrator/utils.py b/vendor/tau2/orchestrator/utils.py index 5b119813..d15b6435 100644 --- a/vendor/tau2/orchestrator/utils.py +++ b/vendor/tau2/orchestrator/utils.py @@ -5,6 +5,4 @@ def is_valid_environment_message(msg: Message) -> bool: """ Check if the message is valid to the environment. """ - return isinstance(msg, ToolMessage) or ( - isinstance(msg, AssistantMessage) and msg.is_tool_call() - ) + return isinstance(msg, ToolMessage) or (isinstance(msg, AssistantMessage) and msg.is_tool_call()) diff --git a/vendor/tau2/registry.py b/vendor/tau2/registry.py index 764b917c..ac3971a3 100644 --- a/vendor/tau2/registry.py +++ b/vendor/tau2/registry.py @@ -7,29 +7,21 @@ from vendor.tau2.agent.base import BaseAgent from vendor.tau2.agent.llm_agent import LLMAgent, LLMGTAgent, LLMSoloAgent from vendor.tau2.data_model.tasks import Task -from vendor.tau2.domains.airline.environment import \ - get_environment as airline_domain_get_environment -from vendor.tau2.domains.airline.environment import \ - get_tasks as airline_domain_get_tasks -from vendor.tau2.domains.mock.environment import \ - get_environment as mock_domain_get_environment +from vendor.tau2.domains.airline.environment import get_environment as airline_domain_get_environment +from vendor.tau2.domains.airline.environment import get_tasks as airline_domain_get_tasks +from vendor.tau2.domains.mock.environment import get_environment as mock_domain_get_environment from vendor.tau2.domains.mock.environment import get_tasks as mock_domain_get_tasks -from vendor.tau2.domains.retail.environment import \ - get_environment as retail_domain_get_environment -from vendor.tau2.domains.retail.environment import \ - get_tasks as retail_domain_get_tasks -from vendor.tau2.domains.telecom.environment import \ - get_environment_manual_policy as \ - telecom_domain_get_environment_manual_policy -from vendor.tau2.domains.telecom.environment import \ - get_environment_workflow_policy as \ - telecom_domain_get_environment_workflow_policy -from vendor.tau2.domains.telecom.environment import \ - get_tasks as telecom_domain_get_tasks -from vendor.tau2.domains.telecom.environment import \ - get_tasks_full as telecom_domain_get_tasks_full -from vendor.tau2.domains.telecom.environment import \ - get_tasks_small as telecom_domain_get_tasks_small +from vendor.tau2.domains.retail.environment import get_environment as retail_domain_get_environment +from vendor.tau2.domains.retail.environment import get_tasks as retail_domain_get_tasks +from vendor.tau2.domains.telecom.environment import ( + get_environment_manual_policy as telecom_domain_get_environment_manual_policy, +) +from vendor.tau2.domains.telecom.environment import ( + get_environment_workflow_policy as telecom_domain_get_environment_workflow_policy, +) +from vendor.tau2.domains.telecom.environment import get_tasks as telecom_domain_get_tasks +from vendor.tau2.domains.telecom.environment import get_tasks_full as telecom_domain_get_tasks_full +from vendor.tau2.domains.telecom.environment import get_tasks_small as telecom_domain_get_tasks_small from vendor.tau2.environment.environment import Environment from vendor.tau2.user.base import BaseUser from vendor.tau2.user.user_simulator import DummyUser, UserSimulator @@ -184,13 +176,13 @@ def get_info(self) -> RegistryInfo: registry.register_domain(retail_domain_get_environment, "retail") registry.register_tasks(retail_domain_get_tasks, "retail") registry.register_domain(telecom_domain_get_environment_manual_policy, "telecom") - registry.register_domain( - telecom_domain_get_environment_workflow_policy, "telecom-workflow" - ) + registry.register_domain(telecom_domain_get_environment_workflow_policy, "telecom-workflow") registry.register_tasks(telecom_domain_get_tasks_full, "telecom_full") registry.register_tasks(telecom_domain_get_tasks_small, "telecom_small") registry.register_tasks(telecom_domain_get_tasks, "telecom") registry.register_tasks(telecom_domain_get_tasks, "telecom-workflow") - logger.debug(f"Default components registered successfully. Registry info: {json.dumps(registry.get_info().model_dump(), indent=2)}") + logger.debug( + f"Default components registered successfully. Registry info: {json.dumps(registry.get_info().model_dump(), indent=2)}" + ) except Exception as e: logger.error(f"Error initializing registry: {str(e)}") diff --git a/vendor/tau2/run.py b/vendor/tau2/run.py index c2813d34..d49521ab 100644 --- a/vendor/tau2/run.py +++ b/vendor/tau2/run.py @@ -8,8 +8,7 @@ from loguru import logger from vendor.tau2.agent.llm_agent import LLMAgent, LLMGTAgent, LLMSoloAgent -from vendor.tau2.data_model.simulation import (AgentInfo, Info, Results, RunConfig, - SimulationRun, UserInfo) +from vendor.tau2.data_model.simulation import AgentInfo, Info, Results, RunConfig, SimulationRun, UserInfo from vendor.tau2.data_model.tasks import Task from vendor.tau2.environment.environment import Environment, EnvironmentInfo from vendor.tau2.evaluator.evaluator import EvaluationType, evaluate_simulation @@ -29,9 +28,7 @@ def get_options() -> RegistryInfo: return registry.get_info() -def get_environment_info( - domain_name: str, include_tool_info: bool = False -) -> EnvironmentInfo: +def get_environment_info(domain_name: str, include_tool_info: bool = False) -> EnvironmentInfo: """Get information about the environment for a registered Domain""" global registry env_constructor = registry.get_env_constructor(domain_name) @@ -59,14 +56,10 @@ def get_tasks( if task_ids is None: tasks = load_tasks(task_set_name=task_set_name) else: - tasks = [ - task for task in load_tasks(task_set_name=task_set_name) if task.id in task_ids - ] + tasks = [task for task in load_tasks(task_set_name=task_set_name) if task.id in task_ids] if task_ids is not None and len(tasks) != len(task_ids): missing_tasks = set(task_ids) - set([task.id for task in tasks]) - raise ValueError( - f"Not all tasks were found for task set {task_set_name}: {missing_tasks}" - ) + raise ValueError(f"Not all tasks were found for task set {task_set_name}: {missing_tasks}") if num_tasks is not None: tasks = tasks[:num_tasks] return tasks @@ -100,13 +93,17 @@ def run_domain(config: RunConfig) -> Results: total_num_tasks = len(tasks) tasks = [task for task in tasks if LLMGTAgent.check_valid_task(task)] num_tasks = len(tasks) - console_text = Text(text=f"Running {num_tasks} out of {total_num_tasks} tasks for GT agent.", style="bold green") + console_text = Text( + text=f"Running {num_tasks} out of {total_num_tasks} tasks for GT agent.", style="bold green" + ) ConsoleDisplay.console.print(console_text) if "solo" in config.agent: total_num_tasks = len(tasks) tasks = [task for task in tasks if LLMSoloAgent.check_valid_task(task)] num_tasks = len(tasks) - console_text = Text(text=f"Running {num_tasks} out of {total_num_tasks} tasks for solo agent.", style="bold green") + console_text = Text( + text=f"Running {num_tasks} out of {total_num_tasks} tasks for solo agent.", style="bold green" + ) ConsoleDisplay.console.print(console_text) num_trials = config.num_trials @@ -244,9 +241,7 @@ def run_tasks( with open(save_to, "r") as fp: prev_simulation_results = Results.model_validate_json(fp.read()) # Check if the run config has changed - if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash( - simulation_results.info - ): + if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash(simulation_results.info): diff = show_dict_diff( prev_simulation_results.info.model_dump(), simulation_results.info.model_dump(), @@ -279,14 +274,12 @@ def run_tasks( "The task set has changed. Please delete the existing file or use a different save_to name." ) # Check which of the runs have already been done - done_runs = set( - [ - (sim.trial, sim.task_id, sim.seed) - for sim in prev_simulation_results.simulations - ] - ) + done_runs = set([(sim.trial, sim.task_id, sim.seed) for sim in prev_simulation_results.simulations]) simulation_results = prev_simulation_results - console_text = Text(text=f"Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.", style="bold yellow") + console_text = Text( + text=f"Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.", + style="bold yellow", + ) ConsoleDisplay.console.print(console_text) # Create new save file else: @@ -338,7 +331,10 @@ def _run(task: Task, trial: int, seed: int, progress_str: str) -> SimulationRun: for trial in range(num_trials): for i, task in enumerate(tasks): if (trial, task.id, seeds[trial]) in done_runs: - console_text = Text(text=f"Skipping task {task.id}, trial {trial} because it has already been run.", style="bold yellow") + console_text = Text( + text=f"Skipping task {task.id}, trial {trial} because it has already been run.", + style="bold yellow", + ) ConsoleDisplay.console.print(console_text) continue progress_str = f"{i}/{len(tasks)} (trial {trial + 1}/{num_trials})" @@ -394,9 +390,7 @@ def run_task( if max_errors <= 0: raise ValueError("Max errors must be greater than 0") global registry - logger.info( - f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}" - ) + logger.info(f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}") environment_constructor = registry.get_env_constructor(domain) environment = environment_constructor() AgentConstructor = registry.get_agent_constructor(agent) @@ -429,9 +423,7 @@ def run_task( task=task, ) else: - raise ValueError( - f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent" - ) + raise ValueError(f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent") try: user_tools = environment.get_user_tools() except Exception: @@ -439,9 +431,7 @@ def run_task( UserConstructor = registry.get_user_constructor(user) if issubclass(UserConstructor, DummyUser): - assert isinstance(agent, LLMSoloAgent), ( - "Dummy user can only be used with solo agent" - ) + assert isinstance(agent, LLMSoloAgent), "Dummy user can only be used with solo agent" user = UserConstructor( tools=user_tools, diff --git a/vendor/tau2/scripts/show_domain_doc.py b/vendor/tau2/scripts/show_domain_doc.py index a7d55cdd..b1b4c8ee 100755 --- a/vendor/tau2/scripts/show_domain_doc.py +++ b/vendor/tau2/scripts/show_domain_doc.py @@ -59,9 +59,7 @@ def main(domain: str): except KeyError: available_domains = registry.get_domains() - logger.error( - f"Domain '{domain}' not found. Available domains: {available_domains}" - ) + logger.error(f"Domain '{domain}' not found. Available domains: {available_domains}") exit(1) except Exception as e: logger.error(f"Failed to start domain documentation server: {str(e)}") diff --git a/vendor/tau2/scripts/start_servers.py b/vendor/tau2/scripts/start_servers.py index 7a7596be..ecbec88c 100755 --- a/vendor/tau2/scripts/start_servers.py +++ b/vendor/tau2/scripts/start_servers.py @@ -18,9 +18,7 @@ def kill_process_on_port(port): connections = proc.net_connections() for conn in connections: if hasattr(conn, "laddr") and conn.laddr.port == port: - logger.warning( - f"Killing existing process {proc.pid} on port {port}" - ) + logger.warning(f"Killing existing process {proc.pid} on port {port}") proc.terminate() time.sleep(0.5) # Give it a moment to terminate if proc.is_running(): # If still running @@ -82,9 +80,7 @@ def signal_handler(signum, frame): try: with ThreadPoolExecutor(max_workers=len(servers)) as executor: # Start each server in a separate thread - futures = [ - executor.submit(run_server, command, port) for command, port in servers - ] + futures = [executor.submit(run_server, command, port) for command, port in servers] # Wait for all servers to complete for future in futures: diff --git a/vendor/tau2/scripts/view_simulations.py b/vendor/tau2/scripts/view_simulations.py index b8fad2be..e3357477 100644 --- a/vendor/tau2/scripts/view_simulations.py +++ b/vendor/tau2/scripts/view_simulations.py @@ -23,9 +23,7 @@ def get_available_simulations(): return sorted([f for f in sim_dir.glob("*.json")]) -def display_simulation_list( - results: Results, only_show_failed: bool = False, only_show_all_failed: bool = False -): +def display_simulation_list(results: Results, only_show_failed: bool = False, only_show_all_failed: bool = False): """Display a numbered list of simulations with basic info.""" ConsoleDisplay.console.print("\n[bold blue]Available Simulations:[/]") @@ -74,9 +72,7 @@ def display_available_files(files): ConsoleDisplay.console.print(f"[cyan]{i}.[/] {file.name}") -def display_simulation_with_task( - simulation, task, results_file: str, sim_index: int, show_details: bool = True -): +def display_simulation_with_task(simulation, task, results_file: str, sim_index: int, show_details: bool = True): """Display a simulation along with its associated task.""" ConsoleDisplay.console.print("\n" + "=" * 80) # Separator ConsoleDisplay.console.print("[bold blue]Task Details:[/]") @@ -113,18 +109,12 @@ def find_task_by_id(tasks, task_id): def find_simulation_by_task_id_and_trial(results, task_id, trial): """Get a simulation by its task ID and trial number.""" return next( - ( - sim - for sim in results.simulations - if sim.task_id == task_id and sim.trial == trial - ), + (sim for sim in results.simulations if sim.task_id == task_id and sim.trial == trial), None, ) -def save_simulation_note( - simulation, task, note: str, results_file: str, sim_index: int -): +def save_simulation_note(simulation, task, note: str, results_file: str, sim_index: int): """Save a note about a simulation to a CSV file.""" notes_file = Path(DATA_DIR) / "simulations" / "simulation_notes.csv" file_exists = notes_file.exists() @@ -137,9 +127,11 @@ def save_simulation_note( "trial": simulation.trial, "duration": simulation.duration, "reward": simulation.reward_info.reward if simulation.reward_info else None, - "db_match": simulation.reward_info.db_check.db_match - if simulation.reward_info and simulation.reward_info.db_check - else None, + "db_match": ( + simulation.reward_info.db_check.db_match + if simulation.reward_info and simulation.reward_info.db_check + else None + ), "results_file": results_file, "sim_index": sim_index, "note": note, @@ -165,9 +157,7 @@ def main( sim_files = [Path(sim_file)] if not sim_files: - ConsoleDisplay.console.print( - "[red]No simulation files found in data/simulations/[/]" - ) + ConsoleDisplay.console.print("[red]No simulation files found in data/simulations/[/]") return results = None @@ -176,20 +166,14 @@ def main( # Show main menu ConsoleDisplay.console.print("\n[bold yellow]Main Menu:[/]") ConsoleDisplay.console.print("1. Select simulation file") - ConsoleDisplay.console.print( - " [dim]Choose a simulation results file to load and analyze[/]" - ) + ConsoleDisplay.console.print(" [dim]Choose a simulation results file to load and analyze[/]") if results: ConsoleDisplay.console.print("2. View agent performance metrics") ConsoleDisplay.console.print(" [dim]Display agent performance metrics[/]") ConsoleDisplay.console.print("3. View simulation") - ConsoleDisplay.console.print( - " [dim]Examine a specific simulation in detail with all its data[/]" - ) + ConsoleDisplay.console.print(" [dim]Examine a specific simulation in detail with all its data[/]") ConsoleDisplay.console.print("4. View task details") - ConsoleDisplay.console.print( - " [dim]Look at the configuration and parameters of a specific task[/]" - ) + ConsoleDisplay.console.print(" [dim]Look at the configuration and parameters of a specific task[/]") ConsoleDisplay.console.print("5. Exit") ConsoleDisplay.console.print(" [dim]Close the simulation viewer[/]") choices = ["1", "2", "3", "4", "5"] @@ -200,17 +184,13 @@ def main( choices = ["1", "2"] default_choice = "1" - choice = Prompt.ask( - "\nWhat would you like to do?", choices=choices, default=default_choice - ) + choice = Prompt.ask("\nWhat would you like to do?", choices=choices, default=default_choice) if choice == "1": # Show available files and get selection display_available_files(sim_files) # default to view the last file - file_num = IntPrompt.ask( - f"\nSelect file number (1-{len(sim_files)})", default=len(sim_files) - ) + file_num = IntPrompt.ask(f"\nSelect file number (1-{len(sim_files)})", default=len(sim_files)) if 1 <= file_num <= len(sim_files): try: @@ -219,13 +199,9 @@ def main( ConsoleDisplay.console.print( f"\n[bold green]Loaded {len(results.simulations)} simulations from {current_file}[/]" ) - results.simulations = sorted( - results.simulations, key=lambda x: (x.task_id, x.trial) - ) + results.simulations = sorted(results.simulations, key=lambda x: (x.task_id, x.trial)) except Exception as e: - ConsoleDisplay.console.print( - f"[red]Error loading results:[/] {str(e)}" - ) + ConsoleDisplay.console.print(f"[red]Error loading results:[/] {str(e)}") else: ConsoleDisplay.console.print("[red]Invalid file number[/]") @@ -245,21 +221,15 @@ def main( # Get simulation selection by index sim_count = len(results.simulations) - sim_index = IntPrompt.ask( - f"\nEnter simulation number (1-{sim_count})", default=1 - ) + sim_index = IntPrompt.ask(f"\nEnter simulation number (1-{sim_count})", default=1) if 1 <= sim_index <= sim_count: sim = results.simulations[sim_index - 1] task = find_task_by_id(results.tasks, sim.task_id) if task: - display_simulation_with_task( - sim, task, current_file, sim_index, show_details=True - ) + display_simulation_with_task(sim, task, current_file, sim_index, show_details=True) else: - ConsoleDisplay.console.print( - f"[red]Warning: Could not find task for simulation {sim.id}[/]" - ) + ConsoleDisplay.console.print(f"[red]Warning: Could not find task for simulation {sim.id}[/]") ConsoleDisplay.display_simulation(sim, show_details=True) continue else: diff --git a/vendor/tau2/user/base.py b/vendor/tau2/user/base.py index b47bdb9a..1fe018b6 100644 --- a/vendor/tau2/user/base.py +++ b/vendor/tau2/user/base.py @@ -63,9 +63,7 @@ def flip_roles(self) -> list[APICompatibleMessage]: ) ) else: - raise ValueError( - f"Tool calls are not supported in the flipped messages: {message}" - ) + raise ValueError(f"Tool calls are not supported in the flipped messages: {message}") elif isinstance(message, ToolMessage): if message.requestor == "user": # Only add tool messages for the user @@ -77,9 +75,7 @@ def flip_roles(self) -> list[APICompatibleMessage]: ) ) else: - raise ValueError( - f"Tool messages should be sent to the user in this message history: {message}" - ) + raise ValueError(f"Tool messages should be sent to the user in this message history: {message}") else: print(message, type(message)) raise ValueError(f"Unknown message role: {message.role}") @@ -100,9 +96,7 @@ def __init__( self.instructions = instructions @abstractmethod - async def get_init_state( - self, message_history: Optional[list[Message]] = None - ) -> UserState: + async def get_init_state(self, message_history: Optional[list[Message]] = None) -> UserState: """Get the initial state of the user simulator. Args: diff --git a/vendor/tau2/utils/display.py b/vendor/tau2/utils/display.py index 674efb92..8e83e2cd 100644 --- a/vendor/tau2/utils/display.py +++ b/vendor/tau2/utils/display.py @@ -95,9 +95,7 @@ def display_task(cls, task: Task): if task.description.purpose: content_parts.append(f"[white]Purpose:[/] {task.description.purpose}") if task.description.relevant_policies: - content_parts.append( - f"[white]Relevant Policies:[/] {task.description.relevant_policies}" - ) + content_parts.append(f"[white]Relevant Policies:[/] {task.description.relevant_policies}") if task.description.notes: content_parts.append(f"[white]Notes:[/] {task.description.notes}") @@ -108,14 +106,10 @@ def display_task(cls, task: Task): scenario_parts.append(f"[white]Persona:[/] {task.user_scenario.persona}") # User Instruction - scenario_parts.append( - f"[white]Task Instructions:[/] {task.user_scenario.instructions}" - ) + scenario_parts.append(f"[white]Task Instructions:[/] {task.user_scenario.instructions}") if scenario_parts: - content_parts.append( - "[bold cyan]User Scenario:[/]\n" + "\n".join(scenario_parts) - ) + content_parts.append("[bold cyan]User Scenario:[/]\n" + "\n".join(scenario_parts)) # Initial State section if task.initial_state: @@ -134,9 +128,7 @@ def display_task(cls, task: Task): ) if initial_state_parts: - content_parts.append( - "[bold cyan]Initial State:[/]\n" + "\n".join(initial_state_parts) - ) + content_parts.append("[bold cyan]Initial State:[/]\n" + "\n".join(initial_state_parts)) # Evaluation Criteria section if task.evaluation_criteria: @@ -154,15 +146,11 @@ def display_task(cls, task: Task): f"[white]Information to Communicate:[/]\n{json.dumps(task.evaluation_criteria.communicate_info, indent=2)}" ) if eval_parts: - content_parts.append( - "[bold cyan]Evaluation Criteria:[/]\n" + "\n".join(eval_parts) - ) + content_parts.append("[bold cyan]Evaluation Criteria:[/]\n" + "\n".join(eval_parts)) content = "\n\n".join(content_parts) # Create and display panel - task_panel = Panel( - content, title="[bold blue]Task Details", border_style="blue", expand=True - ) + task_panel = Panel(content, title="[bold blue]Task Details", border_style="blue", expand=True) cls.console.print(task_panel) @@ -203,18 +191,11 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True marker = "โœ…" if is_successful(simulation.reward_info.reward) else "โŒ" sim_info.append("Reward: ", style="bold cyan") if simulation.reward_info.reward_breakdown: - breakdown = sorted( - [ - f"{k.value}: {v:.1f}" - for k, v in simulation.reward_info.reward_breakdown.items() - ] - ) + breakdown = sorted([f"{k.value}: {v:.1f}" for k, v in simulation.reward_info.reward_breakdown.items()]) else: breakdown = [] - sim_info.append( - f"{marker} {simulation.reward_info.reward:.4f} ({', '.join(breakdown)})\n" - ) + sim_info.append(f"{marker} {simulation.reward_info.reward:.4f} ({', '.join(breakdown)})\n") # Add DB check info if present if simulation.reward_info.db_check: @@ -243,9 +224,7 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True if simulation.reward_info.communicate_checks: sim_info.append("\nCommunicate Checks:\n", style="bold magenta") for i, check in enumerate(simulation.reward_info.communicate_checks): - sim_info.append( - f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'}\n" - ) + sim_info.append(f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'}\n") # Add NL assertions if present if simulation.reward_info.nl_assertions: @@ -261,9 +240,7 @@ def display_simulation(cls, simulation: SimulationRun, show_details: bool = True for key, value in simulation.reward_info.info.items(): sim_info.append(f"{key}: {value}\n") - cls.console.print( - Panel(sim_info, title="Simulation Overview", border_style="blue") - ) + cls.console.print(Panel(sim_info, title="Simulation Overview", border_style="blue")) # Create messages table if simulation.messages: @@ -390,15 +367,8 @@ def display_simulation(cls, sim: SimulationRun) -> str: # Add reward info if present if sim.reward_info: - breakdown = sorted( - [ - f"{k.value}: {v:.1f}" - for k, v in sim.reward_info.reward_breakdown.items() - ] - ) - output.append( - f"**Reward**: {sim.reward_info.reward:.4f} ({', '.join(breakdown)})\n" - ) + breakdown = sorted([f"{k.value}: {v:.1f}" for k, v in sim.reward_info.reward_breakdown.items()]) + output.append(f"**Reward**: {sim.reward_info.reward:.4f} ({', '.join(breakdown)})\n") output.append(f"**Reward**: {sim.reward_info.reward:.4f}") # Add DB check info if present @@ -428,9 +398,7 @@ def display_simulation(cls, sim: SimulationRun) -> str: if sim.reward_info.communicate_checks: output.append("\n**Communicate Checks**") for i, check in enumerate(sim.reward_info.communicate_checks): - output.append( - f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'} {check.justification}" - ) + output.append(f"- {i}: {check.info} {'โœ…' if check.met else 'โŒ'} {check.justification}") # Add NL assertions if present if sim.reward_info.nl_assertions: diff --git a/vendor/tau2/utils/pydantic_utils.py b/vendor/tau2/utils/pydantic_utils.py index 5c34acd8..e46454eb 100644 --- a/vendor/tau2/utils/pydantic_utils.py +++ b/vendor/tau2/utils/pydantic_utils.py @@ -21,9 +21,7 @@ def get_pydantic_hash(obj: BaseModel) -> str: return get_dict_hash(hash_dict) -def update_pydantic_model_with_dict( - model_instance: T, update_data: Dict[str, Any] -) -> T: +def update_pydantic_model_with_dict(model_instance: T, update_data: Dict[str, Any]) -> T: """ Return an updated BaseModel instance based on the update_data. """ diff --git a/vendor/tau2/utils/utils.py b/vendor/tau2/utils/utils.py index c1103fe9..33c9b511 100644 --- a/vendor/tau2/utils/utils.py +++ b/vendor/tau2/utils/utils.py @@ -29,9 +29,7 @@ # Check if data directory exists and is accessible if not DATA_DIR.exists(): logger.warning(f"Data directory does not exist: {DATA_DIR}") - logger.warning( - "Set TAU2_DATA_DIR environment variable to point to your data directory" - ) + logger.warning("Set TAU2_DATA_DIR environment variable to point to your data directory") logger.warning("Or ensure the data directory exists in the expected location") @@ -72,11 +70,7 @@ def get_commit_hash() -> str: Get the commit hash of the current directory. """ try: - commit_hash = ( - subprocess.check_output(["git", "rev-parse", "HEAD"], text=True) - .strip() - .split("\n")[0] - ) + commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip().split("\n")[0] except Exception as e: logger.error(f"Failed to get git hash: {e}") commit_hash = "unknown" diff --git a/vite-app/.gitignore b/vite-app/.gitignore index 9b89d067..af6b9481 100644 --- a/vite-app/.gitignore +++ b/vite-app/.gitignore @@ -20,7 +20,7 @@ dist-ssr *.ntvs* *.njsproj *.sln -*.sw? +*.sw? !package.json -!dist/ \ No newline at end of file +!dist/ diff --git a/vite-app/index.html b/vite-app/index.html index f8fba7b1..ca9157d7 100644 --- a/vite-app/index.html +++ b/vite-app/index.html @@ -11,4 +11,4 @@
- \ No newline at end of file + diff --git a/vite-app/src/index.css b/vite-app/src/index.css index 022157d5..bdb109c5 100644 --- a/vite-app/src/index.css +++ b/vite-app/src/index.css @@ -1,2 +1,2 @@ -@import "tailwindcss"; \ No newline at end of file +@import "tailwindcss"; diff --git a/vite-app/src/types/README.md b/vite-app/src/types/README.md index b86b1ceb..7a605ea4 100644 --- a/vite-app/src/types/README.md +++ b/vite-app/src/types/README.md @@ -90,7 +90,7 @@ import { EvaluationRowSchema } from '@/types'; async function fetchEvaluationData(): Promise { const response = await fetch('/api/evaluation'); const data = await response.json(); - + // Validate the response return EvaluationRowSchema.parse(data); } @@ -140,4 +140,4 @@ The TypeScript types closely mirror the Python Pydantic models: - `Optional[T]` โ†’ `z.optional()` - `List[T]` โ†’ `z.array()` - `Dict[str, Any]` โ†’ `z.record(z.any())` -- `extra="allow"` โ†’ `.passthrough()` \ No newline at end of file +- `extra="allow"` โ†’ `.passthrough()` diff --git a/vite-app/vite.config.ts b/vite-app/vite.config.ts index 78f64d17..8038c4fb 100644 --- a/vite-app/vite.config.ts +++ b/vite-app/vite.config.ts @@ -14,4 +14,4 @@ export default defineConfig({ port: 5173, host: true } -}) \ No newline at end of file +})