Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/autodiff/type-trees.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
# TypeTrees for Autodiff

## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
Memory layout descriptors for Enzyme. They tell Enzyme what "type" bytes are, with the main categories being Float, Integer, or Pointer. In Rust, memory is conceptually untyped, so it is possible to store a float into 4 bytes, and later read the bytes back as an integer. This is generally true in Rust even in the absence of `enum` or `union` types. We therefore can not directly put typetree metadata on allocations. We can also not accept Enzyme's default behaviour, which incorrectly assumes that LLVM-IR follows `strict aliasing` rules (known from C/C++). As a solution, we disable Enzyme's strict-aliasing behaviour and only generate TypeTree metadata in selected locations.

## Where we generate TypeTree
The underlying idea is that memory "at rest" is untyped, but plenty of usages interprete bytes in a way that we can communicate to Enzyme. For example, when we call a function, the memory passed to it is interpreted according to the function's signature, so we can add TypeTrees to the LLVM-IR function definitions. We currently only do that for the outermost functions differentiated (those that have a `#[autodiff]` macro on them), but we plan to extend it to all functions which are called from them. We currently also generate TypeTree information for all calls to mem{cpy|move|set}. Finally, we started to add TypeTrees to the input or return values of certain instructions, for now that mainly is `extractvalue`.

## How we add TypeTrees
If we determined that a value has a meaningfull type, then we walk the MIR `Ty` of that value in the middle-end and generate a Rust TypeTree out of it. In the codegen\_llvm backend we lower our Rust TypeTree to LLVM/Enzyme TypeTrees. We then attach them to one of three locations:

Parameters in function definitions:
```llvm
define internal void @_RNvCs7tI50jyFEig_3foo1f(ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %0, ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %2) unnamed_addr #0 !dbg !1089 {
```

Argument to calls:
```llvm
call void @llvm.memcpy.p0.p0.i64(ptr align 8 "enzyme_type"="{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Float@double}" %6, ptr align 8 "enzyme_type"="{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Float@double}" %0, i64 24, i1 false), !dbg !669
```

Input or return values of instructions, via debug metadata:
```llvm
%14 = extractvalue { ptr, i64 } %13, 0, !dbg !906, !enzyme_type !907
%15 = extractvalue { ptr, i64 } %13, 1, !dbg !906, !enzyme_type !910
...
!907 = !{!"Unknown", i32 -1, !908}
!908 = !{!"Pointer", i32 -1, !909}
!909 = !{!"Float@double"}
!910 = !{!"Unknown", i32 0, !911, i32 1, !911, i32 2, !911, i32 3, !911, i32 4, !911, i32 5, !911, i32 6, !911, i32 7, !911}
!911 = !{!"Integer"}
```

## Structure
```rust
Expand Down Expand Up @@ -47,11 +75,12 @@ TypeTree(vec![Type {
}])
```

## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata
## Why are they needed?
- Plenty of LLVM types are opaque (e.g. `ptr`), but types are needed to compute the correct derivatives.
- They tell Enzyme which bytes are differentiable (e.g. the pointer to float within a slice) vs metadata (e.g. the integer length of a slice)
- Enzyme can't deduce all types from LLVM IR, but can (to some extend) deduce them from usage (Type Analysis).
- Debug builds have a lot of variables with little usage, so Type Analysis (and thus compilation) often fails without extra TypeTrees.
- Type analysis is slow, just reading TypeTrees therefore saves a lot of time.

## What Enzyme Does With This Information:

Expand Down Expand Up @@ -190,4 +219,4 @@ TypeTree(vec![
**Arrays** use offset `-1` for efficiency:
- `&[f32; 100]` has the same pattern repeated 100 times
- Using -1 avoids listing 100 separate offsets
- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}`
- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}`
Loading