Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions MaxCode/ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,49 +33,82 @@ execute the `migration_agent` and `evaluation_agent`, respectively.

### 4. ADK Tools

`tools/migration_tool.py` and `tools/evaluation_tool.py` define ADK
`FunctionTool`s that wrap specific Python functions for code conversion,
config generation, data generation, and testing.
`tools/migration_tool.py`, `tools/evaluation_tool.py`, and
`tools/verification_tool.py` define ADK `FunctionTool`s that wrap specific
Python functions for code conversion, quality verification, config generation,
data generation, and testing.

### 5. Migration and Validation Logic
### 5. Migration Pipeline

For **directory inputs**, `PrimaryAgent` uses `MergeAgent`
(`agents/migration/merge_agent.py`) to preprocess the repository before
conversion. The merge step:
- Discovers all nn.Module files and builds an import dependency graph
- Filters infrastructure files (fused kernels, CUDA wrappers, etc.)
- Merges model files into a single file in topological order
- Discovers and merges utility files separately
- Filters infrastructure classes from merged output

For **single-file inputs**, the existing direct conversion path is used.

After conversion, `migration_tool.convert_code` automatically runs
`VerificationAgent` (`agents/migration/verification_agent.py`) to produce
a completeness scorecard (AST-based, no LLM). The verification tool is
also available standalone via `tools/verification_tool.py`.

### 6. ADK Agent Orchestration

The `migration_agent` orchestrates the end-to-end migration and validation
workflow by calling tools in sequence:
1. **`migration_tool.convert_code`**: Converts PyTorch code to JAX using
`agents.migration.primary_agent.PrimaryAgent`, copies the original source
code, and saves the results to a timestamped output directory. Returns
paths to the migrated code, original code, and mapping file.
2. **`evaluation_tool.generate_model_configs`**: Generates configuration
1. **`migration_tool.convert_code`**: Merges, converts, and verifies
PyTorch code to JAX using `PrimaryAgent` (which delegates to
`MergeAgent` for directories). Copies the original source code and
saves results to a timestamped output directory.
2. **`verification_tool.verify_conversion`** (optional): Standalone
quality verification with completeness and correctness scores.
3. **`evaluation_tool.generate_model_configs`**: Generates configuration
files from the original PyTorch code.
3. **`evaluation_tool.generate_oracle_data`**: Generates oracle data
4. **`evaluation_tool.generate_oracle_data`**: Generates oracle data
(.pkl files) from the PyTorch code using the generated configurations.
4. **`evaluation_tool.run_equivalence_tests`**: Generates test scripts
5. **`evaluation_tool.run_equivalence_tests`**: Generates test scripts
that compare JAX outputs against PyTorch oracle data, and then runs these
tests using `subprocess`.

The result is a destination directory containing the migrated JAX code, a
`mapping.json` file, and an `evaluation` subdirectory with configurations,
oracle data, and test scripts.
`mapping.json` file, a `verification_scorecard.json`, and an `evaluation`
subdirectory with configurations, oracle data, and test scripts.

## Summary

The overall flow for migration is:

```
Gemini CLI -> mcp_server:primary_agent_server -> adk_agents:migration_agent ->
1. tools:migration_tool:convert_code (Migration)
2. tools:evaluation_tool:generate_model_configs (Config Gen)
3. tools:evaluation_tool:generate_oracle_data (Data Gen)
4. tools:evaluation_tool:run_equivalence_tests (Test Gen & Run)
1. tools:migration_tool:convert_code
(Merge -> Convert -> Validate/Repair -> Verify)
2. tools:verification_tool:verify_conversion (optional, standalone)
3. tools:evaluation_tool:generate_model_configs (Config Gen)
4. tools:evaluation_tool:generate_oracle_data (Data Gen)
5. tools:evaluation_tool:run_equivalence_tests (Test Gen & Run)
```

The internal flow within `convert_code` for directory inputs:

```
MergeAgent.run(repo_dir) # Preprocessing: discover, filter, merge
-> PrimaryAgent._convert_file() # LLM conversion (model + utils)
-> PrimaryAgent._fill_missing() # Gap-filling pass
-> PrimaryAgent._validate() # Validation + repair loop
-> VerificationAgent.verify() # Quality scorecard
```

## Agent Structure and Extension

The project separates agent implementation logic from ADK agent/tool
definitions:

* **`agents/<domain>/`**: Contains agent classes with core implementation logic (e.g., `agents/migration/primary_agent.py`).
* **`tools/`**: Contains ADK `FunctionTool` wrappers that call agent logic or other Python functions (e.g., `tools/migration_tool.py`).
* **`agents/<domain>/`**: Contains agent classes with core implementation logic (e.g., `agents/migration/primary_agent.py`, `agents/migration/merge_agent.py`, `agents/migration/verification_agent.py`).
* **`tools/`**: Contains ADK `FunctionTool` wrappers that call agent logic or other Python functions (e.g., `tools/migration_tool.py`, `tools/verification_tool.py`).
* **`mcp_server/adk_agents.py`**: Defines the ADK agent hierarchy, instructions, and tool mappings.

