Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ directories, but must be read manually when working from the project root:
- `optimum/neuron/models/inference/AGENTS.md` — any inference model work
- `optimum/neuron/models/inference/backend/modules/attention/AGENTS.md` — attention or NKI kernel work
- `optimum/neuron/models/inference/<model>/AGENTS.md` — model-specific work (gemma3, llama, qwen3, etc.)
- `optimum/neuron/cache/AGENTS.md` — cache subsystem work (cleanup, Hub sync, registry)
- `optimum/neuron/cache/AGENTS.md` — cache subsystem work (bucket storage, fetch/sync, cleanup)
- `optimum/neuron/vllm/AGENTS.md` — vLLM integration work
- `tests/AGENTS.md` — test infrastructure, fixtures, and cache management

Expand Down Expand Up @@ -101,7 +101,7 @@ Use [NxDI](https://github.com/aws-neuron/neuronx-distributed-inference) for neur
For the full porting checklist and test guidance, see [optimum/neuron/models/inference/AGENTS.md](optimum/neuron/models/inference/AGENTS.md).

## Cache Management
Compiled models are cached locally and synced to the HF Hub. See [optimum/neuron/cache/AGENTS.md](optimum/neuron/cache/AGENTS.md) for the full cache architecture, entry states, cleanup logic, and CLI commands. Test helpers live in [tests/conftest.py](tests/conftest.py). Relevant env vars: `NEURON_CC_FLAGS`, `NEURON_COMPILE_CACHE_URL`, `NEURON_RT_VISIBLE_CORES`.
Compiled NEFFs are cached locally and synced to HF Storage Buckets (default: `aws-neuron/optimum-neuron-neff-cache`). The `hub_neuronx_cache` context manager handles fetch on enter and sync on exit. Bucket operations run in an isolated subprocess via `uv` to avoid `huggingface_hub` version conflicts. See [optimum/neuron/cache/AGENTS.md](optimum/neuron/cache/AGENTS.md) for the full cache architecture, bucket layout, entry states, cleanup logic, and CLI commands. Relevant env vars: `NEURON_CACHE_BUCKET`, `NEURON_COMPILE_CACHE_URL`, `NEURON_CC_FLAGS`.

## CI/CD Workflows (Summary)

Expand Down Expand Up @@ -135,6 +135,15 @@ All test workflows follow the same pattern:
- TRN1 training: `pytest -m "is_trainium_test" tests/training/`
4. Check model-specific AGENTS.md if you touched a model directory.

## Commit Policy

- **Atomic commits**: one logical change per commit. Split refactoring from new feature additions whenever possible.
- **Tests passing**: each commit should leave tests green. Exceptions acceptable within a branch but not ideal.
- **Conventional commits**: use `feat(scope):`, `fix(scope):`, `refactor(scope):`, `test(scope):`, `docs(scope):`, `chore(scope):`, etc.
- **No file lists in commit body**: don't enumerate changed files — that's redundant with git metadata. Focus on the *why*.
- **Clean history**: present a consistent, rebased commit queue before pushing.
- **Rebase via script**: when squashing fixups, write a `GIT_SEQUENCE_EDITOR` shell script that transforms the rebase todo (mark fixups, reorder lines). Save it to a file and let the user run `GIT_SEQUENCE_EDITOR=/path/to/script.sh git rebase -i <base>`. Never run interactive rebases directly. If a fixup touches files from multiple parent commits, split it in a separate pass first.

## Troubleshooting

- `ruff: command not found`: activate venv first.
Expand Down
120 changes: 73 additions & 47 deletions optimum/commands/neuron/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,96 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the command line related to dealing with the Neuron cache repo."""
"""Defines the command line related to dealing with the Neuron cache."""

from argparse import ArgumentParser

from ...neuron.cache.bucket_cache import fetch_cache
from ...neuron.cache.bucket_utils import get_cache_bucket, set_cache_bucket_in_hf_home
from ...neuron.cache.cleanup import cleanup_local_cache, get_local_cache_status
from ...neuron.cache.hub_cache import select_hub_cached_entries, synchronize_hub_cache
from ...neuron.utils.cache_utils import (
CACHE_REPO_NAME,
HF_HOME_CACHE_REPO_FILE,
create_custom_cache_repo,
set_custom_cache_repo_name_in_hf_home,
)
from ...neuron.cache.hub_cache import select_hub_cached_entries
from ...neuron.utils.import_utils import is_package_available
from ...neuron.utils.instance import SUPPORTED_INSTANCE_TYPES
from ...neuron.utils.require_utils import requires_torch_neuronx
from ...utils import logging
from ..base import BaseOptimumCLICommand, CommandInfo


logger = logging.get_logger()


class CreateCustomCacheRepoCommand(BaseOptimumCLICommand):
class CreateCacheBucketCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: ArgumentParser):
parser.add_argument(
"-n",
"--name",
type=str,
default=CACHE_REPO_NAME,
help="The name of the repo that will be used as a remote cache for the compilation files.",
default=None,
help="The bucket ID (e.g. 'my-org/my-cache'). Defaults to the configured cache bucket.",
)
parser.add_argument(
"--public",
action="store_true",
help="If set, the created repo will be public. By default the cache repo is private.",
help="If set, the created bucket will be public. By default the cache bucket is private.",
)

def run(self):
repo_url = create_custom_cache_repo(repo_id=self.args.name, private=not self.args.public)
from ...neuron.cache.bucket_cache import _call_server

bucket_id = self.args.name or get_cache_bucket()
if not bucket_id:
logger.error("No bucket ID specified and no default bucket configured.")
return

# Verify bucket connectivity via the server (auto-starts if needed)
try:
_call_server("ping")
logger.info(f"Cache bucket server ready for: {bucket_id}")
except Exception as e:
logger.error(f"Failed to start bucket server: {e}")
return

set_cache_bucket_in_hf_home(bucket_id)
public_or_private = "public" if self.args.public else "private"
logger.info(f"Neuron cache created on the Hugging Face Hub: {repo_url.repo_id} [{public_or_private}].")
logger.info(f"Neuron cache name set locally to {repo_url.repo_id} in {HF_HOME_CACHE_REPO_FILE}.")
logger.info(f"Neuron cache bucket set to {bucket_id} [{public_or_private}].")
Comment on lines +32 to +66
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreateCacheBucketCommand is described as creating a bucket (and exposes a --public flag), but the implementation only pings the local bucket server and writes the bucket ID to HF_HOME. This is misleading UX and makes --public a no-op. Either actually create/verify the bucket via the server (e.g. proxy a create_repo(repo_type="bucket") operation) or rename the command/help to reflect that it only sets configuration.

Copilot uses AI. Check for mistakes.


class SetCustomCacheRepoCommand(BaseOptimumCLICommand):
class SetCacheBucketCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
parser.add_argument("name", type=str, help="The name of the repo to use as remote cache.")
parser.add_argument("name", type=str, help="The bucket ID to use as remote cache (e.g. 'my-org/my-cache').")

def run(self):
set_custom_cache_repo_name_in_hf_home(self.args.name)
logger.info(f"Neuron cache name set locally to {self.args.name} in {HF_HOME_CACHE_REPO_FILE}.")
set_cache_bucket_in_hf_home(self.args.name)
logger.info(f"Neuron cache bucket set locally to {self.args.name}.")


class SynchronizeRepoCommand(BaseOptimumCLICommand):
class FetchCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.")
parser.add_argument(
"--cache_dir", type=str, default=None, help="The cache directory that contains the compilation files."
"model_id",
type=str,
help="The model_id to pre-warm cache for.",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="The task to fetch cache for (e.g. 'text-generation').",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The local cache directory to download to.",
)

@requires_torch_neuronx
def run(self):
synchronize_hub_cache(cache_path=self.args.cache_dir, cache_repo_id=self.args.repo_id)
fetch_cache(model_id=self.args.model_id, task=self.args.task, cache_dir=self.args.cache_dir)


class LookupRepoCommand(BaseOptimumCLICommand):
class LookupCommand(BaseOptimumCLICommand):
@staticmethod
def parse_args(parser: "ArgumentParser"):
parser.add_argument(
Expand All @@ -92,7 +113,7 @@ def parse_args(parser: "ArgumentParser"):
"--task",
type=str,
default=None,
help="The optional task to lookup cached versions for models supporting multiple tasks.",
Comment thread
dacorvo marked this conversation as resolved.
help="The task to lookup cache for (e.g. 'text-generation').",
)
parser.add_argument(
"--instance_type",
Expand Down Expand Up @@ -121,13 +142,11 @@ def parse_args(parser: "ArgumentParser"):
type=int,
help="Only look for cached models supporting at least the specified sequence length.",
)
parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.")

def _list_entries(self):
entries = select_hub_cached_entries(
self.args.model_id,
task=self.args.task,
Comment thread
dacorvo marked this conversation as resolved.
cache_repo_id=self.args.repo_id,
instance_type=self.args.instance_type,
batch_size=self.args.batch_size,
sequence_length=self.args.sequence_length,
Expand All @@ -138,24 +157,31 @@ def _list_entries(self):
if n_entries == 0:
print(f"No cached entries found for {self.args.model_id}.")
return
# Prepare output table data
title = f"Cached entries for {self.args.model_id}"
columns = ["batch size", "sequence length", "tensor parallel", "dtype", "instance type"]
rows = []
for entry in entries:
rows.append(
(
str(entry["batch_size"]),
str(entry["sequence_length"]),
str(entry.get("tp_degree", entry.get("tensor_parallel_size"))),
str(entry.get("torch_dtype", entry.get("dtype"))),
str(entry["target"]),
str(entry.get("batch_size", "?")),
str(entry.get("sequence_length", "?")),
str(entry.get("tp_degree", entry.get("tensor_parallel_size", "?"))),
str(entry.get("torch_dtype", entry.get("dtype", "?"))),
str(entry.get("target", "?")),
)
)
# Remove duplicates (might happen if the same arch was compiled several times with different models and sync'ed afterwards)

def _sort_key(row):
def _int_or(val):
try:
return (0, int(val))
except (ValueError, TypeError):
return (1, val)

return (_int_or(row[2]), _int_or(row[0]), _int_or(row[1]), row[3])

rows = list(set(rows))
# Sort by tensor parallel size, then batch size, sequence length, dtype
rows = sorted(rows, key=lambda x: (int(x[2]), int(x[0]), int(x[1]), x[3]))
rows = sorted(rows, key=_sort_key)
if is_package_available("rich", "14.1.0"):
from rich.console import Console
from rich.table import Table
Expand Down Expand Up @@ -231,23 +257,23 @@ class CustomCacheRepoCommand(BaseOptimumCLICommand):
SUBCOMMANDS = (
CommandInfo(
name="create",
help="Create a model repo on the Hugging Face Hub to store Neuron X compilation files.",
subcommand_class=CreateCustomCacheRepoCommand,
help="Create a storage bucket on the Hugging Face Hub for Neuron compilation files.",
subcommand_class=CreateCacheBucketCommand,
),
CommandInfo(
name="set",
help="Set the name of the Neuron cache repo to use locally.",
subcommand_class=SetCustomCacheRepoCommand,
help="Set the name of the Neuron cache bucket to use locally.",
subcommand_class=SetCacheBucketCommand,
),
CommandInfo(
name="synchronize",
help="Synchronize the neuronx compiler cache with a hub cache repo.",
subcommand_class=SynchronizeRepoCommand,
name="fetch",
help="Pre-warm the local cache by downloading MODULE dirs for a model from the bucket.",
subcommand_class=FetchCommand,
),
CommandInfo(
name="lookup",
help="Lookup the neuronx compiler hub cache for the specified model id. Tip: install rich for a nicer display",
subcommand_class=LookupRepoCommand,
help="Lookup cached export configurations for the specified model id.",
subcommand_class=LookupCommand,
),
CommandInfo(
name="status",
Expand Down
67 changes: 33 additions & 34 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@
import torch
Comment thread
dacorvo marked this conversation as resolved.
from transformers import PreTrainedModel

from optimum.neuron.cache.entries.multi_model import MultiModelCacheEntry
from optimum.neuron.cache.entries.single_model import SingleModelCacheEntry
from optimum.neuron.cache.traced import cache_traced_neuron_artifacts
from optimum.neuron.utils import (
DiffusersPretrainedConfig,
is_neuronx_available,
store_compilation_config,
)

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.utils.cache_utils import get_model_name_or_path
from ...neuron.utils.system import get_neuron_major
from ...neuron.utils.version_utils import get_neuronxcc_version
from ...utils import (
Expand Down Expand Up @@ -346,7 +342,6 @@ def export_models(

failed_models = []
total_compilation_time = 0
compile_configs = {}
for i, model_name in enumerate(models_and_neuron_configs.keys()):
logger.info(f"***** Compiling {model_name} *****")
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
Expand All @@ -372,6 +367,9 @@ def export_models(
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
model_name_or_path=model_name_or_path,
task=task,
disable_neuron_cache=disable_neuron_cache,
**compiler_kwargs,
)
compilation_time = time.time() - start_time
Expand Down Expand Up @@ -411,26 +409,9 @@ def export_models(
output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
)
model_config.save_pretrained(output_path.parent)
compile_configs[model_name] = model_config

logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")

# cache neuronx model
if not disable_neuron_cache and is_neuronx_available():
model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
if len(compile_configs) == 1:
# FIXME: this is overly complicated just to pass the config
cache_config = list(compile_configs.values())[0]
cache_entry = SingleModelCacheEntry(model_id=model_id, task=task, config=cache_config)
else:
try:
cache_entry = MultiModelCacheEntry(model_id=model_id, configs=compile_configs)
except NotImplementedError:
logger.warning(f"Cache indexing is not supported for {model_id}.")
cache_entry = None
if cache_entry is not None:
cache_traced_neuron_artifacts(neuron_dir=output_dir, cache_entry=cache_entry)

# remove models failed to export
for i, model_name in failed_models:
output_file_names.pop(model_name)
Expand All @@ -451,6 +432,9 @@ def export(
auto_cast_type: str = "bf16",
disable_fast_relayout: bool = False,
disable_fallback: bool = False,
model_name_or_path: str | None = None,
task: str | None = None,
disable_neuron_cache: bool = False,
) -> tuple[list[str], list[str]]:
if is_neuronx_available():
return export_neuronx(
Expand All @@ -463,6 +447,9 @@ def export(
instance_type=instance_type,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
model_name_or_path=model_name_or_path,
task=task,
disable_neuron_cache=disable_neuron_cache,
)
else:
raise RuntimeError(
Expand All @@ -480,6 +467,9 @@ def export_neuronx(
optlevel: str = "2",
auto_cast: str | None = None,
auto_cast_type: str = "bf16",
model_name_or_path: str | None = None,
task: str | None = None,
disable_neuron_cache: bool = False,
) -> tuple[list[str], list[str]]:
"""
Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler.
Expand Down Expand Up @@ -559,18 +549,27 @@ def export_neuronx(
)
inline_weights_to_neff = True

# Start trace
trace_neuronx(
model=checked_model,
config=config,
dummy_inputs=dummy_inputs_tuple,
compiler_args=compiler_args,
output=output,
tensor_parallel_size=config.tensor_parallel_size,
aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)
# Start trace (wrapped in cache context for NEFF fetch/sync)
def _do_trace():
trace_neuronx(
model=checked_model,
config=config,
dummy_inputs=dummy_inputs_tuple,
compiler_args=compiler_args,
output=output,
tensor_parallel_size=config.tensor_parallel_size,
aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)

if not disable_neuron_cache and model_name_or_path:
from optimum.neuron.cache.hub_cache import hub_neuronx_cache

with hub_neuronx_cache(model_id=model_name_or_path, task=task):
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hub_neuronx_cache supports uploading an export record when export_config is provided, but export_neuronx() never passes export_config nor sets ctx.export_config inside the context. As a result, lookup_cache()/select_hub_cached_entries will remain empty even after successful compiles. Consider populating export_config from the Neuron config / compiler args (using the yielded CacheContext) so advisory lookup works as intended.

Suggested change
with hub_neuronx_cache(model_id=model_name_or_path, task=task):
with hub_neuronx_cache(model_id=model_name_or_path, task=task) as ctx:
# Populate export_config on the cache context so that successful compiles
# can be recorded and later used for advisory cache lookups.
export_config: dict[str, Any] = {
"compiler_args": compiler_args,
}
# Add compiler metadata when available.
if "NEURON_COMPILER_TYPE" in globals():
export_config["compiler_type"] = NEURON_COMPILER_TYPE
if "NEURON_COMPILER_VERSION" in globals():
export_config["compiler_version"] = NEURON_COMPILER_VERSION
# Include runtime information and a serializable view of the Neuron config.
export_config["neuron_runtime"] = f"neuronx-{get_neuron_major()}"
if hasattr(config, "to_dict"):
export_config["neuron_config"] = config.to_dict()
elif hasattr(config, "__dict__"):
export_config["neuron_config"] = {
k: v for k, v in config.__dict__.items() if not k.startswith("_")
}
# Only set export_config if the context supports it.
if hasattr(ctx, "export_config"):
ctx.export_config = export_config

Copilot uses AI. Check for mistakes.
_do_trace()
else:
_do_trace()

del model_or_path
return config.inputs, config.outputs
Expand Down
Loading
Loading