Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,17 @@
tool-augmented models using self-contained task bundles.
"""

from importlib import import_module
from typing import Any
import warnings

from eval_protocol.adapters.braintrust import reward_fn_to_scorer, scorer_to_reward_fn

# Lightweight imports (no heavy optional dependencies)
from .integrations.braintrust import reward_fn_to_scorer, scorer_to_reward_fn
from .auth import get_fireworks_account_id, get_fireworks_api_key
from .common_utils import load_jsonl
from .config import RewardKitConfig, get_config, load_config
from .mcp_env import (
AnthropicPolicy,
FireworksPolicy,
LiteLLMPolicy,
OpenAIPolicy,
make,
rollout,
test_mcp,
)

# Try to import FireworksPolicy if available
try:
from .mcp_env import FireworksPolicy

_FIREWORKS_AVAILABLE = True
except (ImportError, AttributeError):
_FIREWORKS_AVAILABLE = False
# Import submodules to make them available via eval_protocol.rewards, etc.
from . import mcp, rewards
from .models import EvaluateResult, Message, MetricResult
from .playback_policy import PlaybackPolicyBase
from .resources import create_llm_resource
Expand All @@ -42,6 +27,7 @@

warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

# Public API (static exports only; dynamic MCP symbols are provided via __getattr__)
__all__ = [
# Core interfaces
"Message",
Expand All @@ -60,14 +46,6 @@
"RewardKitConfig",
# Utilities
"load_jsonl",
# MCP Environment API
"make",
"rollout",
"LiteLLMPolicy",
"AnthropicPolicy",
"FireworksPolicy",
"OpenAIPolicy",
"test_mcp",
# Playback functionality
"PlaybackPolicyBase",
# Resource management
Expand All @@ -77,6 +55,30 @@
"mcp",
]


def __getattr__(name: str) -> Any:
"""Lazily import heavy MCP environment symbols to speed up package import.

This defers importing modules that depend on optional or heavy dependencies
(e.g., vendored tau2, OpenAI clients) until they are actually used.
"""
if name in {
"make",
"rollout",
"LiteLLMPolicy",
"AnthropicPolicy",
"FireworksPolicy",
"OpenAIPolicy",
"test_mcp",
}:
m = import_module(".mcp_env", __name__)
return getattr(m, name)
if name in {"mcp", "rewards"}:
# Lazy-load subpackages for attribute access like eval_protocol.mcp
return import_module(f".{name}", __name__)
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


from . import _version

__version__ = _version.get_versions()["version"]
50 changes: 18 additions & 32 deletions eval_protocol/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,24 @@
- TRL integration (legacy)
"""

# Conditional imports based on available dependencies
try:
from .langfuse import LangfuseAdapter, create_langfuse_adapter
__all__ = ["LangfuseAdapter", "create_langfuse_adapter"]
except ImportError:
__all__ = []
from importlib import import_module
from typing import Any

try:
from .huggingface import (
HuggingFaceAdapter,
create_huggingface_adapter,
create_gsm8k_adapter,
create_math_adapter,
)
__all__.extend([
"HuggingFaceAdapter",
"create_huggingface_adapter",
"create_gsm8k_adapter",
"create_math_adapter",
])
except ImportError:
pass
__all__ = []

# Legacy adapters (always available)
try:
from .braintrust import reward_fn_to_scorer, scorer_to_reward_fn
__all__.extend(["scorer_to_reward_fn", "reward_fn_to_scorer"])
except ImportError:
pass

try:
from .trl import create_trl_adapter
__all__.extend(["create_trl_adapter"])
except ImportError:
pass
def __getattr__(name: str) -> Any:
# Lazy import optional adapters to avoid import-time side effects and heavy deps
if name in {"LangfuseAdapter", "create_langfuse_adapter"}:
m = import_module(".langfuse", __name__)
return getattr(m, name)
if name in {"HuggingFaceAdapter", "create_huggingface_adapter", "create_gsm8k_adapter", "create_math_adapter"}:
m = import_module(".huggingface", __name__)
return getattr(m, name)
if name in {"reward_fn_to_scorer", "scorer_to_reward_fn"}:
m = import_module(".braintrust", __name__)
return getattr(m, name)
if name in {"create_trl_adapter"}:
m = import_module(".trl", __name__)
return getattr(m, name)
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
6 changes: 3 additions & 3 deletions vendor/tau2/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from litellm.caching.caching import Cache
from litellm.main import ModelResponse, Usage
from loguru import logger
import os

from vendor.tau2.config import (
DEFAULT_LLM_CACHE_TYPE,
Expand Down Expand Up @@ -70,9 +71,8 @@


ALLOW_SONNET_THINKING = False

if not ALLOW_SONNET_THINKING:
logger.warning("Sonnet thinking is disabled")
if os.getenv("TAU2_VERBOSE") == "1" and not ALLOW_SONNET_THINKING:
logger.info("Sonnet thinking is disabled")


def _parse_ft_model_name(model: str) -> str:
Expand Down
12 changes: 8 additions & 4 deletions vendor/tau2/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from dotenv import load_dotenv
from loguru import logger

_TAU2_VERBOSE = os.getenv("TAU2_VERBOSE") == "1"

res = load_dotenv()
if not res:
if not res and _TAU2_VERBOSE:
logger.warning("No .env file found")

# Try to get data directory from environment variable first
Expand All @@ -19,15 +21,17 @@
if DATA_DIR_ENV:
# Use environment variable if set
DATA_DIR = Path(DATA_DIR_ENV)
logger.info(f"Using data directory from environment: {DATA_DIR}")
if _TAU2_VERBOSE:
logger.info(f"Using data directory from environment: {DATA_DIR}")
else:
# Fallback to vendored tau2 directory
SOURCE_DIR = Path(__file__).parents[1] # vendor/tau2/
DATA_DIR = SOURCE_DIR / "data"
logger.info(f"Using data directory from vendored tau2: {DATA_DIR}")
if _TAU2_VERBOSE:
logger.info(f"Using data directory from vendored tau2: {DATA_DIR}")

# Check if data directory exists and is accessible
if not DATA_DIR.exists():
if not DATA_DIR.exists() and _TAU2_VERBOSE:
logger.warning(f"Data directory does not exist: {DATA_DIR}")
logger.warning(
"Set TAU2_DATA_DIR environment variable to point to your data directory"
Expand Down
Loading