From d42b750d88aa017282687b1f176c34046a2fb880 Mon Sep 17 00:00:00 2001 From: RiyaP-QA Date: Tue, 28 Apr 2026 20:15:59 +0530 Subject: [PATCH 1/2] Update PrivateUse1 tutorial to document Python backend approach Add new section covering _setup_privateuseone_for_python_backend() with tensor subclass pattern, torch.library op registration, hooks/device guard customization, end-to-end NumPy-backed example, comparison table, and limitations. Also add hooks registration warning to existing C++ section. Fixes pytorch/pytorch#179010 Fixes pytorch/pytorch#179008 --- advanced_source/privateuseone.rst | 390 +++++++++++++++++++++++++++++- index.rst | 4 +- 2 files changed, 385 insertions(+), 9 deletions(-) diff --git a/advanced_source/privateuseone.rst b/advanced_source/privateuseone.rst index 5b5b37c20e2..c587562b30a 100644 --- a/advanced_source/privateuseone.rst +++ b/advanced_source/privateuseone.rst @@ -3,8 +3,7 @@ Facilitating New Backend Integration by PrivateUse1 In this tutorial we will walk through some necessary steps to integrate a new backend living outside ``pytorch/pytorch`` repo by ``PrivateUse1``. Note that this tutorial assumes that -you already have a basic understanding of PyTorch. -you are an advanced user of PyTorch. +you already have a basic understanding of PyTorch and are an advanced user of PyTorch. .. note:: @@ -12,6 +11,11 @@ you are an advanced user of PyTorch. and other parts will not be covered. At the same time, not all the modules involved in this tutorial are required, and you can choose the modules that are helpful to you according to your actual needs. +.. tip:: + + If you want to integrate a backend entirely in Python without writing any C++, + see `Python Backend Approach (Simplified)`_ below. + What is PrivateUse1? -------------------- @@ -191,7 +195,7 @@ of new backend additional metadata named ``backend_meta_`` in class ``TensorImpl 1. Inherit the ``BackendMeta`` class to implement ``CustomBackendMetadata`` corresponding to the new backend and various fields of the new backend can be customized in the class. -2. Implement the serialization and deserialization functions of the new backend, the function signatures are +2. Implement the serialization and deserialization functions of the new backend, the function signatures are ``void(const at::Tensor&, std::unordered_map&)``. 3. Call the ``TensorBackendMetaRegistry`` macro to complete dynamic registration. @@ -228,6 +232,7 @@ and the next thing to do is to improve usability, which mainly involves the foll 1. Register new backend module to Pytorch. 2. Rename PrivateUse1 to a custom name for the new backend. 3. Generate methods and properties related to the new backend. +4. Register PrivateUse1HooksInterface for the new backend. Register new backend module to Pytorch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -292,18 +297,389 @@ Then, you can use the following methods and properties: torch.Storage.is_npu ... +Register PrivateUse1HooksInterface for the new backend +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your backend uses autograd (i.e., calls ``.backward()``), you **must** register +``PrivateUse1HooksInterface``. Without it, backward passes will fail. + +.. warning:: + + Calling ``.backward()`` without registering ``PrivateUse1HooksInterface`` will raise:: + + RuntimeError: Please register PrivateUse1HooksInterface + by `RegisterPrivateUse1HooksInterface` first. + +In C++, register hooks by calling: + +.. code-block:: cpp + + #include + + struct MyHooks : at::PrivateUse1HooksInterface { + bool isAvailable() const override { return true; } + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; } + bool isBuilt() const override { return true; } + }; + + at::RegisterPrivateUse1HooksInterface(new MyHooks()); + +In Python, register hooks using: + +.. code-block:: python + + torch._C._acc.register_python_privateuseone_hook(hook_instance) + +.. note:: + + If you use ``_setup_privateuseone_for_python_backend()``, hooks are registered + automatically. See `Python Backend Approach (Simplified)`_ for details. + + +Python Backend Approach (Simplified) +------------------------------------ + +Starting from PyTorch 2.10, ``_setup_privateuseone_for_python_backend()`` simplifies +backend creation to a single Python call. This approach is ideal for: + +* Rapid prototyping without C++ compilation +* Backends where compute is handled in Python (NumPy, JAX, custom accelerators with Python bindings) +* Educational purposes and experimentation + +.. warning:: + + This API is experimental and may change without notice. The function is + private (prefixed with ``_``). + +Quick Setup +^^^^^^^^^^^ + +The simplest way to set up a Python backend: + +.. code-block:: python + + from torch.utils.backend_registration import _setup_privateuseone_for_python_backend + + _setup_privateuseone_for_python_backend("mybackend") + +**Parameters:** + +* ``rename`` (``str | None``) -- custom backend name; defaults to ``"privateuseone"`` +* ``backend_module`` (``object | None``) -- defaults to ``_DummyBackendModule`` +* ``hook`` (``object | None``) -- defaults to ``_DummyPrivateUse1Hook`` +* ``device_guard`` (``object | None``) -- defaults to ``_DummyDeviceGuard`` + +**What it does internally (in order):** + +1. ``torch.utils.rename_privateuse1_backend(rename)`` +2. ``torch.utils.generate_methods_for_privateuse1_backend()`` +3. ``torch._register_device_module(rename, backend_module)`` +4. ``torch._C._acc.register_python_privateuseone_hook(hook)`` +5. ``torch._C._acc.register_python_privateuseone_device_guard(device_guard)`` + +Defining a Custom Tensor Subclass +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To store backend-specific data (e.g., NumPy arrays), define a tensor subclass: + +.. code-block:: python + + import torch + import numpy as np + + class MyDeviceTensor(torch.Tensor): + @staticmethod + def __new__(cls, size, dtype, raw_data=None, requires_grad=False): + # Create an empty tensor on the PrivateUse1 device + res = torch._C._acc.create_empty_tensor(size, dtype) + res.__class__ = MyDeviceTensor + return res + + def __init__(self, size, dtype, raw_data=None, requires_grad=False): + # Store the backend-specific data + self.raw_data = raw_data + + def wrap(arr, shape, dtype): + """Wrap a NumPy array as a MyDeviceTensor.""" + return MyDeviceTensor(shape, dtype, arr) + + def unwrap(tensor): + """Extract the raw NumPy array from a MyDeviceTensor.""" + return tensor.raw_data + +.. note:: + + ``torch._C._acc.create_empty_tensor`` is a private helper provided for the + Python backend approach. It creates an empty tensor on the PrivateUse1 device. + +Registering Ops with torch.library +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``@torch.library.impl`` to register operator implementations for your backend: + +.. code-block:: python + + @torch.library.impl("aten::add.Tensor", "privateuseone") + def add(t1, t2): + out = unwrap(t1) + unwrap(t2) + return wrap(out, out.shape, torch.float32) + +.. note:: + + The second argument to ``@torch.library.impl`` is the **dispatch key** name, + which is always ``"privateuseone"`` regardless of what you passed to ``rename``. + The renamed name (e.g., ``"npy"``) is only used for user-facing APIs like + ``tensor.to("npy")`` and ``tensor.device.type``. + +**Required Operators:** + +The following operators must be registered for basic functionality and autograd support: + +.. list-table:: + :header-rows: 1 + :widths: 35 65 + + * - Operator + - Purpose + * - ``aten::empty_strided`` + - Tensor allocation + * - ``aten::empty.memory_format`` + - Tensor allocation (alternate path) + * - ``aten::_copy_from`` + - Data transfer between CPU and backend + * - ``aten::detach`` + - Autograd detach + * - ``aten::view`` + - View creation + * - ``aten::as_strided`` + - Strided tensor access + * - ``aten::ones_like`` + - Gradient initialization during backward + * - ``aten::expand`` + - Broadcasting + * - ``aten::sum`` + - Reduction (if using ``.sum().backward()``) + * - ``aten::add.Tensor`` + - Addition + * - ``aten::mul.Tensor`` + - Multiplication + +Customizing Hooks and Device Guard +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more control, you can provide custom hooks and device guard implementations: + +.. code-block:: python + + class MyHook(torch._C._acc.PrivateUse1Hooks): + def is_available(self) -> bool: + return True + + def has_primary_context(self, dev_id) -> bool: + return True + + def is_built(self) -> bool: + return True + + class MyGuard(torch._C._acc.DeviceGuard): + def type_(self): + return torch._C._autograd.DeviceType.PrivateUse1 + + _setup_privateuseone_for_python_backend( + "mybackend", hook=MyHook(), device_guard=MyGuard() + ) + +.. important:: + + **Hooks registration is required for autograd backward.** Without registering + hooks, calling ``backward()`` will fail with:: + + RuntimeError: Please register PrivateUse1HooksInterface + by `RegisterPrivateUse1HooksInterface` first. + + The ``_setup_privateuseone_for_python_backend()`` function handles this + automatically by registering a default hook. + +End-to-End Example with Autograd +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Here is a complete, self-contained example using NumPy as the backend: + +.. code-block:: python + + import numpy as np + import torch + from torch.utils.backend_registration import _setup_privateuseone_for_python_backend + + # Step 1: Set up the backend + _setup_privateuseone_for_python_backend("npy") + + # Step 2: Define the tensor subclass + class NpyTensor(torch.Tensor): + @staticmethod + def __new__(cls, size, dtype, raw_data=None, requires_grad=False): + res = torch._C._acc.create_empty_tensor(size, dtype) + res.__class__ = NpyTensor + return res + + def __init__(self, size, dtype, raw_data=None, requires_grad=False): + self.raw_data = raw_data + + def wrap(arr, shape, dtype): + return NpyTensor(shape, dtype, arr) + + def unwrap(tensor): + return tensor.raw_data + + # Step 3: Register all required operators + @torch.library.impl("aten::add.Tensor", "privateuseone") + def add(t1, t2): + out = unwrap(t1) + unwrap(t2) + return wrap(out, out.shape, torch.float32) + + @torch.library.impl("aten::mul.Tensor", "privateuseone") + def mul(t1, t2): + out = unwrap(t1) * unwrap(t2) + return wrap(out, out.shape, torch.float32) + + @torch.library.impl("aten::sum", "privateuseone") + def sum_impl(*args, **kwargs): + ans = unwrap(args[0]).sum() + return wrap(ans, ans.shape, torch.float32) + + @torch.library.impl("aten::detach", "privateuseone") + def detach(self): + return wrap(unwrap(self), self.shape, torch.float32) + + @torch.library.impl("aten::ones_like", "privateuseone") + def ones_like(self, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None): + ans = np.ones_like(unwrap(self)) + return wrap(ans, ans.shape, torch.float32) + + @torch.library.impl("aten::expand", "privateuseone") + def expand(self, size, *, implicit=False): + ans = np.broadcast_to(self.raw_data, size) + return wrap(ans, ans.shape, torch.float32) + + @torch.library.impl("aten::empty_strided", "privateuseone") + def empty_strided(size, stride, *, dtype=None, layout=None, device=None, pin_memory=None): + out = np.empty(size) + return wrap(out, out.shape, torch.float32) + + @torch.library.impl("aten::empty.memory_format", "privateuseone") + def empty_memory_format(size, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None): + ans = np.empty(size) + return wrap(ans, ans.shape, torch.float32) + + @torch.library.impl("aten::_copy_from", "privateuseone") + def copy_from(a, b): + if a.device.type == "npy": + npy_data = unwrap(a) + else: + npy_data = a.numpy() + b.raw_data = npy_data + + @torch.library.impl("aten::view", "privateuseone") + def view(a, size): + return wrap(unwrap(a), a.shape, a.dtype) + + @torch.library.impl("aten::as_strided", "privateuseone") + def as_strided(self, size, stride, storage_offset=None): + ans = np.lib.stride_tricks.as_strided(self.raw_data, size, stride) + return wrap(ans, ans.shape, torch.float32) + + # Step 4: Use the backend with autograd + a = torch.randn(2, 2).to("npy") + b = torch.randn(2, 2).to("npy") + + a.requires_grad = True + b.requires_grad = True + + c = (a + b).sum() + c.backward() + + print(np.allclose(a.grad.raw_data, np.ones((2, 2)))) # True + +Comparison: C++ vs Python Approach +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - Aspect + - C++ Approach + - Python Approach + * - Device guard + - Manual C++ + ``C10_REGISTER_GUARD_IMPL`` + - Auto-generated dummy or Python subclass + * - Generator + - Manual C++ registration + - Not needed + * - Op registration + - ``TORCH_LIBRARY_IMPL`` in C++ + - ``@torch.library.impl`` in Python + * - Compilation + - C++ build toolchain required + - No compilation + * - Performance + - Native speed + - Python interpreter overhead + * - Streams/Events + - Full support + - Not supported + * - Multi-device + - Supported + - Single device (index 0) + * - torch.compile + - Supported + - Not yet supported + * - Best for + - Production accelerators + - Prototyping, Python-native compute + +Limitations +^^^^^^^^^^^ + +The Python backend approach has the following limitations: + +* **No stream/event support** -- Asynchronous execution is not available +* **Single device only** -- Only device index 0 is supported +* **Python interpreter overhead** -- Every op dispatch goes through Python +* **No torch.compile / dynamo support** -- JIT compilation is not available +* **Experimental API** -- May change without notice in future releases + +See Also +^^^^^^^^ + +.. seealso:: + + * `torch_openreg `_ -- C++ backend reference implementation (in-tree) + * `pytorch_open_registration_example `_ -- Original C++ example by @ptrblck + * `test_privateuseone_python_backend.py `_ -- Working test for the Python approach + * `PR #157859 `_ -- Original PR that added ``_setup_privateuseone_for_python_backend()`` + + Future Work ----------- The improvement of the ``PrivateUse1`` mechanism is still in progress, so the integration method of ``PrivateUse1`` of the new module will be added in turn. Here are a few items that we are actively working on: -* Add the integration method of ``distributed collective communication``. +* Improve ``torch.compile`` support for the Python backend approach. +* Expand the integration method of ``distributed collective communication``. * Add the integration method of ``benchmark timer``. Conclusion ---------- -This tutorial walked you through the process of integrating new backends into PyTorch via ``PrivateUse1``, including but not limited to -operator registration, generator registration, device guard registration, and so on. At the same time, some methods are introduced -to improve the user experience. +This tutorial walked you through two approaches for integrating new backends into +PyTorch via ``PrivateUse1``: + +* The **C++ approach** for production accelerators, covering operator registration, + generator registration, device guard registration, and more. +* The **Python-only approach** using ``_setup_privateuseone_for_python_backend()`` + for rapid prototyping without any C++ compilation. + +Both approaches allow you to rename the backend, generate convenience methods, and +integrate with PyTorch's autograd system. diff --git a/index.rst b/index.rst index ba30992900c..7001396d774 100644 --- a/index.rst +++ b/index.rst @@ -444,10 +444,10 @@ Welcome to PyTorch Tutorials .. customcarditem:: :header: Facilitating New Backend Integration by PrivateUse1 - :card_description: Learn how to integrate a new backend living outside of the pytorch/pytorch repo and maintain it to keep in sync with the native PyTorch backend. + :card_description: Learn how to integrate a new backend via PrivateUse1 using either C++ (production) or Python-only (prototyping) approaches. :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png :link: advanced/privateuseone.html - :tags: Extending-PyTorch,Frontend-APIs,C++ + :tags: Extending-PyTorch,Frontend-APIs,C++,Python .. customcarditem:: :header: Custom Function Tutorial: Double Backward From 7a1e6897e6e1d6dc37a6b52a00f2f90a97d61620 Mon Sep 17 00:00:00 2001 From: RiyaP-QA Date: Tue, 28 Apr 2026 22:40:48 +0530 Subject: [PATCH 2/2] Fix broken link: correct pytorch_open_registration_example author --- advanced_source/privateuseone.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced_source/privateuseone.rst b/advanced_source/privateuseone.rst index c587562b30a..9e8634b9f10 100644 --- a/advanced_source/privateuseone.rst +++ b/advanced_source/privateuseone.rst @@ -655,7 +655,7 @@ See Also .. seealso:: * `torch_openreg `_ -- C++ backend reference implementation (in-tree) - * `pytorch_open_registration_example `_ -- Original C++ example by @ptrblck + * `pytorch_open_registration_example `_ -- Original C++ example by @bdhirsh * `test_privateuseone_python_backend.py `_ -- Working test for the Python approach * `PR #157859 `_ -- Original PR that added ``_setup_privateuseone_for_python_backend()``