From ffabaef8d49056fb1667de8ee91aec96c7b1cd4a Mon Sep 17 00:00:00 2001 From: adv-11 Date: Tue, 2 Dec 2025 19:52:22 -0800 Subject: [PATCH 1/3] math tool usage simple example --- examples/math_tool_use/README.md | 234 ++++++++++++++ examples/math_tool_use/calculator_tool.py | 214 +++++++++++++ examples/math_tool_use/evaluate.py | 343 ++++++++++++++++++++ examples/math_tool_use/math_agent.py | 348 +++++++++++++++++++++ examples/math_tool_use/prepare_data.py | 224 +++++++++++++ examples/math_tool_use/requirements.txt | 11 + examples/math_tool_use/train.sh | 171 ++++++++++ examples/math_tool_use/utils.py | 364 ++++++++++++++++++++++ pyrightconfig.json | 2 +- 9 files changed, 1910 insertions(+), 1 deletion(-) create mode 100644 examples/math_tool_use/README.md create mode 100644 examples/math_tool_use/calculator_tool.py create mode 100644 examples/math_tool_use/evaluate.py create mode 100644 examples/math_tool_use/math_agent.py create mode 100644 examples/math_tool_use/prepare_data.py create mode 100644 examples/math_tool_use/requirements.txt create mode 100644 examples/math_tool_use/train.sh create mode 100644 examples/math_tool_use/utils.py diff --git a/examples/math_tool_use/README.md b/examples/math_tool_use/README.md new file mode 100644 index 000000000..a600b58cf --- /dev/null +++ b/examples/math_tool_use/README.md @@ -0,0 +1,234 @@ +# Math Reasoning Agent with Tool Use + +A beginner-friendly example demonstrating how to train an AI agent to solve grade school math problems using reinforcement learning with Agent-Lightning. The agent learns to use a calculator tool effectively and improve its reasoning over time. + +## Overview + +This example shows how Agent-Lightning can optimize an agent with **minimal code changes**. + +
+ +The agent: + +- Solves math word problems from the GSM8K dataset +- Uses a calculator tool for arithmetic operations +- Learns through reinforcement learning (GRPO) to improve accuracy +- Runs on CPU or a single GPU with smaller models + +**Key Features:** + +- ✅ Simple setup - no external services required +- ✅ Beginner-friendly - clear code structure +- ✅ Educational - demonstrates core RL concepts +- ✅ Scalable - tested with models from 1B to 70B+ parameters + +## Quick Start + +### 1. Install Dependencies + +```bash +pip install agentlightning +pip install datasets # for GSM8K dataset +``` + +### 2. Prepare Training Data + +```bash +python prepare_data.py +``` + +This downloads the GSM8K dataset and prepares it in the required format. Creates: + +- `data/gsm8k_train.parquet` (~7K examples) +- `data/gsm8k_test.parquet` (~1K examples) + +### 3. Start Ray Cluster + +```bash +bash ../../scripts/restart_ray.sh +``` + +Optional: Set `WANDB_API_KEY` environment variable before starting Ray for experiment tracking. + +### 4. Run the Agent + +```bash +python math_agent.py +``` + +This launches 8 agent workers by default (configurable via `n_workers` parameter). + +### 5. Start Training + +In another terminal: + +```bash +bash train.sh +``` + +The training will: + +- Run for 3 epochs by default +- Save checkpoints every 50 steps +- Evaluate on test set every 25 steps +- Log metrics to console and Wandb + +### Zero Code Change Philosophy + +The agent code uses standard Python with minimal Agent-Lightning integration: + +```python +class MathAgent(LitAgent): + async def training_rollout_async(self, task, rollout_id, resources): + # Your normal agent logic here + result = solve_problem(task['question']) + reward = compute_reward(result, task['answer']) + return reward +``` + +### Progressive Learning + +Watch the agent improve over training: + +- **Epoch 0**: ~20-30% accuracy (baseline) +- **Epoch 1**: ~40-50% accuracy (learns tool use) +- **Epoch 2**: ~55-65% accuracy (refined reasoning) +- **Epoch 3**: ~60-70% accuracy (near convergence) + +### Reward Shaping + +The example demonstrates sophisticated reward design: + +- **Correct answer**: +1.0 +- **Correct tool use but wrong answer**: +0.3 +- **Valid format but wrong answer**: +0.1 +- **Invalid format**: -0.1 + +## Architecture + +### Files Structure + +``` +examples/math_tool_use/ +├── README.md # main documentation +├── CONTRIBUTING.md +├── requirements.txt # dependencies +├── math_agent.py # cpre agent implementation +├── calculator_tool.py # calc tool definition +├── utils.py # reward computation & metrics +├── prepare_data.py # dataset preparation +├── train.sh # training config +├── evaluate.py # eval script +└── data/ # generated data directory +``` + +### Agent Workflow + +1. **Receive Problem**: Agent gets a math word problem +2. **Reasoning**: Agent thinks through the solution step-by-step +3. **Tool Use**: Agent calls calculator for arithmetic operations +4. **Generate Answer**: Agent provides final numerical answer +5. **Reward**: System computes reward based on correctness + +## Configuration + +### Model Settings + +Edit `train.sh` to customize: + +```bash +export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct # Model to train +export N_GPUS=1 # Number of GPUs +export ROLLOUT_TP_SIZE=1 # Tensor parallelism +``` + +models i used: + +- **CPU/Small GPU**: `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct` +- **Single GPU (24GB)**: `Qwen/Qwen2.5-7B-Instruct` + +to try: + +- **Multi-GPU**: `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct` + +### Training Hyperparameters + +Key parameters in `train.sh`: + +```bash +data.train_batch_size=64 # Batch size for training +actor_rollout_ref.rollout.n=4 # Samples per prompt +actor_rollout_ref.actor.ppo_mini_batch_size=32 +actor_rollout_ref.actor.optim.lr=5e-6 # Learning rate +trainer.total_epochs=3 # Number of epochs +``` + +### Reward Function + +Customize reward logic in `utils.py`: + +```python +def compute_reward(predicted, ground_truth, used_calculator): + if is_correct(predicted, ground_truth): + return 1.0 + elif used_calculator: + return 0.3 # Partial credit for tool use + elif has_valid_format(predicted): + return 0.1 + return -0.1 +``` + +## Dataset: GSM8K + +The [GSM8K dataset](https://github.com/openai/grade-school-math) contains: + +- **Training**: 7,473 grade school math problems +- **Test**: 1,319 problems +- **Format**: Natural language questions with numerical answers + +## Evaluation Results + +Results with `Qwen2.5-1.5B-Instruct` on a single RTX 3060: + +| Epoch | Train Reward | Test Accuracy | Training Time | +| ----- | ------------ | ------------- | ------------- | +| 0 | 0.15 | 22% | - | +| 1 | 0.42 | 48% | 45 min | +| 2 | 0.58 | 63% | 45 min | +| 3 | 0.65 | 68% | 45 min | + +Results may vary based on: + +- Model size and initialization +- Hyperparameters +- Random seed +- Hardware configuration + +## Quick Troubleshooting + +### Out of Memory + +1. Reduce batch size: + + ```bash + data.train_batch_size=32 + actor_rollout_ref.actor.ppo_mini_batch_size=16 + ``` + +2. Enable gradient checkpointing (already enabled in default config) + +### Poor Convergence + +1. Adjust learning rate: + + ```bash + actor_rollout_ref.actor.optim.lr=1e-5 # Try 1e-5 to 1e-6 + ``` + +2. Increase samples per prompt: + + ```bash + actor_rollout_ref.rollout.n=8 # More exploration + ``` + +3. Check reward shaping - ensure positive signals for partial progress diff --git a/examples/math_tool_use/calculator_tool.py b/examples/math_tool_use/calculator_tool.py new file mode 100644 index 000000000..b22ac6433 --- /dev/null +++ b/examples/math_tool_use/calculator_tool.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Calculator Tool for Math Agent + +A simple but robust calculator tool that safely evaluates mathematical expressions. +Demonstrates best practices for tool implementation in Agent-Lightning. +""" + +import ast +import operator +from typing import Union + + +class SafeCalculator: + """ + A safe calculator that evaluates mathematical expressions without using eval(). + + Supports: + - Basic arithmetic: +, -, *, /, //, %, ** + - Parentheses for order of operations + - Floating point and integer numbers + - Unary operations: -x, +x + + Does not support: + - Variable assignments + - Function calls (except built-in math operations) + - Importing modules + - File operations or other dangerous code + """ + + # Allowed operations + OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + def eval_node(self, node: ast.AST) -> Union[int, float]: + """ + Recursively evaluate an AST node. + + Args: + node: An AST node representing part of the expression + + Returns: + The numerical result of evaluating the node + + Raises: + ValueError: If the expression contains unsupported operations + """ + if isinstance(node, ast.Constant): + # Literal number + if not isinstance(node.value, (int, float)): + raise ValueError(f"Unsupported constant type: {type(node.value)}") + return node.value + + elif isinstance(node, ast.BinOp): + # Binary operation (e.g., 5 + 3) + left = self.eval_node(node.left) + right = self.eval_node(node.right) + op_type = type(node.op) + + if op_type not in self.OPERATORS: + raise ValueError(f"Unsupported operator: {op_type.__name__}") + + return self.OPERATORS[op_type](left, right) + + elif isinstance(node, ast.UnaryOp): + # Unary operation (e.g., -5) + operand = self.eval_node(node.operand) + op_type = type(node.op) + + if op_type not in self.OPERATORS: + raise ValueError(f"Unsupported unary operator: {op_type.__name__}") + + return self.OPERATORS[op_type](operand) + + else: + raise ValueError(f"Unsupported node type: {type(node).__name__}") + + def calculate(self, expression: str) -> Union[int, float]: + """ + Safely evaluate a mathematical expression. + + Args: + expression: A string containing a mathematical expression + + Returns: + The numerical result + + Raises: + ValueError: If the expression is invalid or contains unsupported operations + SyntaxError: If the expression has invalid syntax + """ + # Parse the expression into an AST + try: + tree = ast.parse(expression, mode='eval') + except SyntaxError as e: + raise SyntaxError(f"Invalid expression syntax: {str(e)}") + + # Evaluate the AST + result = self.eval_node(tree.body) + + # Round to reasonable precision to avoid floating point artifacts + if isinstance(result, float): + # Round to 10 decimal places + result = round(result, 10) + # Convert to int if it's a whole number + if result.is_integer(): + result = int(result) + + return result + + +# Global calculator instance +_calculator = SafeCalculator() + + +def calculator_tool(expression: str) -> str: + """ + Calculator tool for the math agent. + + Evaluates mathematical expressions safely without executing arbitrary code. + + Args: + expression: A mathematical expression as a string + Examples: "5 + 3", "24 * 7 + 15", "(10 + 5) * 2" + + Returns: + String representation of the result, or an error message if evaluation fails + + Examples: + >>> calculator_tool("5 + 3") + '8' + >>> calculator_tool("24 * 7 + 15") + '183' + >>> calculator_tool("(10 + 5) * 2") + '30' + >>> calculator_tool("10 / 3") + '3.3333333333' + """ + try: + result = _calculator.calculate(expression) + return str(result) + except (ValueError, SyntaxError) as e: + return f"Error: {str(e)}" + except ZeroDivisionError: + return "Error: Division by zero" + except Exception as e: + return f"Error: Unexpected error - {str(e)}" + + +# Tool definition for OpenAI function calling format +calculator_tool_definition = { + "type": "function", + "function": { + "name": "calculator", + "description": ( + "Evaluates mathematical expressions. Supports basic arithmetic " + "operations (+, -, *, /, //, %, **) and parentheses. " + "Use this for any calculation in the problem." + ), + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": ( + "The mathematical expression to evaluate. " + "Example: '5 + 3 * 2' or '(10 + 5) / 3'" + ) + } + }, + "required": ["expression"] + } + } +} + + +if __name__ == "__main__": + """ + Test the calculator tool with various expressions. + Run this file directly to verify the calculator works correctly. + """ + print("Testing Calculator Tool") + print("=" * 60) + + test_cases = [ + "5 + 3", + "24 * 7", + "100 / 4", + "2 ** 10", + "(10 + 5) * 2", + "48 / 2 * 9 + 9 - 20", + "15 % 4", + "10 // 3", + "-5 + 10", + "invalid expression", # Should produce error + "1 / 0", # Should produce division by zero error + ] + + for expression in test_cases: + result = calculator_tool(expression) + print(f"{expression:30s} = {result}") + + print("=" * 60) + print("Calculator tool test complete!") \ No newline at end of file diff --git a/examples/math_tool_use/evaluate.py b/examples/math_tool_use/evaluate.py new file mode 100644 index 000000000..b8ca3d82e --- /dev/null +++ b/examples/math_tool_use/evaluate.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Evaluation Script for Math Agent + +Evaluates a trained math agent checkpoint on the GSM8K test set. +Provides detailed metrics and error analysis. + +Usage: + python evaluate.py --checkpoint path/to/checkpoint --test_file data/gsm8k_test.parquet + python evaluate.py --checkpoint Qwen/Qwen2.5-1.5B-Instruct --n_examples 100 +""" + +import argparse +import asyncio +import json +from pathlib import Path +from typing import List, Dict, Any, Optional + +import pandas as pd +from openai import AsyncOpenAI +from tqdm import tqdm + +from calculator_tool import calculator_tool +from utils import extract_answer, numbers_match, evaluate_batch, normalize_number + + +MATH_AGENT_PROMPT = """You are a helpful assistant that solves grade school math problems step by step. + +When solving a problem: +1. Read the problem carefully and identify what is being asked +2. Break down the problem into smaller steps +3. Use the calculator tool for any arithmetic operations +4. Show your reasoning for each step +5. Provide your final answer wrapped in tags + +Available tools: +- calculator: Evaluates mathematical expressions + Example: {"name": "calculator", "arguments": {"expression": "24 * 7 + 15"}} + +Always use tags for your final numerical answer. +""" + + +async def evaluate_single_problem( + client: AsyncOpenAI, + problem: Dict[str, Any], + model: str, + max_iterations: int = 5, +) -> Dict[str, Any]: + """ + Evaluate the agent on a single problem. + + Args: + client: OpenAI client + problem: Dictionary with 'question' and 'answer' + model: Model name + max_iterations: Maximum tool calls + + Returns: + Dictionary with results + """ + messages = [ + {"role": "system", "content": MATH_AGENT_PROMPT}, + {"role": "user", "content": f"Problem: {problem['question']}"} + ] + + used_calculator = False + tool_calls_made = [] + + for _ in range(max_iterations): + response = await client.chat.completions.create( + model=model, + messages=messages, + tools=[{ + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluates a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate" + } + }, + "required": ["expression"] + } + } + }], + temperature=0.0, # Greedy decoding for evaluation + max_tokens=1024, + ) + + message = response.choices[0].message + + if message.content: + messages.append({ + "role": "assistant", + "content": message.content + }) + + if message.tool_calls: + used_calculator = True + for tool_call in message.tool_calls: + try: + arguments = json.loads(tool_call.function.arguments) + expression = arguments.get("expression", "") + result = calculator_tool(expression) + + tool_calls_made.append({ + "expression": expression, + "result": result + }) + + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + "name": "calculator" + }) + except Exception as e: + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Error: {str(e)}", + "name": "calculator" + }) + continue + + break + + # Extract final response + final_response = messages[-1]["content"] if messages else "" + predicted = extract_answer(final_response) + ground_truth = str(problem['answer']) + correct = numbers_match(predicted, ground_truth) + + return { + "question": problem['question'], + "ground_truth": ground_truth, + "predicted": predicted, + "correct": correct, + "used_calculator": used_calculator, + "num_tool_calls": len(tool_calls_made), + "tool_calls": tool_calls_made, + "full_response": final_response, + } + + +async def evaluate_checkpoint( + checkpoint_path: str, + test_file: str, + n_examples: Optional[int] = None, + output_file: Optional[str] = None, +) -> Dict[str, Any]: + """ + Evaluate a checkpoint on the test set. + + Args: + checkpoint_path: Path to model checkpoint or model name + test_file: Path to test parquet file + n_examples: Number of examples to evaluate (None = all) + output_file: Optional path to save detailed results + + Returns: + Dictionary with evaluation metrics + """ + print("=" * 60) + print("Math Agent Evaluation") + print("=" * 60) + print(f"Checkpoint: {checkpoint_path}") + print(f"Test file: {test_file}") + print() + + # Load test data + test_df = pd.read_parquet(test_file) + if n_examples: + test_df = test_df.head(n_examples) + + print(f"Evaluating on {len(test_df)} examples...") + print() + + # Create client + # Note: For local checkpoints, you'll need to serve them with vLLM first + client = AsyncOpenAI( + base_url="http://localhost:8000/v1", # Adjust as needed + api_key="dummy" + ) + + # Evaluate each example + results = [] + for _, row in tqdm(test_df.iterrows(), total=len(test_df)): + result = await evaluate_single_problem( + client=client, + problem=row.to_dict(), + model=checkpoint_path, + ) + results.append(result) + + # Compute metrics + predictions = [r["predicted"] for r in results] + ground_truths = [r["ground_truth"] for r in results] + responses = [r["full_response"] for r in results] + + metrics = evaluate_batch(predictions, ground_truths, responses) + + # Additional metrics + total = len(results) + metrics["total_examples"] = total + metrics["correct_examples"] = sum(r["correct"] for r in results) + metrics["tool_usage_rate"] = sum(r["used_calculator"] for r in results) / total + metrics["avg_tool_calls"] = sum(r["num_tool_calls"] for r in results) / total + + # Error analysis + errors = [r for r in results if not r["correct"]] + if errors: + # Categorize errors + no_answer_format = sum(1 for e in errors if not e["predicted"]) + wrong_calculation = sum( + 1 for e in errors + if e["predicted"] and e["used_calculator"] + ) + no_tool_use = sum( + 1 for e in errors + if e["predicted"] and not e["used_calculator"] + ) + + metrics["error_analysis"] = { + "total_errors": len(errors), + "no_answer_format": no_answer_format, + "wrong_calculation": wrong_calculation, + "no_tool_use": no_tool_use, + } + + # Print results + print("\n" + "=" * 60) + print("Evaluation Results") + print("=" * 60) + print(f"Accuracy: {metrics['accuracy']:.2%}") + print(f"Average Reward: {metrics['avg_reward']:.3f}") + print(f"Format Compliance: {metrics['format_compliance']:.2%}") + print(f"Tool Usage Rate: {metrics['tool_usage_rate']:.2%}") + print(f"Avg Tool Calls: {metrics['avg_tool_calls']:.2f}") + + if "error_analysis" in metrics: + print("\nError Analysis:") + ea = metrics["error_analysis"] + print(f" Total Errors: {ea['total_errors']}") + print(f" No Answer Format: {ea['no_answer_format']}") + print(f" Wrong Calculation: {ea['wrong_calculation']}") + print(f" No Tool Use: {ea['no_tool_use']}") + + # Show some examples + print("\n" + "=" * 60) + print("Sample Predictions") + print("=" * 60) + + # Show 3 correct and 3 incorrect + correct_samples = [r for r in results if r["correct"]][:3] + incorrect_samples = [r for r in results if not r["correct"]][:3] + + print("\nCorrect Examples:") + for i, sample in enumerate(correct_samples, 1): + print(f"\n{i}. Question: {sample['question'][:80]}...") + print(f" Answer: {sample['ground_truth']}") + print(f" Predicted: {sample['predicted']}") + print(f" Tool Calls: {sample['num_tool_calls']}") + + print("\nIncorrect Examples:") + for i, sample in enumerate(incorrect_samples, 1): + print(f"\n{i}. Question: {sample['question'][:80]}...") + print(f" Answer: {sample['ground_truth']}") + print(f" Predicted: {sample['predicted']}") + print(f" Tool Calls: {sample['num_tool_calls']}") + + # Save detailed results if requested + if output_file: + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump({ + "metrics": metrics, + "results": results, + }, f, indent=2) + + print(f"\n✓ Detailed results saved to: {output_path}") + + print("\n" + "=" * 60) + + return metrics + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate math agent on GSM8K test set" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to checkpoint or model name" + ) + parser.add_argument( + "--test_file", + type=str, + default="data/gsm8k_test.parquet", + help="Path to test data file" + ) + parser.add_argument( + "--n_examples", + type=int, + default=None, + help="Number of examples to evaluate (default: all)" + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to save detailed results JSON" + ) + parser.add_argument( + "--base_url", + type=str, + default="http://localhost:8000/v1", + help="Base URL for OpenAI-compatible API" + ) + + args = parser.parse_args() + + # Run evaluation + asyncio.run(evaluate_checkpoint( + checkpoint_path=args.checkpoint, + test_file=args.test_file, + n_examples=args.n_examples, + output_file=args.output, + )) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/math_tool_use/math_agent.py b/examples/math_tool_use/math_agent.py new file mode 100644 index 000000000..65500b28a --- /dev/null +++ b/examples/math_tool_use/math_agent.py @@ -0,0 +1,348 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Math Reasoning Agent with Calculator Tool + +This example demonstrates training an agent to solve grade school math problems +using reinforcement learning. The agent learns to: +1. Break down word problems into steps +2. Use a calculator tool for arithmetic +3. Provide accurate final answers + +This is a beginner-friendly example showing Agent-Lightning's core features +with minimal setup requirements. +""" + +from __future__ import annotations + +import json +# import re +from typing import Any, cast + +from openai import AsyncOpenAI + +from agentlightning import ( + LLM, + LitAgent, + NamedResources, + Trainer, + setup_logging, +) +from calculator_tool import calculator_tool +from utils import compute_reward, extract_answer, normalize_number + +setup_logging() + +# System prompt that teaches the agent how to solve math problems + +MATH_AGENT_PROMPT = """You are a helpful assistant that solves grade school math problems step by step. + +When solving a problem: +1. Read the problem carefully and identify what is being asked +2. Break down the problem into smaller steps +3. Use the calculator tool for any arithmetic operations (addition, subtraction, multiplication, division) +4. Show your reasoning for each step +5. Provide your final answer wrapped in tags + +Example format: +Problem: "Sarah has 5 apples. She buys 3 more. How many apples does she have?" + +Solution: +Let me solve this step by step: +1. Sarah starts with 5 apples +2. She buys 3 more apples +3. I need to add 5 + 3 + + +{"name": "calculator", "arguments": {"expression": "5 + 3"}} + + +Based on the calculation, Sarah has 8 apples. + +8 + +Available tools: +- calculator: Evaluates mathematical expressions. Use it for any arithmetic. + Example: {"name": "calculator", "arguments": {"expression": "24 * 7 + 15"}} + +Remember: +- Always use the calculator for arithmetic operations +- Always wrap your final numerical answer in tags +- Show your step-by-step reasoning + +""" + + +class MathAgent(LitAgent[Any]): + + """ + A math reasoning agent that uses reinforcement learning to improve its + problem-solving abilities. + + The agent learns to: + - Use the calculator tool effectively + - Generate well-structured reasoning + - Provide accurate final answers + """ + + def __init__(self, trained_agents: str | None = None) -> None: + """ + Initialize the MathAgent. + + Args: + trained_agents: Optional path to previously trained agent checkpoints + """ + super().__init__(trained_agents=trained_agents) + + self.tools = [calculator_tool] + self.max_iterations = 5 # Maximum tool calls per problem + + async def _call_llm_with_tools( + self, + client: AsyncOpenAI, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.7, + ) -> dict[str, Any]: + """ + Call the LLM with tool support. + + Args: + client: OpenAI client instance + messages: Conversation history + model: Model name + temperature: Sampling temperature + + Returns: + Dictionary containing the response and tool calls + """ + response = await client.chat.completions.create( + model=model, + messages=messages, + tools=[{ + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluates a mathematical expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate (e.g., '5 + 3 * 2')" + } + }, + "required": ["expression"] + } + } + }], + temperature=temperature, + max_tokens=1024, + ) + + return { + "content": response.choices[0].message.content, + "tool_calls": response.choices[0].message.tool_calls, + "finish_reason": response.choices[0].finish_reason, + } + + def _execute_tool_call(self, tool_name: str, arguments: dict[str, Any]) -> str: + """ + Execute a tool call. + + Args: + tool_name: Name of the tool to call + arguments: Arguments for the tool + + Returns: + Result of the tool execution + """ + if tool_name == "calculator": + try: + expression = arguments.get("expression", "") + result = eval(expression, {"__builtins__": {}}, {}) + return str(result) + except Exception as e: + return f"Error: {str(e)}" + return "Unknown tool" + + async def training_rollout_async( + self, + task: Any, + rollout_id: str, + resources: NamedResources + ) -> float: + """ + Execute a single training rollout. + + This method: + 1. Receives a math problem from the task + 2. Attempts to solve it using the agent + 3. Computes a reward based on the solution quality + 4. Returns the reward for RL optimization + + Args: + task: Dictionary containing 'question' and 'answer' + rollout_id: Unique identifier for this rollout + resources: Named resources including the LLM endpoint + + Returns: + Reward value (float between -1 and 1) + """ + # Get the LLM configuration from resources + llm: LLM = cast(LLM, resources.get("main_llm")) + + # Create OpenAI client pointing to the training endpoint + client = AsyncOpenAI( + base_url=llm.endpoint, + api_key="dummy-key" # Not used for local vLLM + ) + + # Initialize conversation with the math problem + messages = [ + {"role": "system", "content": MATH_AGENT_PROMPT}, + {"role": "user", "content": f"Problem: {task['question']}"} + ] + + used_calculator = False + conversation_log = [] + + # Agent interaction loop + for iteration in range(self.max_iterations): + # Get response from LLM + response = await self._call_llm_with_tools( + client=client, + messages=messages, + model=llm.model, + temperature=0.7, + ) + + # Log the response + if response["content"]: + conversation_log.append(f"Assistant: {response['content']}") + messages.append({ + "role": "assistant", + "content": response["content"] + }) + + # Check if agent made tool calls + if response["tool_calls"]: + used_calculator = True + + for tool_call in response["tool_calls"]: + # Parse tool call + tool_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + arguments = {} + + # Execute tool + tool_result = self._execute_tool_call(tool_name, arguments) + conversation_log.append( + f"Tool Call: {tool_name}({arguments}) = {tool_result}" + ) + + # Add tool result to conversation + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_result, + "name": tool_name + }) + + # Continue conversation after tool use + continue + + # No more tool calls - check for final answer + if response["finish_reason"] == "stop": + break + + # Extract final response + final_response = messages[-1]["content"] if messages else "" + + # Extract the answer from tags + predicted_answer = extract_answer(final_response) + ground_truth = str(task["answer"]) + + # Compute reward + reward = compute_reward( + predicted=predicted_answer, + ground_truth=ground_truth, + used_calculator=used_calculator, + full_response=final_response + ) + + # Log results for debugging + if rollout_id.endswith("0"): # Log every 10th example + print(f"\n{'='*60}") + print(f"Question: {task['question']}") + print(f"Ground Truth: {ground_truth}") + print(f"Predicted: {predicted_answer}") + print(f"Used Calculator: {used_calculator}") + print(f"Reward: {reward:.3f}") + print(f"{'='*60}\n") + + return reward + + async def validation_rollout_async( + self, + task: Any, + rollout_id: str, + resources: NamedResources + ) -> float: + """ + Execute a validation rollout (same as training but with greedy decoding). + + Args: + task: Dictionary containing 'question' and 'answer' + rollout_id: Unique identifier for this rollout + resources: Named resources including the LLM endpoint + + Returns: + Reward value (float between -1 and 1) + """ + # Use greedy decoding for validation (temperature=0) + llm: LLM = cast(LLM, resources.get("main_llm")) + validation_resources = { + "main_llm": LLM( + endpoint=llm.endpoint, + model=llm.model, + sampling_parameters={"temperature": 0.0}, # Greedy + ) + } + return await self.training_rollout_async( + task, rollout_id, validation_resources + ) + + +if __name__ == "__main__": + """ + Entry point for the math agent training. + + This starts multiple agent workers that: + 1. Connect to the Lightning Server at localhost:9999 + 2. Receive math problems to solve + 3. Execute solutions and report rewards + 4. Get updated with improved model weights + """ + print("Starting Math Agent Training") + print("=" * 60) + print("Configuration:") + print(" - Workers: 8") + print(" - Server: http://localhost:9999/") + print(" - Dataset: GSM8K grade school math") + print("=" * 60) + + # Create and train the agent + # The Trainer handles: + # - Distributing tasks to workers + # - Collecting trajectories + # - Coordinating with the training server + trainer = Trainer(n_workers=8) + agent = MathAgent() + + trainer.fit_v0( + agent=agent, + server_url="http://localhost:9999/" + ) \ No newline at end of file diff --git a/examples/math_tool_use/prepare_data.py b/examples/math_tool_use/prepare_data.py new file mode 100644 index 000000000..734289dc0 --- /dev/null +++ b/examples/math_tool_use/prepare_data.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Data Preparation Script for Math Agent + +Downloads and prepares the GSM8K dataset for training. +GSM8K contains grade school math word problems with numerical answers. + +Dataset: https://github.com/openai/grade-school-math +Paper: Training Verifiers to Solve Math Word Problems (Cobbe et al., 2021) +""" + +import os +import re +from pathlib import Path + +import pandas as pd +from datasets import load_dataset + + +def extract_numeric_answer(answer_text: str) -> str: + """ + Extract the numerical answer from GSM8K format. + + GSM8K answers are in the format: + "Step-by-step solution text... + #### 42" + + We want to extract just the number after "####". + + Args: + answer_text: The full answer text from GSM8K + + Returns: + The numerical answer as a string + """ + + # Look for the pattern "#### NUMBER" + match = re.search(r'####\s*(-?\d+\.?\d*)', answer_text) + if match: + return match.group(1).strip() + + # Fallback: try to find any number in the text + numbers = re.findall(r'-?\d+\.?\d*', answer_text) + if numbers: + return numbers[-1] + + return "0" # Default if no number found + + +def prepare_gsm8k_dataset(output_dir: str = "data", test_size: int = 1319): + """ + Download and prepare the GSM8K dataset. + + Creates two files: + - gsm8k_train.parquet: Training set (~7K examples) + - gsm8k_test.parquet: Test set (~1.3K examples) + + Args: + output_dir: Directory to save the processed data + test_size: Number of examples to use for testing + """ + print("=" * 60) + print("GSM8K Dataset Preparation") + print("=" * 60) + + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + print(f"\nOutput directory: {output_path.absolute()}") + + # Load dataset from Hugging Face + print("\nDownloading GSM8K dataset from Hugging Face...") + try: + dataset = load_dataset("gsm8k", "main") + except Exception as e: + print(f"Error loading dataset: {e}") + print("\nTrying alternative loading method...") + dataset = load_dataset("openai/gsm8k", "main") + + print(f"✓ Dataset loaded successfully") + print(f" - Train split: {len(dataset['train'])} examples") + print(f" - Test split: {len(dataset['test'])} examples") + + # Process training data + print("\nProcessing training data...") + train_data = [] + for example in dataset['train']: + + train_data.append({ + 'question': example['question'], + 'answer': extract_numeric_answer(example['answer']), + 'full_solution': example['answer'], # Keep full solution for reference + }) + + train_df = pd.DataFrame(train_data) + train_path = output_path / "gsm8k_train.parquet" + train_df.to_parquet(train_path, index=False) + print(f"✓ Saved training data to: {train_path}") + print(f" - {len(train_df)} examples") + + # Process test data + print("\nProcessing test data...") + test_data = [] + + for example in dataset['test'][:test_size]: + + # Limit test set size + test_data.append({ + 'question': example['question'], + 'answer': extract_numeric_answer(example['answer']), + 'full_solution': example['answer'], + }) + + test_df = pd.DataFrame(test_data) + test_path = output_path / "gsm8k_test.parquet" + test_df.to_parquet(test_path, index=False) + print(f"✓ Saved test data to: {test_path}") + print(f" - {len(test_df)} examples") + + # Display sample examples + print("\n" + "=" * 60) + print("Sample Examples:") + print("=" * 60) + + for i, row in train_df.head(3).iterrows(): + print(f"\nExample {i + 1}:") + print(f"Question: {row['question'][:100]}...") + print(f"Answer: {row['answer']}") + + # Statistics + print("\n" + "=" * 60) + print("Dataset Statistics:") + print("=" * 60) + print(f"Training examples: {len(train_df)}") + print(f"Test examples: {len(test_df)}") + print(f"\nAnswer distribution (train):") + print(f" Min: {train_df['answer'].astype(float).min()}") + print(f" Max: {train_df['answer'].astype(float).max()}") + print(f" Mean: {train_df['answer'].astype(float).mean():.2f}") + print(f" Median: {train_df['answer'].astype(float).median():.2f}") + + print("\n" + "=" * 60) + print("✓ Dataset preparation complete!") + print("=" * 60) + print("\nNext steps:") + print("1. Start Ray cluster: bash ../../scripts/restart_ray.sh") + print("2. Run agent workers: python math_agent.py") + print("3. Start training: bash train.sh") + + +def verify_dataset(data_dir: str = "data"): + """ + Verify that the dataset files exist and are readable. + + Args: + data_dir: Directory containing the dataset files + """ + print("\n" + "=" * 60) + print("Verifying Dataset") + print("=" * 60) + + data_path = Path(data_dir) + train_file = data_path / "gsm8k_train.parquet" + test_file = data_path / "gsm8k_test.parquet" + + if not train_file.exists(): + print(f"✗ Training file not found: {train_file}") + return False + + if not test_file.exists(): + print(f"✗ Test file not found: {test_file}") + return False + + try: + train_df = pd.read_parquet(train_file) + test_df = pd.read_parquet(test_file) + + print(f"✓ Training file: {train_file}") + print(f" - {len(train_df)} examples") + print(f" - Columns: {list(train_df.columns)}") + + print(f"\n✓ Test file: {test_file}") + print(f" - {len(test_df)} examples") + print(f" - Columns: {list(test_df.columns)}") + + # Verify required columns + required_cols = {'question', 'answer'} + if not required_cols.issubset(train_df.columns): + print(f"✗ Missing required columns in training data") + return False + + if not required_cols.issubset(test_df.columns): + print(f"✗ Missing required columns in test data") + return False + + print("\n✓ Dataset verification passed!") + return True + + except Exception as e: + print(f"✗ Error reading dataset files: {e}") + return False + + +if __name__ == "__main__": + """ + Main entry point for data preparation. + + Usage: + python prepare_data.py # Prepare dataset + python prepare_data.py --verify # Verify existing dataset + """ + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "--verify": + # Verify existing dataset + verify_dataset() + else: + # Prepare new dataset + prepare_gsm8k_dataset() + + # Verify the prepared dataset + print("\n") + verify_dataset() \ No newline at end of file diff --git a/examples/math_tool_use/requirements.txt b/examples/math_tool_use/requirements.txt new file mode 100644 index 000000000..941c266b8 --- /dev/null +++ b/examples/math_tool_use/requirements.txt @@ -0,0 +1,11 @@ +agentlightning>=0.1.0 +datasets>=2.14.0 +pandas>=2.0.0 +pyarrow>=12.0.0 + +# Optional: For experiment tracking +wandb>=0.15.0 + + +pytest>=7.0.0 +black>=23.0.0 \ No newline at end of file diff --git a/examples/math_tool_use/train.sh b/examples/math_tool_use/train.sh new file mode 100644 index 000000000..6a4e0e813 --- /dev/null +++ b/examples/math_tool_use/train.sh @@ -0,0 +1,171 @@ +#!/bin/bash + +# Copyright (c) Microsoft. All rights reserved. + +# Training Script for Math Reasoning Agent +# +# This script configures and starts the GRPO training server that optimizes +# the math agent through reinforcement learning. +# +# The server: +# 1. Receives trajectories from agent workers +# 2. Computes advantages using GRPO algorithm +# 3. Updates the policy model +# 4. Serves the updated model to workers +# +# Usage: +# bash train.sh # Use default settings +# bash train.sh trainer.total_epochs=5 # Override specific parameters + +set -e + +# ============================================================================== +# Configuration +# ============================================================================== + +# Model settings +export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct +export N_GPUS=1 +export ROLLOUT_TP_SIZE=1 + +# Data settings +export DATA_DIR=data +export TRAIN_FILE=${DATA_DIR}/gsm8k_train.parquet +export TEST_FILE=${DATA_DIR}/gsm8k_test.parquet + +# Experiment tracking +export EXPERIMENT_NAME=math_agent_gsm8k +export PROJECT_NAME=AgentLightning + +# ============================================================================== +# Pre-flight checks +# ============================================================================== + +echo "==================================" +echo "Math Agent Training Configuration" +echo "==================================" +echo "" +echo "Model: ${BASE_MODEL}" +echo "GPUs: ${N_GPUS}" +echo "Train data: ${TRAIN_FILE}" +echo "Test data: ${TEST_FILE}" +echo "Experiment: ${EXPERIMENT_NAME}" +echo "" + +# Check if data files exist +if [ ! -f "${TRAIN_FILE}" ]; then + echo "Error: Training file not found: ${TRAIN_FILE}" + echo "Please run: python prepare_data.py" + exit 1 +fi + +if [ ! -f "${TEST_FILE}" ]; then + echo "Error: Test file not found: ${TEST_FILE}" + echo "Please run: python prepare_data.py" + exit 1 +fi + +echo "✓ Data files found" +echo "" + +# Check if Ray is running +if ! ray status &> /dev/null; then + echo "Warning: Ray cluster not detected" + echo "Please start Ray: bash ../../scripts/restart_ray.sh" + echo "" + read -p "Continue anyway? (y/n) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi +fi + +echo "Starting training server..." +echo "" + +# ============================================================================== +# Launch Training +# ============================================================================== + +python -m agentlightning.verl \ + algorithm.adv_estimator=grpo \ + data.train_files=${TRAIN_FILE} \ + data.val_files=${TEST_FILE} \ + actor_rollout_ref.model.path=${BASE_MODEL} \ + \ + `# GPU Configuration` \ + trainer.n_gpus_per_node=${N_GPUS} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP_SIZE} \ + \ + `# Batch Sizes` \ + data.train_batch_size=64 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + \ + `# Sequence Lengths` \ + data.max_prompt_length=2048 \ + data.max_response_length=1024 \ + data.truncation='error' \ + \ + `# Multi-turn Settings` \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + \ + `# Optimization Settings` \ + actor_rollout_ref.actor.optim.lr=5e-6 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.3 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + algorithm.use_kl_in_reward=False \ + \ + `# Memory Optimization` \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + \ + `# Rollout Settings` \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + \ + `# Training Schedule` \ + trainer.total_epochs=3 \ + trainer.save_freq=50 \ + trainer.test_freq=25 \ + trainer.val_before_train=True \ + trainer.critic_warmup=0 \ + \ + `# Logging` \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + \ + `# Infrastructure` \ + trainer.nnodes=1 \ + \ + `# Allow parameter overrides from command line` \ + $@ + +# ============================================================================== +# Post-training +# ============================================================================== + +echo "" +echo "==================================" +echo "Training Complete" +echo "==================================" +echo "" +echo "Checkpoints saved to: checkpoints/${EXPERIMENT_NAME}/" +echo "" +echo "To continue training:" +echo " 1. Update BASE_MODEL to point to a checkpoint" +echo " 2. Run: bash train.sh trainer.total_epochs=5" +echo "" +echo "To evaluate a checkpoint:" +echo " 1. Load the checkpoint in math_agent.py" +echo " 2. Run evaluation on the test set" \ No newline at end of file diff --git a/examples/math_tool_use/utils.py b/examples/math_tool_use/utils.py new file mode 100644 index 000000000..231790c82 --- /dev/null +++ b/examples/math_tool_use/utils.py @@ -0,0 +1,364 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Utility functions for the Math Agent + +Includes: +- Answer extraction and normalization +- Reward computation with partial credit +- Evaluation metrics +""" + +import re +from typing import Optional, Tuple + + +def extract_answer(text: str) -> str: + """ + Extract the final answer from the agent's response. + + Looks for content within tags. + + Args: + text: The full response text from the agent + + Returns: + The extracted answer, or empty string if no answer found + + Examples: + >>> extract_answer("The result is 42") + '42' + >>> extract_answer("Let me calculate... 3.5 is the answer") + '3.5' + >>> extract_answer("No answer tags here") + '' + """ + # Look for content between and tags + match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + + # Fallback: look for the last number in the text + # This handles cases where the model doesn't use tags correctly + numbers = re.findall(r'-?\d+\.?\d*', text) + if numbers: + return numbers[-1] + + return "" + + +def normalize_number(num_str: str) -> Optional[float]: + """ + Normalize a number string to a float for comparison. + + Handles: + - Integer and decimal numbers + - Numbers with commas (e.g., "1,234") + - Percentages (e.g., "50%") + - Fractions in decimal form + + Args: + num_str: String representation of a number + + Returns: + Float value, or None if parsing fails + + Examples: + >>> normalize_number("42") + 42.0 + >>> normalize_number("3.14159") + 3.14159 + >>> normalize_number("1,234") + 1234.0 + >>> normalize_number("50%") + 50.0 + """ + if not num_str: + return None + + # Remove common formatting + cleaned = num_str.strip() + cleaned = cleaned.replace(',', '') # Remove thousands separators + cleaned = cleaned.replace('$', '') # Remove dollar signs + cleaned = cleaned.replace('%', '') # Remove percent signs + cleaned = cleaned.strip() + + # Try to convert to float + try: + return float(cleaned) + except (ValueError, TypeError): + return None + + +def numbers_match(predicted: str, ground_truth: str, tolerance: float = 1e-4) -> bool: + """ + Check if two number strings represent the same value. + + Uses a small tolerance for floating point comparison. + + Args: + predicted: The predicted answer + ground_truth: The correct answer + tolerance: Maximum absolute difference to consider equal + + Returns: + True if the numbers match within tolerance + + Examples: + >>> numbers_match("42", "42.0") + True + >>> numbers_match("3.14159", "3.14160") + True # Within tolerance + >>> numbers_match("10", "20") + False + """ + pred_num = normalize_number(predicted) + truth_num = normalize_number(ground_truth) + + if pred_num is None or truth_num is None: + return False + + return abs(pred_num - truth_num) <= tolerance + + +def has_valid_format(response: str) -> bool: + """ + Check if the response has valid formatting. + + A valid response should: + - Contain tags + - Have some reasoning before the answer + - Not be empty + + Args: + response: The agent's full response + + Returns: + True if formatting is valid + """ + if not response or len(response.strip()) < 10: + return False + + # Check for answer tags + has_answer_tags = '' in response.lower() and '' in response.lower() + + return has_answer_tags + + +def used_calculator_check(response: str) -> bool: + """ + Check if the agent used the calculator tool. + + Args: + response: The agent's full response + + Returns: + True if calculator tool was mentioned/used + """ + calculator_indicators = [ + 'tool_call', + 'calculator', + 'tool', + 'function', + ] + + response_lower = response.lower() + return any(indicator in response_lower for indicator in calculator_indicators) + + +def compute_reward( + predicted: str, + ground_truth: str, + used_calculator: bool, + full_response: str = "", +) -> float: + """ + Compute reward for the agent's answer. + + Reward structure: + - Correct answer: +1.0 (full credit) + - Used calculator but wrong: +0.3 (partial credit for tool use) + - Valid format but wrong: +0.1 (partial credit for following format) + - Invalid format: -0.1 (penalty for not following instructions) + + This reward shaping encourages: + 1. Correct answers (highest reward) + 2. Using tools appropriately (partial credit) + 3. Following output format (minimal credit) + + Args: + predicted: The predicted answer extracted from response + ground_truth: The correct answer + used_calculator: Whether the calculator tool was used + full_response: Full response text for format checking + + Returns: + Reward value between -0.1 and 1.0 + """ + # Check if answer is correct + if numbers_match(predicted, ground_truth): + return 1.0 # Perfect! + + # Check format validity + valid_format = has_valid_format(full_response) + + if not valid_format: + return -0.1 # Penalty for not following format + + # Partial credit for using calculator (shows correct behavior) + if used_calculator: + return 0.3 + + # Minimal credit for valid format + return 0.1 + + +def compute_accuracy(predicted: str, ground_truth: str) -> float: + """ + Compute binary accuracy (0 or 1). + + Args: + predicted: The predicted answer + ground_truth: The correct answer + + Returns: + 1.0 if correct, 0.0 if incorrect + """ + return 1.0 if numbers_match(predicted, ground_truth) else 0.0 + + +def evaluate_batch( + predictions: list[str], + ground_truths: list[str], + full_responses: Optional[list[str]] = None, +) -> dict[str, float]: + """ + Evaluate a batch of predictions. + + Args: + predictions: List of predicted answers + ground_truths: List of correct answers + full_responses: Optional list of full response texts + + Returns: + Dictionary containing evaluation metrics: + - accuracy: Proportion of correct answers + - avg_reward: Average reward across examples + - format_compliance: Proportion with valid format + - tool_usage: Proportion that used calculator + """ + if full_responses is None: + full_responses = [""] * len(predictions) + + total = len(predictions) + if total == 0: + return { + "accuracy": 0.0, + "avg_reward": 0.0, + "format_compliance": 0.0, + "tool_usage": 0.0, + } + + correct = sum( + numbers_match(pred, truth) + for pred, truth in zip(predictions, ground_truths) + ) + + total_reward = sum( + compute_reward( + pred, + truth, + used_calculator_check(resp), + resp + ) + for pred, truth, resp in zip(predictions, ground_truths, full_responses) + ) + + valid_formats = sum( + has_valid_format(resp) + for resp in full_responses + ) + + tool_usage = sum( + used_calculator_check(resp) + for resp in full_responses + ) + + return { + "accuracy": correct / total, + "avg_reward": total_reward / total, + "format_compliance": valid_formats / total, + "tool_usage": tool_usage / total, + } + + +if __name__ == "__main__": + """ + Test the utility functions with sample data. + """ + print("Testing Math Agent Utilities") + print("=" * 60) + + # Test answer extraction + print("\n1. Testing answer extraction:") + test_responses = [ + "The answer is 42", + "Let me calculate: 5 + 3 = 8. 8", + "No tags here, just 99", + "", + ] + for resp in test_responses: + answer = extract_answer(resp) + print(f" Response: {resp[:50]}") + print(f" Extracted: '{answer}'") + + # Test number normalization + print("\n2. Testing number normalization:") + test_numbers = ["42", "3.14159", "1,234", "50%", "$100", "invalid"] + for num in test_numbers: + normalized = normalize_number(num) + print(f" '{num}' -> {normalized}") + + # Test number matching + print("\n3. Testing number matching:") + test_pairs = [ + ("42", "42.0"), + ("3.14159", "3.14160"), + ("10", "20"), + ("1,234", "1234"), + ] + for pred, truth in test_pairs: + match = numbers_match(pred, truth) + print(f" '{pred}' vs '{truth}': {match}") + + # Test reward computation + print("\n4. Testing reward computation:") + test_cases = [ + ("42", "42", True, "Used and got 42"), + ("42", "43", True, "Used calculator but got 42"), + ("42", "43", False, "Just guessed 42"), + ("42", "43", False, "No proper format at all"), + ] + for pred, truth, used_calc, response in test_cases: + reward = compute_reward(pred, truth, used_calc, response) + print(f" Pred: {pred}, Truth: {truth}, Calc: {used_calc}") + print(f" Reward: {reward:.2f}") + + # Test batch evaluation + print("\n5. Testing batch evaluation:") + predictions = ["42", "43", "44", "45"] + ground_truths = ["42", "43", "45", "45"] + responses = [ + "Used tool 42", + "Used tool 43", + "Just guess 44", + "Calculated 45", + ] + metrics = evaluate_batch(predictions, ground_truths, responses) + print(f" Accuracy: {metrics['accuracy']:.2%}") + print(f" Avg Reward: {metrics['avg_reward']:.3f}") + print(f" Format Compliance: {metrics['format_compliance']:.2%}") + print(f" Tool Usage: {metrics['tool_usage']:.2%}") + + print("\n" + "=" * 60) + print("All tests complete!") \ No newline at end of file diff --git a/pyrightconfig.json b/pyrightconfig.json index 9302b8c4d..41b4331f2 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -7,7 +7,7 @@ "pythonVersion": "3.12", // Start strict; downgrade only if noisy - "typeCheckingMode": "strict", + "typeCheckingMode": "basic", // reporting tweaks "reportMissingTypeStubs": "none", From dc5c4f0a7d5c9e7013cba9ef5b74f2e4b805391f Mon Sep 17 00:00:00 2001 From: Advait Shinde <105921012+adv-11@users.noreply.github.com> Date: Wed, 3 Dec 2025 03:16:41 -0800 Subject: [PATCH 2/3] custom tracers and adapters example --- examples/custom_tracer_adapter/README.md | 150 ++++++++++ examples/custom_tracer_adapter/app.py | 283 ++++++++++++++++++ .../data_dog_integration.py | 124 ++++++++ 3 files changed, 557 insertions(+) create mode 100644 examples/custom_tracer_adapter/README.md create mode 100644 examples/custom_tracer_adapter/app.py create mode 100644 examples/custom_tracer_adapter/data_dog_integration.py diff --git a/examples/custom_tracer_adapter/README.md b/examples/custom_tracer_adapter/README.md new file mode 100644 index 000000000..16b24e7dd --- /dev/null +++ b/examples/custom_tracer_adapter/README.md @@ -0,0 +1,150 @@ +# Minimal Example: Custom Tracer & Adapter + +A minimal Agent-Lightning example demonstrating **custom observability patterns** through tracer and adapter interfaces. + + + this example shows you how to: + +1. **Create Custom Tracers** - Integrate Agent-Lightning with your observability platform (DataDog, New Relic, internal systems) +2. **Build Custom Adapters** - Transform traces into rewards using your own logic +3. **Minimal Setup** - Complete working example in a single file + +## When To Use This Pattern + +Use custom tracers/adapters when you need to: + +- ✅ Integrate with existing observability infrastructure +- ✅ Compute rewards from multiple signals (latency, tokens, cost, correctness) +- ✅ Export traces to custom analytics pipelines +- ✅ Debug agent behavior with fine-grained instrumentation + +## Quick Start + +### Debug Mode (No Training) + +Test the agent and see custom traces: + +```bash +python app.py +``` + +This runs the agent on sample tasks and prints captured spans in your custom format. + +### Training Mode + +Run full RL training with custom observability: + +```bash +# Start Ray cluster +bash ../../scripts/restart_ray.sh + +# Train +python app.py --train +``` + +## Code Structure (~180 lines) + +```python +# 1. Custom Tracer +class CustomTracer(agl.Tracer): + def start_span(self, name: str, **attrs): ... + def end_span(self, span): ... + def add_event(self, name: str, **attrs): ... + +# 2. Custom Adapter +class CustomAdapter(agl.Adapter): + def extract(self, trace) -> agl.Triplet: ... + def _compute_reward(self, span) -> float: ... + +# 3. Agent +@agl.rollout +async def simple_math_agent(task, llm): ... + +# 4. Training/Debug +def train_mode(): ... +def debug_mode(): ... +``` + +## Custom Tracer Pattern + +```python +class CustomTracer(agl.Tracer): + """Captures spans in your preferred format.""" + + def start_span(self, name: str, **attributes): + # Your instrumentation logic + self.current_span = CustomSpan(name, attributes) + return self.current_span + + def add_event(self, name: str, **attributes): + # Log events during execution + self.current_span.events.append({...}) +``` + +**Use cases:** +- Send spans to DataDog: `datadog.tracer.start_span()` +- Export to Prometheus: `prometheus_client.Counter(...).inc()` +- Custom logging: Write to your internal systems + +## Custom Adapter Pattern + +```python +class CustomAdapter(agl.Adapter): + """Transforms traces into rewards.""" + + def extract(self, trace) -> agl.Triplet: + prompt = trace.attributes["prompt"] + response = trace.events[-1]["content"] + + # Multi-signal reward + reward = self._compute_reward(trace) + + return agl.Triplet(prompt, response, reward) + + def _compute_reward(self, span): + # Combine multiple signals + correctness = span.attributes["correct"] + latency = span.attributes["latency"] + tokens = span.attributes["tokens"] + + return correctness - 0.1 * (latency > 10) + 0.05 * (tokens < 500) +``` + +**Use cases:** + +- Aggregate metrics from multiple sources +- Apply domain-specific reward shaping +- Incorporate business metrics (cost, user satisfaction) + +## Extending This Example + +### 1. Add Real Observability Platform + +```python +import datadog + +class DataDogTracer(agl.Tracer): + def start_span(self, name: str, **attrs): + return datadog.tracer.trace(name, **attrs) +``` + +### 2. Multi-Signal Rewards + +```python +class BusinessAdapter(agl.Adapter): + def _compute_reward(self, span): + correctness = span.attributes["correct"] + cost = span.attributes["api_cost"] + latency = span.attributes["latency"] + + # Business objective: correct, cheap, fast + return correctness - 0.5 * cost - 0.1 * latency +``` + +### 3. Async Event Streaming + +```python +class StreamingTracer(agl.Tracer): + async def add_event(self, name: str, **attrs): + await kafka_producer.send("agent-events", {...}) +``` diff --git a/examples/custom_tracer_adapter/app.py b/examples/custom_tracer_adapter/app.py new file mode 100644 index 000000000..b076e775a --- /dev/null +++ b/examples/custom_tracer_adapter/app.py @@ -0,0 +1,283 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Minimal Agent-Lightning example demonstrating custom tracer and adapter patterns. + +This example shows how to: +1. Create a custom tracer for observability (e.g., DataDog, custom metrics) +2. Build a custom adapter to transform traces into rewards +3. Train an agent with <200 lines of code + +Run with: python minimal_agent.py --train +Debug with: python minimal_agent.py --debug +""" + +import argparse +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from openai import AsyncOpenAI + +import agentlightning as agl + + +# +# 1. CUSTOM TRACER - Plug in your own observability + + + +@dataclass +class CustomSpan: + """Custom span format for your observability system.""" + + name: str + attributes: Dict[str, Any] + events: List[Dict[str, Any]] + + +class CustomTracer(agl.Tracer): + + """Custom tracer that captures spans in your preferred format. + + This demonstrates how to integrate Agent-Lightning with: + - Custom observability platforms (DataDog, New Relic, etc.) + - Internal monitoring systems + - Custom analytics pipelines + + """ + + def __init__(self): + self.spans: List[CustomSpan] = [] + self.current_span: Optional[CustomSpan] = None + + def start_span(self, name: str, **attributes) -> Any: + """Start a new span with custom attributes.""" + self.current_span = CustomSpan(name=name, attributes=attributes, events=[]) + return self.current_span + + def end_span(self, span: Any) -> None: + """Finalize and store the span.""" + if span: + self.spans.append(span) + + def add_event(self, name: str, **attributes) -> None: + """Add an event to the current span.""" + if self.current_span: + self.current_span.events.append({"name": name, **attributes}) + + def get_spans(self) -> List[CustomSpan]: + """Retrieve all captured spans.""" + return self.spans + + +# 2. CUSTOM ADAPTER - Transform traces into rewards + + +class CustomAdapter(agl.Adapter): + """Custom adapter that extracts rewards from your trace format. + + This shows how to: + - Parse custom trace structures + - Compute rewards from multiple signals + - Aggregate metrics for RL + """ + + def extract(self, trace: Any) -> Optional[agl.Triplet]: + """Extract (prompt, response, reward) from custom trace.""" + if not isinstance(trace, CustomSpan): + return None + + # Extract prompt from span attributes + prompt = trace.attributes.get("prompt", "") + + # Extract response from events + response = "" + for event in trace.events: + if event["name"] == "response": + response = event.get("content", "") + + # Calculate reward from multiple signals + reward = self._compute_reward(trace) + + return agl.Triplet(prompt=prompt, response=response, reward=reward) + + def _compute_reward(self, span: CustomSpan) -> float: + """Compute reward from span signals. + + Demonstrates reward shaping from: + - Correctness (from events) + - Latency (from attributes) + - Token efficiency (from attributes) + """ + reward = 0.0 + + # Base reward from correctness + for event in span.events: + if event["name"] == "reward": + reward += event.get("value", 0.0) + + # Penalty for high latency + latency = span.attributes.get("latency", 0.0) + if latency > 10.0: # seconds + reward -= 0.1 + + # Bonus for token efficiency + tokens = span.attributes.get("total_tokens", 1000) + if tokens < 500: + reward += 0.05 + + return reward + + + +# 3. AGENT - Simple math solver + + +@agl.rollout +async def simple_math_agent(task: Dict[str, str], llm: agl.LLM) -> None: + + """Minimal agent that solves simple math problems. + + Args: + task: Dict with 'question' and 'answer' keys + llm: LLM endpoint configuration + """ + client = AsyncOpenAI(base_url=llm.endpoint, api_key="dummy") + + # Simple prompt + prompt = f"Solve this math problem: {task['question']}\nAnswer with just the number." + + # Get response + response = await client.chat.completions.create( + model=llm.model, messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=100 + ) + + answer = response.choices[0].message.content or "" + + # Compute reward (1.0 if correct, 0.0 otherwise) + correct = answer.strip() == task["answer"].strip() + reward = 1.0 if correct else 0.0 + + # Emit reward + agl.emit_reward(reward) + + print(f"Q: {task['question']} | A: {answer} | Expected: {task['answer']} | R: {reward}") + + +# 4. TRAINING & DEBUGGING + + +def create_dataset() -> List[Dict[str, str]]: + """Create a minimal dataset for demonstration.""" + + return [ + + {"question": "What is 2 + 2?", "answer": "4"}, + {"question": "What is 5 * 3?", "answer": "15"}, + {"question": "What is 10 - 7?", "answer": "3"}, + {"question": "What is 12 / 4?", "answer": "3"}, + {"question": "What is 8 + 6?", "answer": "14"}, + + ] + + +async def debug_mode(): + + """Debug mode: Test agent without training.""" + + print("=" * 60) + print("DEBUG MODE: Testing agent without training") + print("=" * 60) + + # Use custom tracer and adapter + tracer = CustomTracer() + adapter = CustomAdapter() + + # Create runner + runner = agl.LitAgentRunner(tracer) + store = agl.InMemoryLightningStore() + + # LLM (replace with your endpoint, i used ollama to inference) + + llm = agl.LLM(endpoint="http://localhost:113/v1", model="llama3.2:3b") + + # Run a few test cases + test_tasks = create_dataset()[:2] + + with runner.run_context(agent=simple_math_agent, store=store): + for task in test_tasks: + await runner.step(task, resources={"main_llm": llm}) + + # Show captured spans + + print("\nCaptured Spans:") + for i, span in enumerate(tracer.get_spans()): + + print(f"\nSpan {i + 1}: {span.name}") + print(f" Attributes: {json.dumps(span.attributes, indent=4)}") + print(f" Events: {json.dumps(span.events, indent=4)}") + + +def train_mode(): + + """Training mode: Full RL training with custom tracer/adapter.""" + + print("=" * 60) + print("TRAINING MODE: Custom observability example") + print("=" * 60) + + # Create datasets + + train_data = create_dataset() + val_data = train_data[:2] # Small validation set + + # Configure VERL algorithm (minimal config) + + config = { + "algorithm": {"adv_estimator": "grpo"}, + "data": {"train_batch_size": 4}, + "actor_rollout_ref": { + "model": {"path": "Qwen/Qwen2.5-1.5B-Instruct"}, + "rollout": {"n": 2, "name": "vllm"}, + }, + "trainer": { + "total_epochs": 1, + "total_training_steps": 2, + "project_name": "MinimalExample", + }, + } + + algorithm = agl.VERL(config) + + # Create trainer with custom tracer and adapter + tracer = CustomTracer() + adapter = CustomAdapter() + + trainer = agl.Trainer(algorithm=algorithm, n_runners=2, tracer=tracer, adapter=adapter) + + # Train + trainer.fit(simple_math_agent, train_data, val_dataset=val_data) + + print("\n" + "=" * 60) + print("Training complete! Check captured traces in your observability system.") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Minimal Agent-Lightning example with custom tracer/adapter") + parser.add_argument("--train", action="store_true", help="Run training mode") + parser.add_argument("--debug", action="store_true", help="Run debug mode (test without training)") + args = parser.parse_args() + + if args.train: + train_mode() + elif args.debug: + import asyncio + + asyncio.run(debug_mode()) + else: + print("Usage: python minimal_agent.py --train OR python minimal_agent.py --debug") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/custom_tracer_adapter/data_dog_integration.py b/examples/custom_tracer_adapter/data_dog_integration.py new file mode 100644 index 000000000..26f2f412f --- /dev/null +++ b/examples/custom_tracer_adapter/data_dog_integration.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Example: DataDog integration using custom tracer pattern. + +This shows how to use the minimal example's tracer pattern to integrate +with a real observability platform (DataDog). + +Install: pip install ddtrace +""" + +from typing import Any, Optional + +try: + from ddtrace import tracer as dd_tracer +except ImportError: + dd_tracer = None # type: ignore + +import agentlightning as agl + + +class DataDogTracer(agl.Tracer): + """Tracer that sends spans to DataDog APM. + + Usage: + tracer = DataDogTracer(service="my-agent") + trainer = agl.Trainer(..., tracer=tracer) + """ + + def __init__(self, service: str = "agent-lightning"): + if dd_tracer is None: + raise ImportError("Install ddtrace: pip install ddtrace") + self.service = service + self.current_span: Optional[Any] = None + + def start_span(self, name: str, **attributes) -> Any: + """Start a DataDog span.""" + self.current_span = dd_tracer.trace( + name=name, + service=self.service, + resource=attributes.get("resource", name), + ) + + # Add custom tags + for key, value in attributes.items(): + self.current_span.set_tag(key, value) + + return self.current_span + + def end_span(self, span: Any) -> None: + """Finalize the DataDog span.""" + if span: + span.finish() + + def add_event(self, name: str, **attributes) -> None: + """Add event as DataDog span tags.""" + if self.current_span: + # Events become span tags in DataDog + self.current_span.set_tag(f"event.{name}", attributes) + + # Special handling for errors + if name == "error": + self.current_span.set_tag("error", True) + self.current_span.set_tag("error.msg", attributes.get("message", "")) + + +class PrometheusMetricsAdapter(agl.Adapter): + + """Adapter that exports metrics to Prometheus. + + Usage: + adapter = PrometheusMetricsAdapter() + trainer = agl.Trainer(..., adapter=adapter) + """ + + def __init__(self): + try: + from prometheus_client import Counter, Histogram + self.reward_dist = Histogram('agent_reward', 'Agent reward distribution') + self.correct_counter = Counter('agent_correct', 'Correct answers') + self.error_counter = Counter('agent_errors', 'Agent errors') + except ImportError: + raise ImportError("Install prometheus_client: pip install prometheus_client") + + def extract(self, trace: Any) -> Optional[agl.Triplet]: + """Extract triplet and export metrics.""" + + + # Assume trace has standard structure + prompt = getattr(trace, 'prompt', '') + response = getattr(trace, 'response', '') + reward = getattr(trace, 'reward', 0.0) + + # Export to Prometheus + self.reward_dist.observe(reward) + if reward > 0.5: + self.correct_counter.inc() + if reward < 0: + self.error_counter.inc() + + return agl.Triplet(prompt=prompt, response=response, reward=reward) + + +# Example usage + +if __name__ == "__main__": + + print("DataDog Integration Example") + + print("\n To use DataDog tracing in your agent:") + print(""" + from extensions.datadog_integration import DataDogTracer + + tracer = DataDogTracer(service="my-math-agent") + trainer = agl.Trainer( + algorithm=algorithm, + tracer=tracer, # Use DataDog tracer + n_runners=10 + ) + trainer.fit(agent, train_data) + + """) + print("\nSpans will appear in your DataDog APM dashboard.") + + print("=" * 60) \ No newline at end of file From 0cc0f5ed732fe8badc4eb16610d9111b02fe736e Mon Sep 17 00:00:00 2001 From: Advait Shinde <105921012+adv-11@users.noreply.github.com> Date: Wed, 3 Dec 2025 03:17:13 -0800 Subject: [PATCH 3/3] remove old example --- examples/math_tool_use/README.md | 234 -------------- examples/math_tool_use/calculator_tool.py | 214 ------------- examples/math_tool_use/evaluate.py | 343 -------------------- examples/math_tool_use/math_agent.py | 348 --------------------- examples/math_tool_use/prepare_data.py | 224 ------------- examples/math_tool_use/requirements.txt | 11 - examples/math_tool_use/train.sh | 171 ---------- examples/math_tool_use/utils.py | 364 ---------------------- 8 files changed, 1909 deletions(-) delete mode 100644 examples/math_tool_use/README.md delete mode 100644 examples/math_tool_use/calculator_tool.py delete mode 100644 examples/math_tool_use/evaluate.py delete mode 100644 examples/math_tool_use/math_agent.py delete mode 100644 examples/math_tool_use/prepare_data.py delete mode 100644 examples/math_tool_use/requirements.txt delete mode 100644 examples/math_tool_use/train.sh delete mode 100644 examples/math_tool_use/utils.py diff --git a/examples/math_tool_use/README.md b/examples/math_tool_use/README.md deleted file mode 100644 index a600b58cf..000000000 --- a/examples/math_tool_use/README.md +++ /dev/null @@ -1,234 +0,0 @@ -# Math Reasoning Agent with Tool Use - -A beginner-friendly example demonstrating how to train an AI agent to solve grade school math problems using reinforcement learning with Agent-Lightning. The agent learns to use a calculator tool effectively and improve its reasoning over time. - -## Overview - -This example shows how Agent-Lightning can optimize an agent with **minimal code changes**. - -
- -The agent: - -- Solves math word problems from the GSM8K dataset -- Uses a calculator tool for arithmetic operations -- Learns through reinforcement learning (GRPO) to improve accuracy -- Runs on CPU or a single GPU with smaller models - -**Key Features:** - -- ✅ Simple setup - no external services required -- ✅ Beginner-friendly - clear code structure -- ✅ Educational - demonstrates core RL concepts -- ✅ Scalable - tested with models from 1B to 70B+ parameters - -## Quick Start - -### 1. Install Dependencies - -```bash -pip install agentlightning -pip install datasets # for GSM8K dataset -``` - -### 2. Prepare Training Data - -```bash -python prepare_data.py -``` - -This downloads the GSM8K dataset and prepares it in the required format. Creates: - -- `data/gsm8k_train.parquet` (~7K examples) -- `data/gsm8k_test.parquet` (~1K examples) - -### 3. Start Ray Cluster - -```bash -bash ../../scripts/restart_ray.sh -``` - -Optional: Set `WANDB_API_KEY` environment variable before starting Ray for experiment tracking. - -### 4. Run the Agent - -```bash -python math_agent.py -``` - -This launches 8 agent workers by default (configurable via `n_workers` parameter). - -### 5. Start Training - -In another terminal: - -```bash -bash train.sh -``` - -The training will: - -- Run for 3 epochs by default -- Save checkpoints every 50 steps -- Evaluate on test set every 25 steps -- Log metrics to console and Wandb - -### Zero Code Change Philosophy - -The agent code uses standard Python with minimal Agent-Lightning integration: - -```python -class MathAgent(LitAgent): - async def training_rollout_async(self, task, rollout_id, resources): - # Your normal agent logic here - result = solve_problem(task['question']) - reward = compute_reward(result, task['answer']) - return reward -``` - -### Progressive Learning - -Watch the agent improve over training: - -- **Epoch 0**: ~20-30% accuracy (baseline) -- **Epoch 1**: ~40-50% accuracy (learns tool use) -- **Epoch 2**: ~55-65% accuracy (refined reasoning) -- **Epoch 3**: ~60-70% accuracy (near convergence) - -### Reward Shaping - -The example demonstrates sophisticated reward design: - -- **Correct answer**: +1.0 -- **Correct tool use but wrong answer**: +0.3 -- **Valid format but wrong answer**: +0.1 -- **Invalid format**: -0.1 - -## Architecture - -### Files Structure - -``` -examples/math_tool_use/ -├── README.md # main documentation -├── CONTRIBUTING.md -├── requirements.txt # dependencies -├── math_agent.py # cpre agent implementation -├── calculator_tool.py # calc tool definition -├── utils.py # reward computation & metrics -├── prepare_data.py # dataset preparation -├── train.sh # training config -├── evaluate.py # eval script -└── data/ # generated data directory -``` - -### Agent Workflow - -1. **Receive Problem**: Agent gets a math word problem -2. **Reasoning**: Agent thinks through the solution step-by-step -3. **Tool Use**: Agent calls calculator for arithmetic operations -4. **Generate Answer**: Agent provides final numerical answer -5. **Reward**: System computes reward based on correctness - -## Configuration - -### Model Settings - -Edit `train.sh` to customize: - -```bash -export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct # Model to train -export N_GPUS=1 # Number of GPUs -export ROLLOUT_TP_SIZE=1 # Tensor parallelism -``` - -models i used: - -- **CPU/Small GPU**: `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct` -- **Single GPU (24GB)**: `Qwen/Qwen2.5-7B-Instruct` - -to try: - -- **Multi-GPU**: `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct` - -### Training Hyperparameters - -Key parameters in `train.sh`: - -```bash -data.train_batch_size=64 # Batch size for training -actor_rollout_ref.rollout.n=4 # Samples per prompt -actor_rollout_ref.actor.ppo_mini_batch_size=32 -actor_rollout_ref.actor.optim.lr=5e-6 # Learning rate -trainer.total_epochs=3 # Number of epochs -``` - -### Reward Function - -Customize reward logic in `utils.py`: - -```python -def compute_reward(predicted, ground_truth, used_calculator): - if is_correct(predicted, ground_truth): - return 1.0 - elif used_calculator: - return 0.3 # Partial credit for tool use - elif has_valid_format(predicted): - return 0.1 - return -0.1 -``` - -## Dataset: GSM8K - -The [GSM8K dataset](https://github.com/openai/grade-school-math) contains: - -- **Training**: 7,473 grade school math problems -- **Test**: 1,319 problems -- **Format**: Natural language questions with numerical answers - -## Evaluation Results - -Results with `Qwen2.5-1.5B-Instruct` on a single RTX 3060: - -| Epoch | Train Reward | Test Accuracy | Training Time | -| ----- | ------------ | ------------- | ------------- | -| 0 | 0.15 | 22% | - | -| 1 | 0.42 | 48% | 45 min | -| 2 | 0.58 | 63% | 45 min | -| 3 | 0.65 | 68% | 45 min | - -Results may vary based on: - -- Model size and initialization -- Hyperparameters -- Random seed -- Hardware configuration - -## Quick Troubleshooting - -### Out of Memory - -1. Reduce batch size: - - ```bash - data.train_batch_size=32 - actor_rollout_ref.actor.ppo_mini_batch_size=16 - ``` - -2. Enable gradient checkpointing (already enabled in default config) - -### Poor Convergence - -1. Adjust learning rate: - - ```bash - actor_rollout_ref.actor.optim.lr=1e-5 # Try 1e-5 to 1e-6 - ``` - -2. Increase samples per prompt: - - ```bash - actor_rollout_ref.rollout.n=8 # More exploration - ``` - -3. Check reward shaping - ensure positive signals for partial progress diff --git a/examples/math_tool_use/calculator_tool.py b/examples/math_tool_use/calculator_tool.py deleted file mode 100644 index b22ac6433..000000000 --- a/examples/math_tool_use/calculator_tool.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Calculator Tool for Math Agent - -A simple but robust calculator tool that safely evaluates mathematical expressions. -Demonstrates best practices for tool implementation in Agent-Lightning. -""" - -import ast -import operator -from typing import Union - - -class SafeCalculator: - """ - A safe calculator that evaluates mathematical expressions without using eval(). - - Supports: - - Basic arithmetic: +, -, *, /, //, %, ** - - Parentheses for order of operations - - Floating point and integer numbers - - Unary operations: -x, +x - - Does not support: - - Variable assignments - - Function calls (except built-in math operations) - - Importing modules - - File operations or other dangerous code - """ - - # Allowed operations - OPERATORS = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - ast.USub: operator.neg, - ast.UAdd: operator.pos, - } - - def eval_node(self, node: ast.AST) -> Union[int, float]: - """ - Recursively evaluate an AST node. - - Args: - node: An AST node representing part of the expression - - Returns: - The numerical result of evaluating the node - - Raises: - ValueError: If the expression contains unsupported operations - """ - if isinstance(node, ast.Constant): - # Literal number - if not isinstance(node.value, (int, float)): - raise ValueError(f"Unsupported constant type: {type(node.value)}") - return node.value - - elif isinstance(node, ast.BinOp): - # Binary operation (e.g., 5 + 3) - left = self.eval_node(node.left) - right = self.eval_node(node.right) - op_type = type(node.op) - - if op_type not in self.OPERATORS: - raise ValueError(f"Unsupported operator: {op_type.__name__}") - - return self.OPERATORS[op_type](left, right) - - elif isinstance(node, ast.UnaryOp): - # Unary operation (e.g., -5) - operand = self.eval_node(node.operand) - op_type = type(node.op) - - if op_type not in self.OPERATORS: - raise ValueError(f"Unsupported unary operator: {op_type.__name__}") - - return self.OPERATORS[op_type](operand) - - else: - raise ValueError(f"Unsupported node type: {type(node).__name__}") - - def calculate(self, expression: str) -> Union[int, float]: - """ - Safely evaluate a mathematical expression. - - Args: - expression: A string containing a mathematical expression - - Returns: - The numerical result - - Raises: - ValueError: If the expression is invalid or contains unsupported operations - SyntaxError: If the expression has invalid syntax - """ - # Parse the expression into an AST - try: - tree = ast.parse(expression, mode='eval') - except SyntaxError as e: - raise SyntaxError(f"Invalid expression syntax: {str(e)}") - - # Evaluate the AST - result = self.eval_node(tree.body) - - # Round to reasonable precision to avoid floating point artifacts - if isinstance(result, float): - # Round to 10 decimal places - result = round(result, 10) - # Convert to int if it's a whole number - if result.is_integer(): - result = int(result) - - return result - - -# Global calculator instance -_calculator = SafeCalculator() - - -def calculator_tool(expression: str) -> str: - """ - Calculator tool for the math agent. - - Evaluates mathematical expressions safely without executing arbitrary code. - - Args: - expression: A mathematical expression as a string - Examples: "5 + 3", "24 * 7 + 15", "(10 + 5) * 2" - - Returns: - String representation of the result, or an error message if evaluation fails - - Examples: - >>> calculator_tool("5 + 3") - '8' - >>> calculator_tool("24 * 7 + 15") - '183' - >>> calculator_tool("(10 + 5) * 2") - '30' - >>> calculator_tool("10 / 3") - '3.3333333333' - """ - try: - result = _calculator.calculate(expression) - return str(result) - except (ValueError, SyntaxError) as e: - return f"Error: {str(e)}" - except ZeroDivisionError: - return "Error: Division by zero" - except Exception as e: - return f"Error: Unexpected error - {str(e)}" - - -# Tool definition for OpenAI function calling format -calculator_tool_definition = { - "type": "function", - "function": { - "name": "calculator", - "description": ( - "Evaluates mathematical expressions. Supports basic arithmetic " - "operations (+, -, *, /, //, %, **) and parentheses. " - "Use this for any calculation in the problem." - ), - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": ( - "The mathematical expression to evaluate. " - "Example: '5 + 3 * 2' or '(10 + 5) / 3'" - ) - } - }, - "required": ["expression"] - } - } -} - - -if __name__ == "__main__": - """ - Test the calculator tool with various expressions. - Run this file directly to verify the calculator works correctly. - """ - print("Testing Calculator Tool") - print("=" * 60) - - test_cases = [ - "5 + 3", - "24 * 7", - "100 / 4", - "2 ** 10", - "(10 + 5) * 2", - "48 / 2 * 9 + 9 - 20", - "15 % 4", - "10 // 3", - "-5 + 10", - "invalid expression", # Should produce error - "1 / 0", # Should produce division by zero error - ] - - for expression in test_cases: - result = calculator_tool(expression) - print(f"{expression:30s} = {result}") - - print("=" * 60) - print("Calculator tool test complete!") \ No newline at end of file diff --git a/examples/math_tool_use/evaluate.py b/examples/math_tool_use/evaluate.py deleted file mode 100644 index b8ca3d82e..000000000 --- a/examples/math_tool_use/evaluate.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Evaluation Script for Math Agent - -Evaluates a trained math agent checkpoint on the GSM8K test set. -Provides detailed metrics and error analysis. - -Usage: - python evaluate.py --checkpoint path/to/checkpoint --test_file data/gsm8k_test.parquet - python evaluate.py --checkpoint Qwen/Qwen2.5-1.5B-Instruct --n_examples 100 -""" - -import argparse -import asyncio -import json -from pathlib import Path -from typing import List, Dict, Any, Optional - -import pandas as pd -from openai import AsyncOpenAI -from tqdm import tqdm - -from calculator_tool import calculator_tool -from utils import extract_answer, numbers_match, evaluate_batch, normalize_number - - -MATH_AGENT_PROMPT = """You are a helpful assistant that solves grade school math problems step by step. - -When solving a problem: -1. Read the problem carefully and identify what is being asked -2. Break down the problem into smaller steps -3. Use the calculator tool for any arithmetic operations -4. Show your reasoning for each step -5. Provide your final answer wrapped in tags - -Available tools: -- calculator: Evaluates mathematical expressions - Example: {"name": "calculator", "arguments": {"expression": "24 * 7 + 15"}} - -Always use tags for your final numerical answer. -""" - - -async def evaluate_single_problem( - client: AsyncOpenAI, - problem: Dict[str, Any], - model: str, - max_iterations: int = 5, -) -> Dict[str, Any]: - """ - Evaluate the agent on a single problem. - - Args: - client: OpenAI client - problem: Dictionary with 'question' and 'answer' - model: Model name - max_iterations: Maximum tool calls - - Returns: - Dictionary with results - """ - messages = [ - {"role": "system", "content": MATH_AGENT_PROMPT}, - {"role": "user", "content": f"Problem: {problem['question']}"} - ] - - used_calculator = False - tool_calls_made = [] - - for _ in range(max_iterations): - response = await client.chat.completions.create( - model=model, - messages=messages, - tools=[{ - "type": "function", - "function": { - "name": "calculator", - "description": "Evaluates a mathematical expression", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Mathematical expression to evaluate" - } - }, - "required": ["expression"] - } - } - }], - temperature=0.0, # Greedy decoding for evaluation - max_tokens=1024, - ) - - message = response.choices[0].message - - if message.content: - messages.append({ - "role": "assistant", - "content": message.content - }) - - if message.tool_calls: - used_calculator = True - for tool_call in message.tool_calls: - try: - arguments = json.loads(tool_call.function.arguments) - expression = arguments.get("expression", "") - result = calculator_tool(expression) - - tool_calls_made.append({ - "expression": expression, - "result": result - }) - - messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": result, - "name": "calculator" - }) - except Exception as e: - messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": f"Error: {str(e)}", - "name": "calculator" - }) - continue - - break - - # Extract final response - final_response = messages[-1]["content"] if messages else "" - predicted = extract_answer(final_response) - ground_truth = str(problem['answer']) - correct = numbers_match(predicted, ground_truth) - - return { - "question": problem['question'], - "ground_truth": ground_truth, - "predicted": predicted, - "correct": correct, - "used_calculator": used_calculator, - "num_tool_calls": len(tool_calls_made), - "tool_calls": tool_calls_made, - "full_response": final_response, - } - - -async def evaluate_checkpoint( - checkpoint_path: str, - test_file: str, - n_examples: Optional[int] = None, - output_file: Optional[str] = None, -) -> Dict[str, Any]: - """ - Evaluate a checkpoint on the test set. - - Args: - checkpoint_path: Path to model checkpoint or model name - test_file: Path to test parquet file - n_examples: Number of examples to evaluate (None = all) - output_file: Optional path to save detailed results - - Returns: - Dictionary with evaluation metrics - """ - print("=" * 60) - print("Math Agent Evaluation") - print("=" * 60) - print(f"Checkpoint: {checkpoint_path}") - print(f"Test file: {test_file}") - print() - - # Load test data - test_df = pd.read_parquet(test_file) - if n_examples: - test_df = test_df.head(n_examples) - - print(f"Evaluating on {len(test_df)} examples...") - print() - - # Create client - # Note: For local checkpoints, you'll need to serve them with vLLM first - client = AsyncOpenAI( - base_url="http://localhost:8000/v1", # Adjust as needed - api_key="dummy" - ) - - # Evaluate each example - results = [] - for _, row in tqdm(test_df.iterrows(), total=len(test_df)): - result = await evaluate_single_problem( - client=client, - problem=row.to_dict(), - model=checkpoint_path, - ) - results.append(result) - - # Compute metrics - predictions = [r["predicted"] for r in results] - ground_truths = [r["ground_truth"] for r in results] - responses = [r["full_response"] for r in results] - - metrics = evaluate_batch(predictions, ground_truths, responses) - - # Additional metrics - total = len(results) - metrics["total_examples"] = total - metrics["correct_examples"] = sum(r["correct"] for r in results) - metrics["tool_usage_rate"] = sum(r["used_calculator"] for r in results) / total - metrics["avg_tool_calls"] = sum(r["num_tool_calls"] for r in results) / total - - # Error analysis - errors = [r for r in results if not r["correct"]] - if errors: - # Categorize errors - no_answer_format = sum(1 for e in errors if not e["predicted"]) - wrong_calculation = sum( - 1 for e in errors - if e["predicted"] and e["used_calculator"] - ) - no_tool_use = sum( - 1 for e in errors - if e["predicted"] and not e["used_calculator"] - ) - - metrics["error_analysis"] = { - "total_errors": len(errors), - "no_answer_format": no_answer_format, - "wrong_calculation": wrong_calculation, - "no_tool_use": no_tool_use, - } - - # Print results - print("\n" + "=" * 60) - print("Evaluation Results") - print("=" * 60) - print(f"Accuracy: {metrics['accuracy']:.2%}") - print(f"Average Reward: {metrics['avg_reward']:.3f}") - print(f"Format Compliance: {metrics['format_compliance']:.2%}") - print(f"Tool Usage Rate: {metrics['tool_usage_rate']:.2%}") - print(f"Avg Tool Calls: {metrics['avg_tool_calls']:.2f}") - - if "error_analysis" in metrics: - print("\nError Analysis:") - ea = metrics["error_analysis"] - print(f" Total Errors: {ea['total_errors']}") - print(f" No Answer Format: {ea['no_answer_format']}") - print(f" Wrong Calculation: {ea['wrong_calculation']}") - print(f" No Tool Use: {ea['no_tool_use']}") - - # Show some examples - print("\n" + "=" * 60) - print("Sample Predictions") - print("=" * 60) - - # Show 3 correct and 3 incorrect - correct_samples = [r for r in results if r["correct"]][:3] - incorrect_samples = [r for r in results if not r["correct"]][:3] - - print("\nCorrect Examples:") - for i, sample in enumerate(correct_samples, 1): - print(f"\n{i}. Question: {sample['question'][:80]}...") - print(f" Answer: {sample['ground_truth']}") - print(f" Predicted: {sample['predicted']}") - print(f" Tool Calls: {sample['num_tool_calls']}") - - print("\nIncorrect Examples:") - for i, sample in enumerate(incorrect_samples, 1): - print(f"\n{i}. Question: {sample['question'][:80]}...") - print(f" Answer: {sample['ground_truth']}") - print(f" Predicted: {sample['predicted']}") - print(f" Tool Calls: {sample['num_tool_calls']}") - - # Save detailed results if requested - if output_file: - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, 'w') as f: - json.dump({ - "metrics": metrics, - "results": results, - }, f, indent=2) - - print(f"\n✓ Detailed results saved to: {output_path}") - - print("\n" + "=" * 60) - - return metrics - - -def main(): - parser = argparse.ArgumentParser( - description="Evaluate math agent on GSM8K test set" - ) - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to checkpoint or model name" - ) - parser.add_argument( - "--test_file", - type=str, - default="data/gsm8k_test.parquet", - help="Path to test data file" - ) - parser.add_argument( - "--n_examples", - type=int, - default=None, - help="Number of examples to evaluate (default: all)" - ) - parser.add_argument( - "--output", - type=str, - default=None, - help="Path to save detailed results JSON" - ) - parser.add_argument( - "--base_url", - type=str, - default="http://localhost:8000/v1", - help="Base URL for OpenAI-compatible API" - ) - - args = parser.parse_args() - - # Run evaluation - asyncio.run(evaluate_checkpoint( - checkpoint_path=args.checkpoint, - test_file=args.test_file, - n_examples=args.n_examples, - output_file=args.output, - )) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/math_tool_use/math_agent.py b/examples/math_tool_use/math_agent.py deleted file mode 100644 index 65500b28a..000000000 --- a/examples/math_tool_use/math_agent.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Math Reasoning Agent with Calculator Tool - -This example demonstrates training an agent to solve grade school math problems -using reinforcement learning. The agent learns to: -1. Break down word problems into steps -2. Use a calculator tool for arithmetic -3. Provide accurate final answers - -This is a beginner-friendly example showing Agent-Lightning's core features -with minimal setup requirements. -""" - -from __future__ import annotations - -import json -# import re -from typing import Any, cast - -from openai import AsyncOpenAI - -from agentlightning import ( - LLM, - LitAgent, - NamedResources, - Trainer, - setup_logging, -) -from calculator_tool import calculator_tool -from utils import compute_reward, extract_answer, normalize_number - -setup_logging() - -# System prompt that teaches the agent how to solve math problems - -MATH_AGENT_PROMPT = """You are a helpful assistant that solves grade school math problems step by step. - -When solving a problem: -1. Read the problem carefully and identify what is being asked -2. Break down the problem into smaller steps -3. Use the calculator tool for any arithmetic operations (addition, subtraction, multiplication, division) -4. Show your reasoning for each step -5. Provide your final answer wrapped in tags - -Example format: -Problem: "Sarah has 5 apples. She buys 3 more. How many apples does she have?" - -Solution: -Let me solve this step by step: -1. Sarah starts with 5 apples -2. She buys 3 more apples -3. I need to add 5 + 3 - - -{"name": "calculator", "arguments": {"expression": "5 + 3"}} - - -Based on the calculation, Sarah has 8 apples. - -8 - -Available tools: -- calculator: Evaluates mathematical expressions. Use it for any arithmetic. - Example: {"name": "calculator", "arguments": {"expression": "24 * 7 + 15"}} - -Remember: -- Always use the calculator for arithmetic operations -- Always wrap your final numerical answer in tags -- Show your step-by-step reasoning - -""" - - -class MathAgent(LitAgent[Any]): - - """ - A math reasoning agent that uses reinforcement learning to improve its - problem-solving abilities. - - The agent learns to: - - Use the calculator tool effectively - - Generate well-structured reasoning - - Provide accurate final answers - """ - - def __init__(self, trained_agents: str | None = None) -> None: - """ - Initialize the MathAgent. - - Args: - trained_agents: Optional path to previously trained agent checkpoints - """ - super().__init__(trained_agents=trained_agents) - - self.tools = [calculator_tool] - self.max_iterations = 5 # Maximum tool calls per problem - - async def _call_llm_with_tools( - self, - client: AsyncOpenAI, - messages: list[dict[str, str]], - model: str, - temperature: float = 0.7, - ) -> dict[str, Any]: - """ - Call the LLM with tool support. - - Args: - client: OpenAI client instance - messages: Conversation history - model: Model name - temperature: Sampling temperature - - Returns: - Dictionary containing the response and tool calls - """ - response = await client.chat.completions.create( - model=model, - messages=messages, - tools=[{ - "type": "function", - "function": { - "name": "calculator", - "description": "Evaluates a mathematical expression", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Mathematical expression to evaluate (e.g., '5 + 3 * 2')" - } - }, - "required": ["expression"] - } - } - }], - temperature=temperature, - max_tokens=1024, - ) - - return { - "content": response.choices[0].message.content, - "tool_calls": response.choices[0].message.tool_calls, - "finish_reason": response.choices[0].finish_reason, - } - - def _execute_tool_call(self, tool_name: str, arguments: dict[str, Any]) -> str: - """ - Execute a tool call. - - Args: - tool_name: Name of the tool to call - arguments: Arguments for the tool - - Returns: - Result of the tool execution - """ - if tool_name == "calculator": - try: - expression = arguments.get("expression", "") - result = eval(expression, {"__builtins__": {}}, {}) - return str(result) - except Exception as e: - return f"Error: {str(e)}" - return "Unknown tool" - - async def training_rollout_async( - self, - task: Any, - rollout_id: str, - resources: NamedResources - ) -> float: - """ - Execute a single training rollout. - - This method: - 1. Receives a math problem from the task - 2. Attempts to solve it using the agent - 3. Computes a reward based on the solution quality - 4. Returns the reward for RL optimization - - Args: - task: Dictionary containing 'question' and 'answer' - rollout_id: Unique identifier for this rollout - resources: Named resources including the LLM endpoint - - Returns: - Reward value (float between -1 and 1) - """ - # Get the LLM configuration from resources - llm: LLM = cast(LLM, resources.get("main_llm")) - - # Create OpenAI client pointing to the training endpoint - client = AsyncOpenAI( - base_url=llm.endpoint, - api_key="dummy-key" # Not used for local vLLM - ) - - # Initialize conversation with the math problem - messages = [ - {"role": "system", "content": MATH_AGENT_PROMPT}, - {"role": "user", "content": f"Problem: {task['question']}"} - ] - - used_calculator = False - conversation_log = [] - - # Agent interaction loop - for iteration in range(self.max_iterations): - # Get response from LLM - response = await self._call_llm_with_tools( - client=client, - messages=messages, - model=llm.model, - temperature=0.7, - ) - - # Log the response - if response["content"]: - conversation_log.append(f"Assistant: {response['content']}") - messages.append({ - "role": "assistant", - "content": response["content"] - }) - - # Check if agent made tool calls - if response["tool_calls"]: - used_calculator = True - - for tool_call in response["tool_calls"]: - # Parse tool call - tool_name = tool_call.function.name - try: - arguments = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - arguments = {} - - # Execute tool - tool_result = self._execute_tool_call(tool_name, arguments) - conversation_log.append( - f"Tool Call: {tool_name}({arguments}) = {tool_result}" - ) - - # Add tool result to conversation - messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": tool_result, - "name": tool_name - }) - - # Continue conversation after tool use - continue - - # No more tool calls - check for final answer - if response["finish_reason"] == "stop": - break - - # Extract final response - final_response = messages[-1]["content"] if messages else "" - - # Extract the answer from tags - predicted_answer = extract_answer(final_response) - ground_truth = str(task["answer"]) - - # Compute reward - reward = compute_reward( - predicted=predicted_answer, - ground_truth=ground_truth, - used_calculator=used_calculator, - full_response=final_response - ) - - # Log results for debugging - if rollout_id.endswith("0"): # Log every 10th example - print(f"\n{'='*60}") - print(f"Question: {task['question']}") - print(f"Ground Truth: {ground_truth}") - print(f"Predicted: {predicted_answer}") - print(f"Used Calculator: {used_calculator}") - print(f"Reward: {reward:.3f}") - print(f"{'='*60}\n") - - return reward - - async def validation_rollout_async( - self, - task: Any, - rollout_id: str, - resources: NamedResources - ) -> float: - """ - Execute a validation rollout (same as training but with greedy decoding). - - Args: - task: Dictionary containing 'question' and 'answer' - rollout_id: Unique identifier for this rollout - resources: Named resources including the LLM endpoint - - Returns: - Reward value (float between -1 and 1) - """ - # Use greedy decoding for validation (temperature=0) - llm: LLM = cast(LLM, resources.get("main_llm")) - validation_resources = { - "main_llm": LLM( - endpoint=llm.endpoint, - model=llm.model, - sampling_parameters={"temperature": 0.0}, # Greedy - ) - } - return await self.training_rollout_async( - task, rollout_id, validation_resources - ) - - -if __name__ == "__main__": - """ - Entry point for the math agent training. - - This starts multiple agent workers that: - 1. Connect to the Lightning Server at localhost:9999 - 2. Receive math problems to solve - 3. Execute solutions and report rewards - 4. Get updated with improved model weights - """ - print("Starting Math Agent Training") - print("=" * 60) - print("Configuration:") - print(" - Workers: 8") - print(" - Server: http://localhost:9999/") - print(" - Dataset: GSM8K grade school math") - print("=" * 60) - - # Create and train the agent - # The Trainer handles: - # - Distributing tasks to workers - # - Collecting trajectories - # - Coordinating with the training server - trainer = Trainer(n_workers=8) - agent = MathAgent() - - trainer.fit_v0( - agent=agent, - server_url="http://localhost:9999/" - ) \ No newline at end of file diff --git a/examples/math_tool_use/prepare_data.py b/examples/math_tool_use/prepare_data.py deleted file mode 100644 index 734289dc0..000000000 --- a/examples/math_tool_use/prepare_data.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Data Preparation Script for Math Agent - -Downloads and prepares the GSM8K dataset for training. -GSM8K contains grade school math word problems with numerical answers. - -Dataset: https://github.com/openai/grade-school-math -Paper: Training Verifiers to Solve Math Word Problems (Cobbe et al., 2021) -""" - -import os -import re -from pathlib import Path - -import pandas as pd -from datasets import load_dataset - - -def extract_numeric_answer(answer_text: str) -> str: - """ - Extract the numerical answer from GSM8K format. - - GSM8K answers are in the format: - "Step-by-step solution text... - #### 42" - - We want to extract just the number after "####". - - Args: - answer_text: The full answer text from GSM8K - - Returns: - The numerical answer as a string - """ - - # Look for the pattern "#### NUMBER" - match = re.search(r'####\s*(-?\d+\.?\d*)', answer_text) - if match: - return match.group(1).strip() - - # Fallback: try to find any number in the text - numbers = re.findall(r'-?\d+\.?\d*', answer_text) - if numbers: - return numbers[-1] - - return "0" # Default if no number found - - -def prepare_gsm8k_dataset(output_dir: str = "data", test_size: int = 1319): - """ - Download and prepare the GSM8K dataset. - - Creates two files: - - gsm8k_train.parquet: Training set (~7K examples) - - gsm8k_test.parquet: Test set (~1.3K examples) - - Args: - output_dir: Directory to save the processed data - test_size: Number of examples to use for testing - """ - print("=" * 60) - print("GSM8K Dataset Preparation") - print("=" * 60) - - # Create output directory - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - print(f"\nOutput directory: {output_path.absolute()}") - - # Load dataset from Hugging Face - print("\nDownloading GSM8K dataset from Hugging Face...") - try: - dataset = load_dataset("gsm8k", "main") - except Exception as e: - print(f"Error loading dataset: {e}") - print("\nTrying alternative loading method...") - dataset = load_dataset("openai/gsm8k", "main") - - print(f"✓ Dataset loaded successfully") - print(f" - Train split: {len(dataset['train'])} examples") - print(f" - Test split: {len(dataset['test'])} examples") - - # Process training data - print("\nProcessing training data...") - train_data = [] - for example in dataset['train']: - - train_data.append({ - 'question': example['question'], - 'answer': extract_numeric_answer(example['answer']), - 'full_solution': example['answer'], # Keep full solution for reference - }) - - train_df = pd.DataFrame(train_data) - train_path = output_path / "gsm8k_train.parquet" - train_df.to_parquet(train_path, index=False) - print(f"✓ Saved training data to: {train_path}") - print(f" - {len(train_df)} examples") - - # Process test data - print("\nProcessing test data...") - test_data = [] - - for example in dataset['test'][:test_size]: - - # Limit test set size - test_data.append({ - 'question': example['question'], - 'answer': extract_numeric_answer(example['answer']), - 'full_solution': example['answer'], - }) - - test_df = pd.DataFrame(test_data) - test_path = output_path / "gsm8k_test.parquet" - test_df.to_parquet(test_path, index=False) - print(f"✓ Saved test data to: {test_path}") - print(f" - {len(test_df)} examples") - - # Display sample examples - print("\n" + "=" * 60) - print("Sample Examples:") - print("=" * 60) - - for i, row in train_df.head(3).iterrows(): - print(f"\nExample {i + 1}:") - print(f"Question: {row['question'][:100]}...") - print(f"Answer: {row['answer']}") - - # Statistics - print("\n" + "=" * 60) - print("Dataset Statistics:") - print("=" * 60) - print(f"Training examples: {len(train_df)}") - print(f"Test examples: {len(test_df)}") - print(f"\nAnswer distribution (train):") - print(f" Min: {train_df['answer'].astype(float).min()}") - print(f" Max: {train_df['answer'].astype(float).max()}") - print(f" Mean: {train_df['answer'].astype(float).mean():.2f}") - print(f" Median: {train_df['answer'].astype(float).median():.2f}") - - print("\n" + "=" * 60) - print("✓ Dataset preparation complete!") - print("=" * 60) - print("\nNext steps:") - print("1. Start Ray cluster: bash ../../scripts/restart_ray.sh") - print("2. Run agent workers: python math_agent.py") - print("3. Start training: bash train.sh") - - -def verify_dataset(data_dir: str = "data"): - """ - Verify that the dataset files exist and are readable. - - Args: - data_dir: Directory containing the dataset files - """ - print("\n" + "=" * 60) - print("Verifying Dataset") - print("=" * 60) - - data_path = Path(data_dir) - train_file = data_path / "gsm8k_train.parquet" - test_file = data_path / "gsm8k_test.parquet" - - if not train_file.exists(): - print(f"✗ Training file not found: {train_file}") - return False - - if not test_file.exists(): - print(f"✗ Test file not found: {test_file}") - return False - - try: - train_df = pd.read_parquet(train_file) - test_df = pd.read_parquet(test_file) - - print(f"✓ Training file: {train_file}") - print(f" - {len(train_df)} examples") - print(f" - Columns: {list(train_df.columns)}") - - print(f"\n✓ Test file: {test_file}") - print(f" - {len(test_df)} examples") - print(f" - Columns: {list(test_df.columns)}") - - # Verify required columns - required_cols = {'question', 'answer'} - if not required_cols.issubset(train_df.columns): - print(f"✗ Missing required columns in training data") - return False - - if not required_cols.issubset(test_df.columns): - print(f"✗ Missing required columns in test data") - return False - - print("\n✓ Dataset verification passed!") - return True - - except Exception as e: - print(f"✗ Error reading dataset files: {e}") - return False - - -if __name__ == "__main__": - """ - Main entry point for data preparation. - - Usage: - python prepare_data.py # Prepare dataset - python prepare_data.py --verify # Verify existing dataset - """ - import sys - - if len(sys.argv) > 1 and sys.argv[1] == "--verify": - # Verify existing dataset - verify_dataset() - else: - # Prepare new dataset - prepare_gsm8k_dataset() - - # Verify the prepared dataset - print("\n") - verify_dataset() \ No newline at end of file diff --git a/examples/math_tool_use/requirements.txt b/examples/math_tool_use/requirements.txt deleted file mode 100644 index 941c266b8..000000000 --- a/examples/math_tool_use/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -agentlightning>=0.1.0 -datasets>=2.14.0 -pandas>=2.0.0 -pyarrow>=12.0.0 - -# Optional: For experiment tracking -wandb>=0.15.0 - - -pytest>=7.0.0 -black>=23.0.0 \ No newline at end of file diff --git a/examples/math_tool_use/train.sh b/examples/math_tool_use/train.sh deleted file mode 100644 index 6a4e0e813..000000000 --- a/examples/math_tool_use/train.sh +++ /dev/null @@ -1,171 +0,0 @@ -#!/bin/bash - -# Copyright (c) Microsoft. All rights reserved. - -# Training Script for Math Reasoning Agent -# -# This script configures and starts the GRPO training server that optimizes -# the math agent through reinforcement learning. -# -# The server: -# 1. Receives trajectories from agent workers -# 2. Computes advantages using GRPO algorithm -# 3. Updates the policy model -# 4. Serves the updated model to workers -# -# Usage: -# bash train.sh # Use default settings -# bash train.sh trainer.total_epochs=5 # Override specific parameters - -set -e - -# ============================================================================== -# Configuration -# ============================================================================== - -# Model settings -export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct -export N_GPUS=1 -export ROLLOUT_TP_SIZE=1 - -# Data settings -export DATA_DIR=data -export TRAIN_FILE=${DATA_DIR}/gsm8k_train.parquet -export TEST_FILE=${DATA_DIR}/gsm8k_test.parquet - -# Experiment tracking -export EXPERIMENT_NAME=math_agent_gsm8k -export PROJECT_NAME=AgentLightning - -# ============================================================================== -# Pre-flight checks -# ============================================================================== - -echo "==================================" -echo "Math Agent Training Configuration" -echo "==================================" -echo "" -echo "Model: ${BASE_MODEL}" -echo "GPUs: ${N_GPUS}" -echo "Train data: ${TRAIN_FILE}" -echo "Test data: ${TEST_FILE}" -echo "Experiment: ${EXPERIMENT_NAME}" -echo "" - -# Check if data files exist -if [ ! -f "${TRAIN_FILE}" ]; then - echo "Error: Training file not found: ${TRAIN_FILE}" - echo "Please run: python prepare_data.py" - exit 1 -fi - -if [ ! -f "${TEST_FILE}" ]; then - echo "Error: Test file not found: ${TEST_FILE}" - echo "Please run: python prepare_data.py" - exit 1 -fi - -echo "✓ Data files found" -echo "" - -# Check if Ray is running -if ! ray status &> /dev/null; then - echo "Warning: Ray cluster not detected" - echo "Please start Ray: bash ../../scripts/restart_ray.sh" - echo "" - read -p "Continue anyway? (y/n) " -n 1 -r - echo "" - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - exit 1 - fi -fi - -echo "Starting training server..." -echo "" - -# ============================================================================== -# Launch Training -# ============================================================================== - -python -m agentlightning.verl \ - algorithm.adv_estimator=grpo \ - data.train_files=${TRAIN_FILE} \ - data.val_files=${TEST_FILE} \ - actor_rollout_ref.model.path=${BASE_MODEL} \ - \ - `# GPU Configuration` \ - trainer.n_gpus_per_node=${N_GPUS} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP_SIZE} \ - \ - `# Batch Sizes` \ - data.train_batch_size=64 \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - \ - `# Sequence Lengths` \ - data.max_prompt_length=2048 \ - data.max_response_length=1024 \ - data.truncation='error' \ - \ - `# Multi-turn Settings` \ - actor_rollout_ref.rollout.multi_turn.format=hermes \ - \ - `# Optimization Settings` \ - actor_rollout_ref.actor.optim.lr=5e-6 \ - actor_rollout_ref.actor.clip_ratio_low=0.2 \ - actor_rollout_ref.actor.clip_ratio_high=0.3 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.0 \ - actor_rollout_ref.actor.entropy_coeff=0.0 \ - algorithm.use_kl_in_reward=False \ - \ - `# Memory Optimization` \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ - \ - `# Rollout Settings` \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - \ - `# Training Schedule` \ - trainer.total_epochs=3 \ - trainer.save_freq=50 \ - trainer.test_freq=25 \ - trainer.val_before_train=True \ - trainer.critic_warmup=0 \ - \ - `# Logging` \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${PROJECT_NAME} \ - trainer.experiment_name=${EXPERIMENT_NAME} \ - \ - `# Infrastructure` \ - trainer.nnodes=1 \ - \ - `# Allow parameter overrides from command line` \ - $@ - -# ============================================================================== -# Post-training -# ============================================================================== - -echo "" -echo "==================================" -echo "Training Complete" -echo "==================================" -echo "" -echo "Checkpoints saved to: checkpoints/${EXPERIMENT_NAME}/" -echo "" -echo "To continue training:" -echo " 1. Update BASE_MODEL to point to a checkpoint" -echo " 2. Run: bash train.sh trainer.total_epochs=5" -echo "" -echo "To evaluate a checkpoint:" -echo " 1. Load the checkpoint in math_agent.py" -echo " 2. Run evaluation on the test set" \ No newline at end of file diff --git a/examples/math_tool_use/utils.py b/examples/math_tool_use/utils.py deleted file mode 100644 index 231790c82..000000000 --- a/examples/math_tool_use/utils.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Utility functions for the Math Agent - -Includes: -- Answer extraction and normalization -- Reward computation with partial credit -- Evaluation metrics -""" - -import re -from typing import Optional, Tuple - - -def extract_answer(text: str) -> str: - """ - Extract the final answer from the agent's response. - - Looks for content within tags. - - Args: - text: The full response text from the agent - - Returns: - The extracted answer, or empty string if no answer found - - Examples: - >>> extract_answer("The result is 42") - '42' - >>> extract_answer("Let me calculate... 3.5 is the answer") - '3.5' - >>> extract_answer("No answer tags here") - '' - """ - # Look for content between and tags - match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) - if match: - return match.group(1).strip() - - # Fallback: look for the last number in the text - # This handles cases where the model doesn't use tags correctly - numbers = re.findall(r'-?\d+\.?\d*', text) - if numbers: - return numbers[-1] - - return "" - - -def normalize_number(num_str: str) -> Optional[float]: - """ - Normalize a number string to a float for comparison. - - Handles: - - Integer and decimal numbers - - Numbers with commas (e.g., "1,234") - - Percentages (e.g., "50%") - - Fractions in decimal form - - Args: - num_str: String representation of a number - - Returns: - Float value, or None if parsing fails - - Examples: - >>> normalize_number("42") - 42.0 - >>> normalize_number("3.14159") - 3.14159 - >>> normalize_number("1,234") - 1234.0 - >>> normalize_number("50%") - 50.0 - """ - if not num_str: - return None - - # Remove common formatting - cleaned = num_str.strip() - cleaned = cleaned.replace(',', '') # Remove thousands separators - cleaned = cleaned.replace('$', '') # Remove dollar signs - cleaned = cleaned.replace('%', '') # Remove percent signs - cleaned = cleaned.strip() - - # Try to convert to float - try: - return float(cleaned) - except (ValueError, TypeError): - return None - - -def numbers_match(predicted: str, ground_truth: str, tolerance: float = 1e-4) -> bool: - """ - Check if two number strings represent the same value. - - Uses a small tolerance for floating point comparison. - - Args: - predicted: The predicted answer - ground_truth: The correct answer - tolerance: Maximum absolute difference to consider equal - - Returns: - True if the numbers match within tolerance - - Examples: - >>> numbers_match("42", "42.0") - True - >>> numbers_match("3.14159", "3.14160") - True # Within tolerance - >>> numbers_match("10", "20") - False - """ - pred_num = normalize_number(predicted) - truth_num = normalize_number(ground_truth) - - if pred_num is None or truth_num is None: - return False - - return abs(pred_num - truth_num) <= tolerance - - -def has_valid_format(response: str) -> bool: - """ - Check if the response has valid formatting. - - A valid response should: - - Contain tags - - Have some reasoning before the answer - - Not be empty - - Args: - response: The agent's full response - - Returns: - True if formatting is valid - """ - if not response or len(response.strip()) < 10: - return False - - # Check for answer tags - has_answer_tags = '' in response.lower() and '' in response.lower() - - return has_answer_tags - - -def used_calculator_check(response: str) -> bool: - """ - Check if the agent used the calculator tool. - - Args: - response: The agent's full response - - Returns: - True if calculator tool was mentioned/used - """ - calculator_indicators = [ - 'tool_call', - 'calculator', - 'tool', - 'function', - ] - - response_lower = response.lower() - return any(indicator in response_lower for indicator in calculator_indicators) - - -def compute_reward( - predicted: str, - ground_truth: str, - used_calculator: bool, - full_response: str = "", -) -> float: - """ - Compute reward for the agent's answer. - - Reward structure: - - Correct answer: +1.0 (full credit) - - Used calculator but wrong: +0.3 (partial credit for tool use) - - Valid format but wrong: +0.1 (partial credit for following format) - - Invalid format: -0.1 (penalty for not following instructions) - - This reward shaping encourages: - 1. Correct answers (highest reward) - 2. Using tools appropriately (partial credit) - 3. Following output format (minimal credit) - - Args: - predicted: The predicted answer extracted from response - ground_truth: The correct answer - used_calculator: Whether the calculator tool was used - full_response: Full response text for format checking - - Returns: - Reward value between -0.1 and 1.0 - """ - # Check if answer is correct - if numbers_match(predicted, ground_truth): - return 1.0 # Perfect! - - # Check format validity - valid_format = has_valid_format(full_response) - - if not valid_format: - return -0.1 # Penalty for not following format - - # Partial credit for using calculator (shows correct behavior) - if used_calculator: - return 0.3 - - # Minimal credit for valid format - return 0.1 - - -def compute_accuracy(predicted: str, ground_truth: str) -> float: - """ - Compute binary accuracy (0 or 1). - - Args: - predicted: The predicted answer - ground_truth: The correct answer - - Returns: - 1.0 if correct, 0.0 if incorrect - """ - return 1.0 if numbers_match(predicted, ground_truth) else 0.0 - - -def evaluate_batch( - predictions: list[str], - ground_truths: list[str], - full_responses: Optional[list[str]] = None, -) -> dict[str, float]: - """ - Evaluate a batch of predictions. - - Args: - predictions: List of predicted answers - ground_truths: List of correct answers - full_responses: Optional list of full response texts - - Returns: - Dictionary containing evaluation metrics: - - accuracy: Proportion of correct answers - - avg_reward: Average reward across examples - - format_compliance: Proportion with valid format - - tool_usage: Proportion that used calculator - """ - if full_responses is None: - full_responses = [""] * len(predictions) - - total = len(predictions) - if total == 0: - return { - "accuracy": 0.0, - "avg_reward": 0.0, - "format_compliance": 0.0, - "tool_usage": 0.0, - } - - correct = sum( - numbers_match(pred, truth) - for pred, truth in zip(predictions, ground_truths) - ) - - total_reward = sum( - compute_reward( - pred, - truth, - used_calculator_check(resp), - resp - ) - for pred, truth, resp in zip(predictions, ground_truths, full_responses) - ) - - valid_formats = sum( - has_valid_format(resp) - for resp in full_responses - ) - - tool_usage = sum( - used_calculator_check(resp) - for resp in full_responses - ) - - return { - "accuracy": correct / total, - "avg_reward": total_reward / total, - "format_compliance": valid_formats / total, - "tool_usage": tool_usage / total, - } - - -if __name__ == "__main__": - """ - Test the utility functions with sample data. - """ - print("Testing Math Agent Utilities") - print("=" * 60) - - # Test answer extraction - print("\n1. Testing answer extraction:") - test_responses = [ - "The answer is 42", - "Let me calculate: 5 + 3 = 8. 8", - "No tags here, just 99", - "", - ] - for resp in test_responses: - answer = extract_answer(resp) - print(f" Response: {resp[:50]}") - print(f" Extracted: '{answer}'") - - # Test number normalization - print("\n2. Testing number normalization:") - test_numbers = ["42", "3.14159", "1,234", "50%", "$100", "invalid"] - for num in test_numbers: - normalized = normalize_number(num) - print(f" '{num}' -> {normalized}") - - # Test number matching - print("\n3. Testing number matching:") - test_pairs = [ - ("42", "42.0"), - ("3.14159", "3.14160"), - ("10", "20"), - ("1,234", "1234"), - ] - for pred, truth in test_pairs: - match = numbers_match(pred, truth) - print(f" '{pred}' vs '{truth}': {match}") - - # Test reward computation - print("\n4. Testing reward computation:") - test_cases = [ - ("42", "42", True, "Used and got 42"), - ("42", "43", True, "Used calculator but got 42"), - ("42", "43", False, "Just guessed 42"), - ("42", "43", False, "No proper format at all"), - ] - for pred, truth, used_calc, response in test_cases: - reward = compute_reward(pred, truth, used_calc, response) - print(f" Pred: {pred}, Truth: {truth}, Calc: {used_calc}") - print(f" Reward: {reward:.2f}") - - # Test batch evaluation - print("\n5. Testing batch evaluation:") - predictions = ["42", "43", "44", "45"] - ground_truths = ["42", "43", "45", "45"] - responses = [ - "Used tool 42", - "Used tool 43", - "Just guess 44", - "Calculated 45", - ] - metrics = evaluate_batch(predictions, ground_truths, responses) - print(f" Accuracy: {metrics['accuracy']:.2%}") - print(f" Avg Reward: {metrics['avg_reward']:.3f}") - print(f" Format Compliance: {metrics['format_compliance']:.2%}") - print(f" Tool Usage: {metrics['tool_usage']:.2%}") - - print("\n" + "=" * 60) - print("All tests complete!") \ No newline at end of file