Skip to content

Commit 3b9b850

Browse files
authored
Add tips & tricks for custom diffs article (#144)
* Add tips & tricks for custom diffs article * Change memory access terminology * Clarify that autodiff does handle stepwise functions, but you can still provide custom derivatives if you want. * Add example snippets, and explain how Slang provides derivatives for max() and clamp() * Fix slang errors, de-duplicate finite difference approximation example * One more .p for primal
1 parent 09afd22 commit 3b9b850

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

_data/documentation.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ articles:
4141
description: "Things to know when using Slang to compile to the WGSL."
4242
link_url: "https://docs.shader-slang.org/en/latest/external/slang/docs/user-guide/a2-03-wgsl-target-specific.html"
4343
link_label: "WGSL Functionalities"
44+
- title: "Autodiff Tips & Tricks: When to Use Custom Derivatives"
45+
description: "Guidance on when defining custom derivatives can help you out"
46+
link_url: "https://docs.shader-slang.org/en/latest/autodiff-tips-custom-diffs.md"
47+
link_label: "Autodiff Tips & Tricks: Custom Derivatives"
4448

4549
tutorials:
4650
- title: "Write Your First Slang Shader"
@@ -82,4 +86,5 @@ tutorials:
8286
- title: "Slang Automatic Differentiation Tutorial - 2"
8387
description: "Learn how to use Slang’s automatic differentiation feature to compute the gradient of a shader."
8488
link_url: "https://docs.shader-slang.org/en/latest/auto-diff-tutorial-2.html"
85-
link_label: "Automatic Differentiation Tutorial - 2"
89+
link_label: "Automatic Differentiation Tutorial - 2"
90+

docs/autodiff-tips-custom-diffs.md

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
---
2+
title: Autodiff Tips and Tricks: Custom Derivatives
3+
layout: page
4+
description: Autodiff Tips and Tricks: Custom Derivatives
5+
permalink: "/docs/autodiff-tips-custom-diffs"
6+
intro_image_absolute: true
7+
intro_image_hide_on_mobile: false
8+
---
9+
10+
# Autodiff Tips and Tricks: When to Use Custom Derivatives
11+
12+
Slang's automatic differentiation (autodiff) system is powerful and can handle a wide range of functions. However, there are specific scenarios where defining a custom derivative can be beneficial, offering more control and accuracy. This guide outlines the major examples of when you might want to consider writing your own custom derivative, rather than relying solely on Slang's autodiff capabilities.
13+
14+
# 1\. Opaque Functions
15+
16+
Opaque functions are those where the internal operations are not visible or accessible to Slang's autodiff system. This often occurs with functions that interact with external systems or hardware, such as:
17+
18+
* **Hardware-accelerated texture filtering:** When performing operations like texture filtering, the underlying mathematical process may not be exposed in a way that allows Slang to compute a derivative automatically. In such cases, you would need to define a custom derivative that approximates or accurately represents the sensitivity of the output to the input.
19+
20+
For these functions, since the autodiff system cannot "see" inside to compute the derivative, providing a custom derivative allows you to still incorporate them into your differentiable computations.
21+
22+
```slang
23+
[BackwardDerivative(externalFunction_bwd)]
24+
float externalFunction(float x)
25+
{
26+
// This calls an external library or hardware function
27+
// that Slang cannot differentiate automatically
28+
return someExternalLibraryCall(x);
29+
}
30+
31+
void externalFunction_bwd(inout DifferentialPair<float> x, float dOut)
32+
{
33+
// Approximate the derivative using finite differences
34+
const float epsilon = 1e-6;
35+
float fwd = someExternalLibraryCall(x.p + epsilon);
36+
float bwd = someExternalLibraryCall(x.p - epsilon);
37+
float derivative = (fwd - bwd) / (2.0 * epsilon);
38+
39+
x = diffPair(x.p, derivative);
40+
}
41+
```
42+
43+
This example shows one option to handle an opaque function that calls an external library. Since Slang cannot see inside the external function, we can use finite differences to approximate the derivative.
44+
45+
# 2\. Buffer Accesses
46+
47+
Functions whose output depends on values retrieved from memory based on an input (e.g., accessing an RWStructuredBuffer on the GPU, or reading from a raw pointer CPU-side) introduce side-effects that automatic differentiation struggles to handle. This can include things like race conditions or ambiguous derivative write-back locations. Additionally, the lookup index itself is non-continuous. Therefore, custom derivatives are often necessary to accurately represent "change" at these points, potentially involving subgradients or specific approximations.
48+
49+
```slang
50+
RWStructuredBuffer<float> myBuffer;
51+
RWStructuredBuffer<Atomic<float>> gradientBuffer; // Global buffer for gradients
52+
53+
[BackwardDerivative(loadFloat_bwd)]
54+
float loadFloat(uint idx)
55+
{
56+
return myBuffer[idx];
57+
}
58+
59+
void loadFloat_bwd(uint idx, float dOut)
60+
{
61+
// Safely accumulate gradient to global buffer using atomic addition
62+
gradientBuffer[idx] += dOut;
63+
}
64+
```
65+
66+
This example shows one way to handle differentiation of buffer access. The `loadFloat` function reads from a buffer, and its custom derivative `loadFloat_bwd` safely accumulates gradients to a global buffer using atomic operations to avoid race conditions when multiple threads might write to the same location.
67+
68+
# 3\. Numerically Unstable Functions
69+
70+
Numerical stability refers to how well a computation preserves accuracy when faced with small changes in input or when intermediate values become very large or very small. Some complex mathematical functions can be numerically unstable, leading to issues like:
71+
72+
* **Divisions by zero:** Highly complex expressions can inadvertently lead to denominators approaching zero, resulting in undefined or infinite derivatives.
73+
* **Gradient explosion or vanishing:** In deep learning contexts, gradients can become extremely large (explosion) or extremely small (vanishing), hindering effective optimization.
74+
75+
By defining a custom derivative, you can implement more robust numerical methods that mitigate these issues, ensuring that your derivatives are well-behaved and do not lead to computational errors or poor training performance. This might involve re-parameterizations or specialized derivative formulas.
76+
77+
```slang
78+
[BackwardDerivative(safeDivide_bwd)]
79+
float safeDivide(float numerator, float denominator)
80+
{
81+
const float epsilon = 1e-8;
82+
return numerator / max(denominator, epsilon);
83+
}
84+
85+
void safeDivide_bwd(inout DifferentialPair<float> numerator, inout DifferentialPair<float> denominator, float dOut)
86+
{
87+
const float epsilon = 1e-8;
88+
const float maxGradient = 1e6; // Prevent gradient explosion
89+
float denomStable = denominator.p + epsilon;
90+
// Clamp gradients to prevent explosion when denominator is very small
91+
float dNumerator = clamp(dOut / denomStable, -maxGradient, maxGradient);
92+
float dDenominator = clamp(-dOut * numerator.p / (denomStable * denomStable), -maxGradient, maxGradient);
93+
94+
numerator = diffPair(numerator.p, dNumerator);
95+
denominator = diffPair(denominator.p, dDenominator);
96+
}
97+
```
98+
99+
This example shows how to handle division that could be numerically unstable. The `safeDivide` function adds a small epsilon to prevent division by zero, and its custom derivative clamps the gradients to prevent explosion when the denominator is very small.
100+
101+
# Mixing Custom and Automatic Differentiation
102+
103+
One of the key strengths of Slang's autodiff system is its flexibility. You are not forced to choose between entirely custom derivatives or entirely automatic differentiation. Slang allows you to **mix custom and automatic differentiation**.
104+
105+
This means you can address just the parts of your function stack that truly need custom derivatives (e.g., the opaque or numerically unstable sections) while still leveraging Slang's powerful autodiff for the rest of your computations. This hybrid approach offers the best of both worlds: the convenience and efficiency of automatic differentiation where it's most effective, and the precision and control of custom derivatives where they are absolutely necessary.
106+
107+
For examples of this in practice, take a look at some of the [experiments]() in our SlangPy samples repository. In particular, you can see a user-defined custom derivative function invoking bwd\_diff() to make use of automatic differentiation for the functions it calls out to in the [differentiable splatting experiment](https://github.com/shader-slang/slangpy-samples/blob/main/experiments/diff-splatting/diffsplatting2d.slang#L512).
108+
109+
# Approximating Derivatives for Inherently Undifferentiable Functions
110+
111+
While some functions are mathematically discontinuous or opaque, making them undifferentiable in a strict sense, it's often still possible and desirable to define a custom derivative that approximates their behavior or provides a useful "gradient" for optimization purposes. This is particularly crucial in machine learning and computational graphics where such functions are common. Here's how you might approach creating a custom derivative for something that seems inherently undifferentiable:
112+
113+
* **Subgradients for Discontinuous Functions**:
114+
For functions with sharp corners or jumps (like ReLU, absolute value, or step functions), the derivative is undefined at specific points. Automatic differentiation systems like Slang's handle these cases using established conventions, but you may want custom behavior. A subgradient is not a unique value but rather a set of possible "slopes" at the non-differentiable point. For example, for ReLU:
115+
* If input \> 0, derivative is 1\.
116+
* If input \< 0, derivative is 0\.
117+
* If input \= 0, the subgradient can be any value between 0 and 1 (inclusive).
118+
When implementing a custom derivative, you typically pick a specific value from this subgradient set (e.g., 0 or 0.5) to ensure a deterministic derivative for your autodiff system.
119+
120+
Slang's autodiff system includes built-in conventions for handling many discontinuous functions that would otherwise be problematic for differentiation. For example, Slang can automatically generate derivatives for the `max()` function even though it's not differentiable at the point where `x = y`. When the inputs are equal, Slang assigns half the gradient to each input, ensuring that gradients flow through both branches of the computation. Similarly, for `clamp()` function, when the input equals the minimum or maximum bounds, Slang propagates the gradient to the input `x` rather than to the `min` or `max` parameters, making a choice that favors the primary input. Similar conventions exist for other discontinuous functions like `min()`, `abs()`, and conditional operations. These conventions are designed to provide sensible gradients that work well in practice, even when the mathematical derivative is undefined at certain points.
121+
122+
* **Finite Difference Approximation**:
123+
If you have no analytical way to determine the derivative, you can resort to numerical approximation using finite differences. This method involves evaluating the function at two closely spaced points and calculating the slope between them. While computationally more expensive and potentially less accurate than analytical derivatives, finite differences can provide a workable approximation for opaque or highly complex functions where direct differentiation is impossible. You would implement this calculation in your custom derivative function. See the example in the "Opaque Functions" section above for a practical implementation of this technique.
124+
* **Surrogate Gradients (for Hard Discontinuities)**:
125+
For functions with truly "hard" discontinuities (e.g., a direct lookup in a table based on an index, or a strict "if-else" branching that completely changes the computation), a subgradient or finite difference might not be suitable. In such cases, you might use a surrogate gradient. This involves replacing the undifferentiable part of the computation with a differentiable approximation solely for the purpose of backpropagation.
126+
For example, if you have a \`floor()\` operation (which is not differentiable), you might, during the backward pass, pretend it was an identity function (\`x\`) or a smoothed version of it to allow gradients to flow through. The forward pass still uses the exact \`floor()\` operation, but the backward pass uses your custom, differentiable surrogate. This allows the optimization algorithm to still "feel" a gradient and make progress, even if it's not the true mathematical derivative.
127+
* **Domain-Specific Knowledge and Heuristics**:
128+
Sometimes, the best "derivative" for an undifferentiable function comes from understanding the underlying problem and applying domain-specific heuristics. For instance, in rendering, if a function represents a decision boundary (e.g., inside/outside an object), your custom derivative might encode how sensitive that decision is to small changes in input, even if the boundary itself is sharp. This often involves defining a "gradient" that pushes the input towards the desired outcome based on prior knowledge.
129+
130+
When choosing an approximation method, consider the trade-offs between accuracy, computational cost, and how well the approximated derivative guides your optimization or analysis. The goal is to provide a "signal" to the autodiff system that allows it to effectively propagate changes and facilitate whatever computation or optimization you are performing.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Slang Documentation
2222
SPIR-V Specific Functionalities <external/slang/docs/user-guide/a2-01-spirv-target-specific>
2323
Metal Specific Functionalities <external/slang/docs/user-guide/a2-02-metal-target-specific>
2424
WGSL Specific Functionalities <external/slang/docs/user-guide/a2-03-wgsl-target-specific>
25+
Autodiff Tips & Tricks <autodiff-tips-custom-diffs>
2526

2627
.. toctree::
2728
:caption: Tutorials

0 commit comments

Comments
 (0)