### How to Add a New Capability
Expand Down
38 changes: 36 additions & 2 deletions MaxCode/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@
This extension provides development tools for the MaxCode project,
including tools for AI-powered code migration between ML frameworks.

## Quick Start

Want to try MaxCode without the full Gemini CLI setup? The standalone demo
converts a PyTorch repo to JAX in five steps:

```bash
cd MaxCode/examples/demo
pip install -r requirements.txt
export GOOGLE_API_KEY=<your-key> # Windows CMD: set GOOGLE_API_KEY=<your-key>

python step1_clone_repo.py # Clone a PyTorch repo from GitHub
python step2_populate_rag.py # Build the RAG reference database
python step3_merge.py # Merge model + utility files (MergeAgent)
python step4_convert.py # Convert to JAX with validation + repair
python step4_convert_maxtext.py # Or: convert to MaxText (YAML + layers + ckpt converter)
python step5_verify.py # Verify conversion quality (VerificationAgent)
```

See [examples/demo/README.md](examples/demo/README.md) for full setup
instructions and details on what each step does.

## Prerequisites

This extension uses the Google AI API, which requires an API key. You can get
Expand Down Expand Up @@ -196,6 +217,19 @@ dev-server run_evaluation_workflow --prompt "Run equivalence tests for migration

## Architecture

Agents are organized by domain (e.g., migration, kernel) within the `agents/`
directory. For more details on the project architecture and agent structure, see
The migration pipeline: **Clone -> Index -> Merge -> Convert -> Verify**.

Step 4 supports two conversion targets:
- **JAX/Flax** (`step4_convert.py`) — single-file JAX translation with validation and repair.
- **MaxText** (`step4_convert_maxtext.py`) — produces a YAML config overlay, an optional JAX layers file, and a checkpoint converter for the MaxText TPU stack.

Key agents in `agents/migration/`:
- **MergeAgent** — Pure-logic preprocessing: file discovery, filtering, import
graph analysis, and merging (no LLM calls).
- **PrimaryAgent** — Orchestrates conversion: routes to model or utility
conversion agents, fills missing components, validates and repairs.
- **VerificationAgent** — Post-processing quality scoring: AST-based
completeness + optional LLM-based correctness.

For more details on the project architecture and agent structure, see
[ARCHITECTURE.md](ARCHITECTURE.md).
8 changes: 4 additions & 4 deletions MaxCode/agents/evaluation/test_generation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
1. The JAX model can be instantiated and run with dummy inputs.
2. The JAX model, when loaded with weights from the PyTorch model, produces numerically close outputs given the same inputs.

You should use `absl.flags` to allow the user to specify the path to a pickle file containing a dictionary with keys 'input', 'output', 'state_dict', and 'intermediates', generated from the original PyTorch model.
You should use `absl.flags` to allow the user to specify the path to a pickle file containing a dictionary with keys 'input', 'output', and 'state_dict', generated from the original PyTorch model.

The JAX code is:
```python
Expand All @@ -31,7 +31,7 @@
* Run `model.init()` and `model.apply()` with the dummy input.
* Assert that the output has the expected shape and contains no NaNs.
4. **Test 2: `test_equivalence()`**:
* Load the pickle file specified by `_PICKLE_PATH.value`. The pickle file contains a dictionary with keys 'input', 'output', 'state_dict', and 'intermediates'.
* Load the pickle file specified by `_PICKLE_PATH.value`. The pickle file contains a dictionary with keys 'input', 'output', and 'state_dict'.
* If PyTorch input is a tuple, use the first element.
* Convert PyTorch input tensor to a Numpy array using `.detach().numpy()`. If it's 4D (NCHW), transpose it to NHWC format for JAX `(0, 2, 3, 1)`.
* Instantiate the JAX model.
Expand All @@ -40,8 +40,8 @@
* Linear weights: PyTorch `(Out, In)` -> Flax `(In, Out)`. Transpose with `(1, 0)`.
* Copy biases and other parameters without transpose.
* The JAX params structure may be nested, e.g., `{{'params': {{'Conv_0': {{'kernel': ..., 'bias': ...}}}}}}`. Map PyTorch weights to the correct Flax names and structure.
* **Numerical Verification of Intermediates**: The test MUST use `model.apply(..., mutable=['intermediates'])` to extract captured JAX activations. It must then iterate through the `intermediates` dictionary provided in the oracle data (captured via PyTorch hooks) and use `np.testing.assert_allclose` to verify each JAX intermediate against its PyTorch counterpart. Ensure `err_msg=f"Mismatch in layer: {{layer_name}}"` is used for precise error reporting.
* Assert `np.testing.assert_allclose(jax_output, torch_output.detach().numpy(), atol=1e-5)` to check for numerical equivalence of the final output.
* Run JAX `model.apply()` using the converted parameters and input.
* Assert `np.testing.assert_allclose(jax_output, torch_output.detach().numpy(), atol=1e-5)` to check for numerical equivalence.
5. Include `absltest.main()` runner block.

Output **only** the Python code block for the test file. Do not include any text before or after the code.
Expand Down

This file was deleted.

Loading