From 04368bf71442446e659163f39d0f0284c2d82be5 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 24 Jun 2026 01:21:05 +0200 Subject: [PATCH 1/4] Fix typetree generation for arguments, e.g. slices --- compiler/rustc_ast/src/expand/typetree.rs | 1 + compiler/rustc_codegen_llvm/src/builder.rs | 2 +- .../src/builder/autodiff.rs | 3 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 1 + .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 5 + compiler/rustc_codegen_llvm/src/llvm/mod.rs | 15 +++ compiler/rustc_codegen_llvm/src/typetree.rs | 122 ++++++++++++------ compiler/rustc_middle/src/ty/typetree.rs | 72 ++++++----- 8 files changed, 149 insertions(+), 72 deletions(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 9619c80904426..1d099f475b5f6 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -28,6 +28,7 @@ pub enum Kind { Anything, Integer, Pointer, + RustSlice, Half, Float, Double, diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index afb6985d21a95..6f5443a7c3c7a 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1175,7 +1175,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // vs. copying a struct with mixed types requires different derivative handling. // The TypeTree tells Enzyme exactly what memory layout to expect. if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index ee17468ec0c03..8bcf54a28da41 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -293,6 +293,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_name: &str, @@ -379,7 +380,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( ); if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { - crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + crate::typetree::add_tt(cx.llmod, cx.llcx, tcx, fn_to_diff, fnc_tree); } let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 1caa95f369360..8ffa6dcd02c7d 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1793,6 +1793,7 @@ fn codegen_autodiff<'ll, 'tcx>( // Build body generate_enzyme_call( bx, + tcx, bx.cx, fn_to_diff, &diff_symbol, diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 195e050a9b651..c95e2e97a5c12 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -292,6 +292,11 @@ pub(crate) mod Enzyme_AD { unsafe { (self.EnzymeTypeTreeToString)(tree) } } + pub(crate) fn tree_to_cstr(&self, tree: *mut EnzymeTypeTree) -> &std::ffi::CStr { + let c_str = self.tree_to_string(tree); + unsafe { std::ffi::CStr::from_ptr(c_str) } + } + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 2ec19b1795b5a..5c5127c691990 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -76,6 +76,21 @@ pub(crate) fn CreateAttrStringValue<'ll>( ) } } +pub(crate) fn CreateAttrStringValueFromCStr<'ll>( + llcx: &'ll Context, + attr: &std::ffi::CStr, + value: &std::ffi::CStr, +) -> &'ll Attribute { + unsafe { + LLVMCreateStringAttribute( + llcx, + (*attr).as_ptr(), + (*attr).to_bytes().len() as c_uint, + (*value).as_ptr(), + (*value).to_bytes().len() as c_uint, + ) + } +} pub(crate) fn CreateAttrString<'ll>(llcx: &'ll Context, attr: &str) -> &'ll Attribute { unsafe { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 4f433f273c8cc..1e3c38d3bf79a 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,6 +1,8 @@ -use std::ffi::{CString, c_char, c_uint}; +use std::ffi::{CString, c_char}; -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::{FncTree, Kind, TypeTree as RustTypeTree}; +use rustc_middle::bug; +use rustc_middle::ty::TyCtxt; use crate::attributes; use crate::llvm::{self, EnzymeWrapper, Value}; @@ -9,27 +11,38 @@ fn to_enzyme_typetree( rust_typetree: RustTypeTree, _data_layout: &str, llcx: &llvm::Context, -) -> llvm::TypeTree { +) -> (llvm::TypeTree, Vec) { let mut enzyme_tt = llvm::TypeTree::new(); - process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); - enzyme_tt + let extra_ints = process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); + + let mut int_vec = vec![]; + for _ in 0..extra_ints { + let mut int_tt = llvm::TypeTree::new(); + int_tt.insert(&[0], llvm::CConcreteType::DT_Integer, llcx); + int_vec.push(int_tt); + } + + (enzyme_tt, int_vec) } + fn process_typetree_recursive( enzyme_tt: &mut llvm::TypeTree, rust_typetree: &RustTypeTree, parent_indices: &[i64], llcx: &llvm::Context, -) { +) -> u32 { + let mut extra_ints = 0; for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { - rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, - rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, - rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, - rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, - rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, - rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, - rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128, - rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, + Kind::Anything => llvm::CConcreteType::DT_Anything, + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + Kind::RustSlice => llvm::CConcreteType::DT_Pointer, + Kind::Half => llvm::CConcreteType::DT_Half, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::F128 => llvm::CConcreteType::DT_FP128, + Kind::Unknown => llvm::CConcreteType::DT_Unknown, }; let mut indices = parent_indices.to_vec(); @@ -43,18 +56,25 @@ fn process_typetree_recursive( enzyme_tt.insert(&indices, concrete_type, llcx); - if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer + if matches!(rust_type.kind, Kind::RustSlice) { + // We lower slices to `ptr,int`, so add the int here. + extra_ints += 1; + } + + if matches!(rust_type.kind, Kind::Pointer | Kind::RustSlice) && !rust_type.child.0.is_empty() { process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); } } + extra_ints } #[cfg_attr(not(feature = "llvm_enzyme"), allow(unused))] -pub(crate) fn add_tt<'ll>( +pub(crate) fn add_tt<'tcx, 'll>( llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, + tcx: TyCtxt<'tcx>, fn_def: &'ll Value, tt: FncTree, ) { @@ -77,39 +97,59 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + let mut offset = 0; for (i, input) in inputs.iter().enumerate() { - unsafe { - let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + + // This scope is just a visual reminder that we *must* drop the enzyme_wrapper before + // we drop any typetrees (mainly enzyme_tt and extra_ints). Drop calls can not accept + // arguments like an enzyme_wrapper, so the typetree drop impl has to call get_instance + // on the static enzyme instance, which is behind a Mutex. Therefore we'd deadlock if we + // hold the enzyme_wrapper while dropping the typetrees. + { let enzyme_wrapper = EnzymeWrapper::get_instance(); - let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); - let c_str = std::ffi::CStr::from_ptr(c_str); - - let attr = llvm::LLVMCreateStringAttribute( - llcx, - c_attr_name.as_ptr(), - c_attr_name.as_bytes().len() as c_uint, - c_str.as_ptr(), - c_str.to_bytes().len() as c_uint, - ); + let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); + for v in &extra_ints { + offset += 1; + let c_str = enzyme_wrapper.tree_to_cstr(v.inner); + let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); + } } } + // We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict. + // Error, so we can learn from our mistakes. + let expected = offset as usize + inputs.len(); + let actual = llvm::count_params(fn_def) as usize; + if expected != actual { + tcx.dcx().warn(format!( + "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ + but the generated LLVM function has {actual} parameter(s)" + )); + } - unsafe { - let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + { let enzyme_wrapper = EnzymeWrapper::get_instance(); - let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); - let c_str = std::ffi::CStr::from_ptr(c_str); - - let ret_attr = llvm::LLVMCreateStringAttribute( - llcx, - c_attr_name.as_ptr(), - c_attr_name.as_bytes().len() as c_uint, - c_str.as_ptr(), - c_str.to_bytes().len() as c_uint, - ); + if !extra_ints.is_empty() { + bug!("A return type should not have extra integers. Implementation bug!"); + } + let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); + + let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); diff --git a/compiler/rustc_middle/src/ty/typetree.rs b/compiler/rustc_middle/src/ty/typetree.rs index 9e941bdb849ec..f541301afccca 100644 --- a/compiler/rustc_middle/src/ty/typetree.rs +++ b/compiler/rustc_middle/src/ty/typetree.rs @@ -32,7 +32,8 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { // Create TypeTree for return type let ret = typetree_from_ty(tcx, sig.output()); - FncTree { args, ret } + let f = FncTree { args, ret }; + f } /// Generate a TypeTree for a specific type. @@ -64,31 +65,29 @@ fn typetree_from_ty_impl_inner<'tcx>( } visited.push(ty); - if ty.is_scalar() { - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) - } else if ty.is_floating_point() { - match ty { - x if x == tcx.types.f16 => (Kind::Half, 2), - x if x == tcx.types.f32 => (Kind::Float, 4), - x if x == tcx.types.f64 => (Kind::Double, 8), - x if x == tcx.types.f128 => (Kind::F128, 16), - _ => (Kind::Integer, 0), - } - } else { - (Kind::Integer, 0) - }; - - // Use offset 0 for scalars that are direct targets of references (like &f64) - // Use offset -1 for scalars used directly (like function return types) - let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; - return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + if ty.is_slice() { + bug!("incorrect autodiff typetree handling for slice: {}", ty); } if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { let Some(inner_ty) = ty.builtin_deref(true) else { - return TypeTree::new(); + bug!("incorrect autodiff typetree handling for type: {}", ty); }; + // slices are represented as `&'{erased} mut [f32]` + // This reads as a reference to a slice of f32. + // So we'd end up with ptr->RustSlice->f32 without this extra handling + if inner_ty.is_slice() { + if let ty::Slice(element_ty) = inner_ty.kind() { + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::RustSlice, + child: element_tree, + }]); + } + } let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); return TypeTree(vec![Type { @@ -121,14 +120,6 @@ fn typetree_from_ty_impl_inner<'tcx>( } } - if ty.is_slice() { - if let ty::Slice(element_ty) = ty.kind() { - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - return element_tree; - } - } - if let ty::Tuple(tuple_types) = ty.kind() { if tuple_types.is_empty() { return TypeTree::new(); @@ -204,5 +195,28 @@ fn typetree_from_ty_impl_inner<'tcx>( } } + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f16 => (Kind::Half, 2), + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + x if x == tcx.types.f128 => (Kind::F128, 16), + _ => bug!("Unexpected floating point type: {:?}", ty), + } + } else { + // is_scalar also accepts things like FnDef or FnPtr, for which we don't know how to + // generate a TypeTree, so return nothing. + return TypeTree::new(); + }; + + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) or slices. + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + } + TypeTree::new() } From 64d9d50ed8ffdf94915b97bcc06111b6fe15ee60 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 25 Jun 2026 02:07:00 +0200 Subject: [PATCH 2/4] Fix typetree generation for memcpy to move TA failure in testcase from memcpy to a later extractvalue --- compiler/rustc_ast/src/expand/typetree.rs | 3 + .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 83 ++++++++++++++----- .../rustc_codegen_ssa/src/traits/builder.rs | 21 ++++- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 17 ++++ compiler/rustc_middle/src/ty/typetree.rs | 6 ++ 6 files changed, 107 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 1d099f475b5f6..84576c67729a3 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -58,6 +58,9 @@ impl TypeTree { } Self(ints) } + pub fn add_indirection(self) -> Self { + Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }]) + } } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)] diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index c95e2e97a5c12..d125760a5b9aa 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -66,6 +66,7 @@ unsafe extern "C" { NameLen: libc::size_t, ) -> Option<&Value>; + pub(crate) fn LLVMRustIsIntrinsicCall(V: &Value) -> bool; } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 1e3c38d3bf79a..ca35abe556b3b 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -8,7 +8,7 @@ use crate::attributes; use crate::llvm::{self, EnzymeWrapper, Value}; fn to_enzyme_typetree( - rust_typetree: RustTypeTree, + rust_typetree: &RustTypeTree, _data_layout: &str, llcx: &llvm::Context, ) -> (llvm::TypeTree, Vec) { @@ -86,6 +86,13 @@ pub(crate) fn add_tt<'tcx, 'll>( #[cfg(not(feature = "llvm_enzyme"))] return; + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return; + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return; + } + let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; @@ -99,7 +106,7 @@ pub(crate) fn add_tt<'tcx, 'll>( let mut offset = 0; for (i, input) in inputs.iter().enumerate() { - let (enzyme_tt, extra_ints) = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(&input, llvm_data_layout, llcx); // This scope is just a visual reminder that we *must* drop the enzyme_wrapper before // we drop any typetrees (mainly enzyme_tt and extra_ints). Drop calls can not accept @@ -111,38 +118,62 @@ pub(crate) fn add_tt<'tcx, 'll>( let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn( - fn_def, - llvm::AttributePlace::Argument(i as u32 + offset), - &[attr], - ); + dbg!(&fn_def); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("callsite"); + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); + } else { + dbg!("llfn"); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); for v in &extra_ints { offset += 1; let c_str = enzyme_wrapper.tree_to_cstr(v.inner); let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn( - fn_def, - llvm::AttributePlace::Argument(i as u32 + offset), - &[int_attr], - ); + dbg!(&fn_def); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("callsite"); + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + } else { + dbg!("llfn"); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } } // We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict. // Error, so we can learn from our mistakes. - let expected = offset as usize + inputs.len(); - let actual = llvm::count_params(fn_def) as usize; - if expected != actual { - tcx.dcx().warn(format!( - "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ - but the generated LLVM function has {actual} parameter(s)" - )); + if unsafe { !llvm::LLVMRustIsIntrinsicCall(fn_def) } { + let expected = offset as usize + inputs.len(); + let actual = llvm::count_params(fn_def) as usize; + if expected != actual { + tcx.dcx().warn(format!( + "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ + but the generated LLVM function has {actual} parameter(s)" + )); + } } - let (enzyme_tt, extra_ints) = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); - { + let (enzyme_tt, extra_ints) = to_enzyme_typetree(&ret_tt, llvm_data_layout, llcx); + if ret_tt != RustTypeTree::new() { let enzyme_wrapper = EnzymeWrapper::get_instance(); if !extra_ints.is_empty() { bug!("A return type should not have extra integers. Implementation bug!"); @@ -151,7 +182,15 @@ pub(crate) fn add_tt<'tcx, 'll>( let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::ReturnValue, + &[ret_attr], + ); + } else { + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index d68549c6871f4..eb83d39ec9ed8 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,10 +2,12 @@ use std::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; use rustc_hir::attrs::AttributeKind; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; +use rustc_middle::ty::typetree::typetree_from_ty; use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; use rustc_session::config::OptLevel; use rustc_span::Span; @@ -456,7 +458,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, - tt: Option, + tt: Option, ); fn memmove( &mut self, @@ -466,6 +468,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + //tt: Option, ); fn memset( &mut self, @@ -474,6 +477,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + //tt: Option, ); // Produce a value from calling the `vscale` intrinsic (containing the `vscale` multiplier that @@ -517,6 +521,19 @@ pub trait BuilderMethods<'a, 'tcx>: let temp = self.load_operand(src.with_type(layout)); temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { + let tt = typetree_from_ty(self.tcx(), layout.ty); + // We seem to pass all values to memcpy with one more indirection. + let tt = tt.add_indirection(); + dbg!(&tt); + use rustc_middle::ty::print::with_no_trimmed_paths; + + with_no_trimmed_paths!({ + eprintln!("memcpy ty = {:?}", layout.ty); + }); + let fnc_tree = FncTree { + args: vec![tt.clone(), tt], + ret: TypeTree::new(), + }; let bytes = self.const_usize(layout.size.bytes()); let bytes = if layout.peel_transparent_wrappers(self).ty.is_scalable_vector() { let vscale = self.vscale(self.type_i64()); @@ -524,7 +541,7 @@ pub trait BuilderMethods<'a, 'tcx>: } else { bytes }; - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); } } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8b063af187a58..df50a0863e384 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -160,6 +160,23 @@ extern "C" void LLVMRustPrintStatisticsJSON(RustStringRef OutBuf) { llvm::PrintStatisticsJSON(OS); } +extern "C" bool LLVMRustIsIntrinsicCall(LLVMValueRef V) { + llvm::Value *Val = llvm::unwrap(V); +llvm:errs() << "LLVMRustIsIntrinsicCall: " << *Val << "\n"; + + if (auto *CB = llvm::dyn_cast(Val)) { + if (auto *Callee = CB->getCalledFunction()) + return Callee->isIntrinsic(); + + return false; + } + + if (auto *F = llvm::dyn_cast(Val)) + return F->isIntrinsic(); + + return false; +} + // Some of the functions here rely on LLVM modules that may not always be // available. As such, we only try to build it in the first place, if // llvm.offload is enabled. diff --git a/compiler/rustc_middle/src/ty/typetree.rs b/compiler/rustc_middle/src/ty/typetree.rs index f541301afccca..7e3f5b4ab3389 100644 --- a/compiler/rustc_middle/src/ty/typetree.rs +++ b/compiler/rustc_middle/src/ty/typetree.rs @@ -39,6 +39,12 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate a TypeTree for a specific type. /// Mainly a convenience wrapper around the actual implementation. pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return TypeTree::new(); + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return TypeTree::new(); + } let mut visited = Vec::new(); typetree_from_ty_impl_inner(tcx, ty, 0, &mut visited, false) } From f60a76a1a686e1bdb057e4a3af2dec44e12e40b2 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 26 Jun 2026 12:57:29 +0200 Subject: [PATCH 3/4] fix related enzyme bug by updating submodule --- src/tools/enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/enzyme b/src/tools/enzyme index 7c0141f133a35..5a4245e1c4edb 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 7c0141f133a3592daa12cc6cc07f297a5222a42e +Subproject commit 5a4245e1c4edb9e0ac6126f3438ee7a017295fe2 From 1b673bf304cf2131ba0cf4f21e0dafa370c7e5ec Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 26 Jun 2026 12:58:46 +0200 Subject: [PATCH 4/4] fix testcase --- compiler/rustc_codegen_llvm/src/asm.rs | 2 +- compiler/rustc_codegen_llvm/src/builder.rs | 21 +++++-- .../src/builder/gpu_offload.rs | 6 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 16 ++++-- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 15 ++++- compiler/rustc_codegen_llvm/src/typetree.rs | 54 ++++++++++++------ compiler/rustc_codegen_ssa/src/mir/block.rs | 4 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 56 ++++++++++++++++++- .../rustc_codegen_ssa/src/traits/builder.rs | 2 +- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 22 +++++++- 10 files changed, 158 insertions(+), 40 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs index 2598c1b38ff88..adf3f0bf1137f 100644 --- a/compiler/rustc_codegen_llvm/src/asm.rs +++ b/compiler/rustc_codegen_llvm/src/asm.rs @@ -365,7 +365,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let value = if output_types.len() == 1 { result } else { - self.extract_value(result, op_idx[&idx] as u64) + self.extract_value(result, op_idx[&idx] as u64, None) }; let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 6f5443a7c3c7a..72415c30086dd 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -3,6 +3,7 @@ use std::iter; use std::ops::Deref; use rustc_ast::expand::typetree::FncTree; +use rustc_ast::expand::typetree::TypeTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -631,7 +632,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let name = format!("llvm.{}{oop_str}.with.overflow", if signed { 's' } else { 'u' }); let res = self.call_intrinsic(name, &[self.type_ix(width)], &[lhs, rhs]); - (self.extract_value(res, 0), self.extract_value(res, 1)) + (self.extract_value(res, 0, None), self.extract_value(res, 1, None)) } fn from_immediate(&mut self, val: Self::Value) -> Self::Value { @@ -1257,9 +1258,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn extract_value(&mut self, agg_val: &'ll Value, idx: u64) -> &'ll Value { + fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option) -> &'ll Value { assert_eq!(idx as c_uint as u64, idx); - unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) } + let extract = unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) }; + if let Some(tt) = tt { + let fnc_tree = FncTree { + args: vec![], + ret: tt, + }; + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, extract, fnc_tree); + } + extract } fn insert_value(&mut self, agg_val: &'ll Value, elt: &'ll Value, idx: u64) -> &'ll Value { @@ -1279,7 +1288,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { unsafe { llvm::LLVMSetCleanup(landing_pad, llvm::TRUE); } - (self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1)) + (self.extract_value(landing_pad, 0, None), self.extract_value(landing_pad, 1, None)) } fn filter_landing_pad(&mut self, pers_fn: &'ll Value) { @@ -1380,8 +1389,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { llvm::FALSE, // SingleThreaded ); llvm::LLVMSetWeak(value, weak.to_llvm_bool()); - let val = self.extract_value(value, 0); - let success = self.extract_value(value, 1); + let val = self.extract_value(value, 0, None); + let success = self.extract_value(value, 1, None); (val, success) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 0b009321802cf..39e3ac8eb484f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -162,9 +162,9 @@ impl<'ll> OffloadKernelDims<'ll> { builder: &mut Builder<'_, 'll, 'tcx>, arr: &'ll Value, ) -> &'ll Value { - let x = builder.extract_value(arr, 0); - let y = builder.extract_value(arr, 1); - let z = builder.extract_value(arr, 2); + let x = builder.extract_value(arr, 0, None); + let y = builder.extract_value(arr, 1, None); + let z = builder.extract_value(arr, 2, None); let xy = builder.mul(x, y); builder.mul(xy, z) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 8ffa6dcd02c7d..2842bf0fcf4a9 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -43,7 +43,7 @@ use crate::errors::{ AutoDiffWithoutEnable, AutoDiffWithoutLto, IntrinsicSignatureMismatch, IntrinsicWrongArch, OffloadWithoutEnable, OffloadWithoutFatLTO, UnknownIntrinsic, }; -use crate::intrinsic::ty::typetree::fnc_typetrees; +use crate::intrinsic::ty::typetree::{fnc_typetrees, typetree_from_ty}; use crate::llvm::{self, Type, Value}; use crate::type_of::LayoutLlvmExt; use crate::va_arg::emit_va_arg; @@ -776,6 +776,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { self.extract_value( args[0].immediate(), fn_args.const_at(2).to_leaf().to_i32() as u64, + None, ) } @@ -1059,7 +1060,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[], &[llvtable, vtable_byte_offset, typeid], ); - self.extract_value(type_checked_load, 0) + self.extract_value(type_checked_load, 0, None) } fn va_start(&mut self, va_list: &'ll Value) { @@ -1171,7 +1172,7 @@ fn autocast<'ll>( iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty)) .enumerate() { - let elt = bx.extract_value(val, idx as u64); + let elt = bx.extract_value(val, idx as u64, None); let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty); ret = bx.insert_value(ret, casted_elt, idx as u64); } @@ -1632,7 +1633,7 @@ fn codegen_gnu_try<'ll, 'tcx>( let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 1); let tydesc = bx.const_null(bx.type_ptr()); bx.add_clause(vals, tydesc); - let ptr = bx.extract_value(vals, 0); + let ptr = bx.extract_value(vals, 0, None); let catch_ty = bx.type_func(&[bx.type_ptr(), bx.type_ptr()], bx.type_void()); bx.call(catch_ty, None, None, catch_func, &[data, ptr], None, None); bx.ret(bx.const_bool(true)); @@ -1917,8 +1918,11 @@ fn get_args_from_tuple<'ll, 'tcx>( let field = tuple_place.project_field(bx, tuple_index); let llvm_ty = field.layout.llvm_type(bx.cx); let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); - result.push(bx.extract_value(pair_val, 0)); - result.push(bx.extract_value(pair_val, 1)); + + let extract_ty = field.layout.ty; + let tt = typetree_from_ty(bx.tcx(), extract_ty); + result.push(bx.extract_value(pair_val, 0, Some(tt.clone()))); + result.push(bx.extract_value(pair_val, 1, Some(tt))); tuple_index += 1; } PassMode::Indirect { .. } => { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index d125760a5b9aa..e40d3aa3d2a3d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -67,6 +67,11 @@ unsafe extern "C" { ) -> Option<&Value>; pub(crate) fn LLVMRustIsIntrinsicCall(V: &Value) -> bool; + pub(crate) safe fn LLVMRustSupportsEnzymeMD(V: &Value) -> bool; + pub(crate) fn LLVMRustSetEnzymeTypeMD( + v: &Value, + md: &Value, + ); } unsafe extern "C" { @@ -97,7 +102,7 @@ pub(crate) mod Enzyme_AD { use rustc_session::filesearch; use super::{CConcreteType, CTypeTreeRef, Context}; - use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor}; + use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor, Value}; type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8); type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char); @@ -115,6 +120,7 @@ pub(crate) mod Enzyme_AD { unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context); type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char; type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char); + type EnzymeTypeTreeToMDFn = unsafe extern "C" fn(CTypeTreeRef, &Context) -> Option<& Value>; #[allow(non_snake_case)] pub(crate) struct EnzymeWrapper { @@ -129,6 +135,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, EnzymePrintPerf: *mut c_void, EnzymePrintActivity: *mut c_void, @@ -302,6 +309,10 @@ pub(crate) mod Enzyme_AD { unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } + pub(crate) fn tree_to_md<'a>(&'a self, tree: *mut EnzymeTypeTree, ctx: &'a Context) -> Option<&'a Value> { + unsafe { (self.EnzymeTypeTreeToMD)(tree, ctx) } + } + pub(crate) fn get_max_type_depth(&self) -> usize { unsafe { std::ptr::read::(self.EnzymeMaxTypeDepth as *const u32) as usize } } @@ -385,6 +396,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, EnzymeSetCLBool: EnzymeSetCLBoolFn, EnzymeSetCLString: EnzymeSetCLStringFn, ); @@ -416,6 +428,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeInsertEq, EnzymeTypeTreeToString, EnzymeTypeTreeToStringFree, + EnzymeTypeTreeToMD, EnzymePrintPerf, EnzymePrintActivity, EnzymePrintType, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index ca35abe556b3b..5fb2849ccc492 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -119,7 +119,13 @@ pub(crate) fn add_tt<'tcx, 'll>( let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); dbg!(&fn_def); - if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + if llvm::LLVMRustSupportsEnzymeMD(fn_def) { + dbg!("extractvalue md"); + let md = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); + unsafe { + llvm::LLVMRustSetEnzymeTypeMD(fn_def, md.unwrap()); + } + } else if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { dbg!("callsite"); attributes::apply_to_callsite( fn_def, @@ -140,8 +146,14 @@ pub(crate) fn add_tt<'tcx, 'll>( let c_str = enzyme_wrapper.tree_to_cstr(v.inner); let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); dbg!(&fn_def); - if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { - dbg!("callsite"); + if llvm::LLVMRustSupportsEnzymeMD(fn_def) { + dbg!("extractvalue input(?)"); + let md = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); + unsafe { + llvm::LLVMRustSetEnzymeTypeMD(fn_def, md.unwrap()); + } + } else if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("callsite input"); attributes::apply_to_callsite( fn_def, llvm::AttributePlace::Argument(i as u32 + offset), @@ -161,28 +173,36 @@ pub(crate) fn add_tt<'tcx, 'll>( } // We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict. // Error, so we can learn from our mistakes. - if unsafe { !llvm::LLVMRustIsIntrinsicCall(fn_def) } { - let expected = offset as usize + inputs.len(); - let actual = llvm::count_params(fn_def) as usize; - if expected != actual { - tcx.dcx().warn(format!( - "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ - but the generated LLVM function has {actual} parameter(s)" - )); - } - } + //if unsafe { !llvm::LLVMRustIsIntrinsicCall(fn_def) } { + // dbg!("checking parameter count"); + // let expected = offset as usize + inputs.len(); + // let actual = llvm::count_params(fn_def) as usize; + // if expected != actual { + // tcx.dcx().warn(format!( + // "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ + // but the generated LLVM function has {actual} parameter(s)" + // )); + // } + //} let (enzyme_tt, extra_ints) = to_enzyme_typetree(&ret_tt, llvm_data_layout, llcx); if ret_tt != RustTypeTree::new() { let enzyme_wrapper = EnzymeWrapper::get_instance(); - if !extra_ints.is_empty() { - bug!("A return type should not have extra integers. Implementation bug!"); - } + //if !extra_ints.is_empty() { + // bug!("A return type should not have extra integers. Implementation bug!"); + //} let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + if llvm::LLVMRustSupportsEnzymeMD(fn_def) { + dbg!("extractvalue md"); + let md = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); + unsafe { + llvm::LLVMRustSetEnzymeTypeMD(fn_def, md.unwrap()); + } + } else if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("intrinsiccall"); attributes::apply_to_callsite( fn_def, llvm::AttributePlace::ReturnValue, diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 115c50edf4e9f..337feacc2c30e 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -2354,8 +2354,8 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if let Some(offset_from_start) = cast.rest_offset { assert_eq!(cast.prefix.len(), 1); assert_eq!(cast.rest.unit.size, cast.rest.total); - let first = bx.extract_value(value, 0); - let second = bx.extract_value(value, 1); + let first = bx.extract_value(value, 0, None); + let second = bx.extract_value(value, 1, None); bx.store(first, ptr, align); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); bx.store(second, second_ptr, align.restrict_for_offset(offset_from_start)); diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index c87ea83eacf62..d602049fe5495 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -21,6 +21,52 @@ use crate::MemFlags; use crate::common::IntPredicate; use crate::traits::*; + +use rustc_ast::expand::typetree::{TypeTree, FncTree}; +use rustc_middle::ty::typetree::typetree_from_ty; +//use rustc_middle::ty::typetree_from_ty; +use crate::TyCtxt; +use rustc_span::sym; + +fn option_ptr_like_scalar_pair_tts<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, +) -> Option { + let ty::Adt(def, args) = ty.kind() else { + return None; + }; + + if !tcx.is_lang_item(def.did(), LangItem::Option) { + return None; + } + + let inner = args.type_at(0); + if !(inner.is_ref() || inner.is_box() || nonnull_inner_ty(tcx, inner).is_some()) { + return None; + } + + let tt = typetree_from_ty(tcx, inner); + //let some_layout = layout.for_variant(bx.cx(), VariantIdx::from_u32(1)); + //let payload_layout = some_layout.field(bx.cx(), 0); + // this will be a slice + //let payload_ty = payload_layout.ty; + //let tt = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + if tt == TypeTree::new() { + return None; + } + Some(tt) +} + +fn nonnull_inner_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option> { + if let ty::Adt(def, args) = ty.kind() + && tcx.is_diagnostic_item(sym::NonNull, def.did()) + { + return Some(args.type_at(0)); + } + + None +} + /// The representation of a Rust value. The enum variant is in fact /// uniquely determined by the value's type, but is kept as a /// safety check. @@ -349,8 +395,14 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { debug!("Operand::from_immediate_or_packed_pair: unpacking {:?} @ {:?}", llval, layout); // Deconstruct the immediate aggregate. - let a_llval = bx.extract_value(llval, 0); - let b_llval = bx.extract_value(llval, 1); + let f1 = option_ptr_like_scalar_pair_tts(bx.tcx(), layout.ty); + let f2 = if f1.is_none() { + None + } else { + Some(TypeTree::int(8)) + }; + let a_llval = bx.extract_value(llval, 0, f1); + let b_llval = bx.extract_value(llval, 1, f2); OperandValue::Pair(a_llval, b_llval) } else { OperandValue::Immediate(llval) diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index eb83d39ec9ed8..4b3b315b060ed 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -579,7 +579,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn va_arg(&mut self, list: Self::Value, ty: Self::Type) -> Self::Value; fn extract_element(&mut self, vec: Self::Value, idx: Self::Value) -> Self::Value; fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value; - fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value; + fn extract_value(&mut self, agg_val: Self::Value, idx: u64, tt: Option) -> Self::Value; fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value; fn set_personality_fn(&mut self, personality: Self::Function); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index df50a0863e384..5e4608fcf201d 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -160,9 +160,28 @@ extern "C" void LLVMRustPrintStatisticsJSON(RustStringRef OutBuf) { llvm::PrintStatisticsJSON(OS); } +extern "C" bool LLVMRustSupportsEnzymeMD(LLVMValueRef V) { + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + return I && llvm::isa(I); +} + +extern "C" void LLVMRustSetEnzymeTypeMD(LLVMValueRef V, LLVMValueRef MDV) { + llvm::errs() << "setting MD" << "\n"; + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + assert(I && "expected instruction for !enzyme_type metadata"); + + auto *MAV = llvm::dyn_cast(llvm::unwrap(MDV)); + assert(MAV && "expected MetadataAsValue"); + + auto *MD = llvm::dyn_cast(MAV->getMetadata()); + assert(MD && "expected MDNode"); + + I->setMetadata("enzyme_type", MD); +} + extern "C" bool LLVMRustIsIntrinsicCall(LLVMValueRef V) { llvm::Value *Val = llvm::unwrap(V); -llvm:errs() << "LLVMRustIsIntrinsicCall: " << *Val << "\n"; + llvm::errs() << "LLVMRustIsIntrinsicCall: " << *Val << "\n"; if (auto *CB = llvm::dyn_cast(Val)) { if (auto *Callee = CB->getCalledFunction()) @@ -174,6 +193,7 @@ llvm:errs() << "LLVMRustIsIntrinsicCall: " << *Val << "\n"; if (auto *F = llvm::dyn_cast(Val)) return F->isIntrinsic(); + llvm::errs() << "LLVMRustIsIntrinsicCall: " << *Val << " nope \n"; return false; }