Skip to content

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289

Draft
JingyaHuang wants to merge 10 commits intohuggingface:mainfrom
JingyaHuang:add-neuron-backend
Draft

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289
JingyaHuang wants to merge 10 commits intohuggingface:mainfrom
JingyaHuang:add-neuron-backend

Conversation

@JingyaHuang
Copy link
Copy Markdown
Contributor

@JingyaHuang JingyaHuang commented Mar 19, 2026

What does this PR do?

This PR adds AWS Neuron (Trainium/Inferentia) as an officially supported compute backend in Diffusers, on par with existing backends like CUDA, MPS, XPU, and MLU.

Changes

  • import_utils.py — adds is_torch_neuronx_available() detection, following the existing pattern for optional backends.
  • torch_utils.py — registers "neuron" in all backend dispatch tables (BACKEND_SUPPORTS_TRAINING, BACKEND_EMPTY_CACHE, BACKEND_DEVICE_COUNT, BACKEND_MANUAL_SEED, etc.) and adds a randn_tensor workaround since Neuron/XLA does not support creating random tensors directly on device (falls back to CPU).
  • utils/init.py — exports is_torch_neuronx_available.
  • pipeline_utils.py — adds two new DiffusionPipeline methods:
    • enable_neuron_compile(model_names, cache_dir, fullgraph) — wraps pipeline nn.Module components with torch.compile(backend="neuron") for whole-graph NEFF compilation. Supports optional NEFF caching via TORCH_NEURONX_NEFF_CACHE_DIR.
    • neuron_warmup(*args, **kwargs) — runs a single dummy forward pass to trigger upfront neuronx-cc compilation before timed inference.

Usage

  • Eager mode
import torch                                                                                                             
import torch_neuronx  # noqa: F401 — registers torch.neuron                                                            
                                                                                                                           
from diffusers import AutoPipelineForText2Image                                                                          
                                                                                                                           
# Load and move to Neuron device                                                                                         
pipe = AutoPipelineForText2Image.from_pretrained(                                                                        
    "stabilityai/sdxl-turbo",                                                                                            
    torch_dtype=torch.bfloat16,                           
    variant="fp16",                                                                                                      
)
pipe = pipe.to(torch.neuron.current_device())                                                                            
                                                                                                                         
# Warmup                                                                   
pipe(prompt="warmup", height=512, width=512, num_inference_steps=1, guidance_scale=0.0)                                                                                                                        
                                                          
# Inference                                                                                               
image = pipe(                                             
    prompt="a golden retriever surfing a wave, photorealistic",                                                          
    height=512,
    width=512,                                                                                                           
    num_inference_steps=1, 
    guidance_scale=0.0,                                                                    
).images[0]                                                                                                              
                                                                                                                         
image.save("output.png") 

Next Steps

  • Enable torch.compile on Neuron device
  • Add tensor parallel support for memory-bound devices like neuron
  • Tackle the compatibility of diffusers+nki kernels lib to boost the performance on neuron under the compile mode

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions bot added lora examples size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 9, 2026
@github-actions github-actions bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants