diff --git a/src/autodiff/type-trees.md b/src/autodiff/type-trees.md index 68cb78650..3476a5492 100644 --- a/src/autodiff/type-trees.md +++ b/src/autodiff/type-trees.md @@ -1,7 +1,35 @@ # TypeTrees for Autodiff ## What are TypeTrees? -Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. +Memory layout descriptors for Enzyme. They tell Enzyme what "type" bytes are, with the main categories being Float, Integer, or Pointer. In Rust, memory is conceptually untyped, so it is possible to store a float into 4 bytes, and later read the bytes back as an integer. This is generally true in Rust even in the absence of `enum` or `union` types. We therefore can not directly put typetree metadata on allocations. We can also not accept Enzyme's default behaviour, which incorrectly assumes that LLVM-IR follows `strict aliasing` rules (known from C/C++). As a solution, we disable Enzyme's strict-aliasing behaviour and only generate TypeTree metadata in selected locations. + +## Where we generate TypeTree +The underlying idea is that memory "at rest" is untyped, but plenty of usages interprete bytes in a way that we can communicate to Enzyme. For example, when we call a function, the memory passed to it is interpreted according to the function's signature, so we can add TypeTrees to the LLVM-IR function definitions. We currently only do that for the outermost functions differentiated (those that have a `#[autodiff]` macro on them), but we plan to extend it to all functions which are called from them. We currently also generate TypeTree information for all calls to mem{cpy|move|set}. Finally, we started to add TypeTrees to the input or return values of certain instructions, for now that mainly is `extractvalue`. + +## How we add TypeTrees +If we determined that a value has a meaningfull type, then we walk the MIR `Ty` of that value in the middle-end and generate a Rust TypeTree out of it. In the codegen\_llvm backend we lower our Rust TypeTree to LLVM/Enzyme TypeTrees. We then attach them to one of three locations: + +Parameters in function definitions: +```llvm +define internal void @_RNvCs7tI50jyFEig_3foo1f(ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %0, ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, ptr align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %2) unnamed_addr #0 !dbg !1089 { +``` + +Argument to calls: +```llvm + call void @llvm.memcpy.p0.p0.i64(ptr align 8 "enzyme_type"="{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Float@double}" %6, ptr align 8 "enzyme_type"="{[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Float@double}" %0, i64 24, i1 false), !dbg !669 +``` + +Input or return values of instructions, via debug metadata: +```llvm + %14 = extractvalue { ptr, i64 } %13, 0, !dbg !906, !enzyme_type !907 + %15 = extractvalue { ptr, i64 } %13, 1, !dbg !906, !enzyme_type !910 +... +!907 = !{!"Unknown", i32 -1, !908} +!908 = !{!"Pointer", i32 -1, !909} +!909 = !{!"Float@double"} +!910 = !{!"Unknown", i32 0, !911, i32 1, !911, i32 2, !911, i32 3, !911, i32 4, !911, i32 5, !911, i32 6, !911, i32 7, !911} +!911 = !{!"Integer"} +``` ## Structure ```rust @@ -47,11 +75,12 @@ TypeTree(vec![Type { }]) ``` -## Why Needed? -- Enzyme can't deduce complex type layouts from LLVM IR -- Prevents slow memory pattern analysis -- Enables correct derivative computation for nested structures -- Tells Enzyme which bytes are differentiable vs metadata +## Why are they needed? +- Plenty of LLVM types are opaque (e.g. `ptr`), but types are needed to compute the correct derivatives. +- They tell Enzyme which bytes are differentiable (e.g. the pointer to float within a slice) vs metadata (e.g. the integer length of a slice) +- Enzyme can't deduce all types from LLVM IR, but can (to some extend) deduce them from usage (Type Analysis). +- Debug builds have a lot of variables with little usage, so Type Analysis (and thus compilation) often fails without extra TypeTrees. +- Type analysis is slow, just reading TypeTrees therefore saves a lot of time. ## What Enzyme Does With This Information: @@ -190,4 +219,4 @@ TypeTree(vec![ **Arrays** use offset `-1` for efficiency: - `&[f32; 100]` has the same pattern repeated 100 times - Using -1 avoids listing 100 separate offsets -- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}` \ No newline at end of file +- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}`