Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum Kind {
Anything,
Integer,
Pointer,
RustSlice,
Half,
Float,
Double,
Expand Down Expand Up @@ -57,6 +58,9 @@ impl TypeTree {
}
Self(ints)
}
pub fn add_indirection(self) -> Self {
Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }])
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
let value = if output_types.len() == 1 {
result
} else {
self.extract_value(result, op_idx[&idx] as u64)
self.extract_value(result, op_idx[&idx] as u64, None)
};
let value =
llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
Expand Down
23 changes: 16 additions & 7 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::iter;
use std::ops::Deref;

use rustc_ast::expand::typetree::FncTree;
use rustc_ast::expand::typetree::TypeTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

Expand Down Expand Up @@ -631,7 +632,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
let name = format!("llvm.{}{oop_str}.with.overflow", if signed { 's' } else { 'u' });

let res = self.call_intrinsic(name, &[self.type_ix(width)], &[lhs, rhs]);
(self.extract_value(res, 0), self.extract_value(res, 1))
(self.extract_value(res, 0, None), self.extract_value(res, 1, None))
}

fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
Expand Down Expand Up @@ -1175,7 +1176,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, memcpy, tt);
}
}

Expand Down Expand Up @@ -1257,9 +1258,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn extract_value(&mut self, agg_val: &'ll Value, idx: u64) -> &'ll Value {
fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option<TypeTree>) -> &'ll Value {
assert_eq!(idx as c_uint as u64, idx);
unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) }
let extract = unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) };
if let Some(tt) = tt {
let fnc_tree = FncTree {
args: vec![],
ret: tt,
};
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, extract, fnc_tree);
}
extract
}

fn insert_value(&mut self, agg_val: &'ll Value, elt: &'ll Value, idx: u64) -> &'ll Value {
Expand All @@ -1279,7 +1288,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
unsafe {
llvm::LLVMSetCleanup(landing_pad, llvm::TRUE);
}
(self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1))
(self.extract_value(landing_pad, 0, None), self.extract_value(landing_pad, 1, None))
}

fn filter_landing_pad(&mut self, pers_fn: &'ll Value) {
Expand Down Expand Up @@ -1380,8 +1389,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
llvm::FALSE, // SingleThreaded
);
llvm::LLVMSetWeak(value, weak.to_llvm_bool());
let val = self.extract_value(value, 0);
let success = self.extract_value(value, 1);
let val = self.extract_value(value, 0, None);
let success = self.extract_value(value, 1, None);
(val, success)
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_name: &str,
Expand Down Expand Up @@ -379,7 +380,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
crate::typetree::add_tt(cx.llmod, cx.llcx, tcx, fn_to_diff, fnc_tree);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ impl<'ll> OffloadKernelDims<'ll> {
builder: &mut Builder<'_, 'll, 'tcx>,
arr: &'ll Value,
) -> &'ll Value {
let x = builder.extract_value(arr, 0);
let y = builder.extract_value(arr, 1);
let z = builder.extract_value(arr, 2);
let x = builder.extract_value(arr, 0, None);
let y = builder.extract_value(arr, 1, None);
let z = builder.extract_value(arr, 2, None);

let xy = builder.mul(x, y);
builder.mul(xy, z)
Expand Down
17 changes: 11 additions & 6 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::errors::{
AutoDiffWithoutEnable, AutoDiffWithoutLto, IntrinsicSignatureMismatch, IntrinsicWrongArch,
OffloadWithoutEnable, OffloadWithoutFatLTO, UnknownIntrinsic,
};
use crate::intrinsic::ty::typetree::fnc_typetrees;
use crate::intrinsic::ty::typetree::{fnc_typetrees, typetree_from_ty};
use crate::llvm::{self, Type, Value};
use crate::type_of::LayoutLlvmExt;
use crate::va_arg::emit_va_arg;
Expand Down Expand Up @@ -776,6 +776,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
self.extract_value(
args[0].immediate(),
fn_args.const_at(2).to_leaf().to_i32() as u64,
None,
)
}

Expand Down Expand Up @@ -1059,7 +1060,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
&[],
&[llvtable, vtable_byte_offset, typeid],
);
self.extract_value(type_checked_load, 0)
self.extract_value(type_checked_load, 0, None)
}

fn va_start(&mut self, va_list: &'ll Value) {
Expand Down Expand Up @@ -1171,7 +1172,7 @@ fn autocast<'ll>(
iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty))
.enumerate()
{
let elt = bx.extract_value(val, idx as u64);
let elt = bx.extract_value(val, idx as u64, None);
let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty);
ret = bx.insert_value(ret, casted_elt, idx as u64);
}
Expand Down Expand Up @@ -1632,7 +1633,7 @@ fn codegen_gnu_try<'ll, 'tcx>(
let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 1);
let tydesc = bx.const_null(bx.type_ptr());
bx.add_clause(vals, tydesc);
let ptr = bx.extract_value(vals, 0);
let ptr = bx.extract_value(vals, 0, None);
let catch_ty = bx.type_func(&[bx.type_ptr(), bx.type_ptr()], bx.type_void());
bx.call(catch_ty, None, None, catch_func, &[data, ptr], None, None);
bx.ret(bx.const_bool(true));
Expand Down Expand Up @@ -1793,6 +1794,7 @@ fn codegen_autodiff<'ll, 'tcx>(
// Build body
generate_enzyme_call(
bx,
tcx,
bx.cx,
fn_to_diff,
&diff_symbol,
Expand Down Expand Up @@ -1916,8 +1918,11 @@ fn get_args_from_tuple<'ll, 'tcx>(
let field = tuple_place.project_field(bx, tuple_index);
let llvm_ty = field.layout.llvm_type(bx.cx);
let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
result.push(bx.extract_value(pair_val, 0));
result.push(bx.extract_value(pair_val, 1));

let extract_ty = field.layout.ty;
let tt = typetree_from_ty(bx.tcx(), extract_ty);
result.push(bx.extract_value(pair_val, 0, Some(tt.clone())));
result.push(bx.extract_value(pair_val, 1, Some(tt)));
tuple_index += 1;
}
PassMode::Indirect { .. } => {
Expand Down
21 changes: 20 additions & 1 deletion compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ unsafe extern "C" {
NameLen: libc::size_t,
) -> Option<&Value>;

pub(crate) fn LLVMRustIsIntrinsicCall(V: &Value) -> bool;
pub(crate) safe fn LLVMRustSupportsEnzymeMD(V: &Value) -> bool;
pub(crate) fn LLVMRustSetEnzymeTypeMD(
v: &Value,
md: &Value,
);
}

unsafe extern "C" {
Expand Down Expand Up @@ -96,7 +102,7 @@ pub(crate) mod Enzyme_AD {
use rustc_session::filesearch;

use super::{CConcreteType, CTypeTreeRef, Context};
use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor, Value};

type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
Expand All @@ -114,6 +120,7 @@ pub(crate) mod Enzyme_AD {
unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
type EnzymeTypeTreeToMDFn = unsafe extern "C" fn(CTypeTreeRef, &Context) -> Option<& Value>;

#[allow(non_snake_case)]
pub(crate) struct EnzymeWrapper {
Expand All @@ -128,6 +135,7 @@ pub(crate) mod Enzyme_AD {
EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn,

EnzymePrintPerf: *mut c_void,
EnzymePrintActivity: *mut c_void,
Expand Down Expand Up @@ -292,10 +300,19 @@ pub(crate) mod Enzyme_AD {
unsafe { (self.EnzymeTypeTreeToString)(tree) }
}

pub(crate) fn tree_to_cstr(&self, tree: *mut EnzymeTypeTree) -> &std::ffi::CStr {
let c_str = self.tree_to_string(tree);
unsafe { std::ffi::CStr::from_ptr(c_str) }
}

pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
}

pub(crate) fn tree_to_md<'a>(&'a self, tree: *mut EnzymeTypeTree, ctx: &'a Context) -> Option<&'a Value> {
unsafe { (self.EnzymeTypeTreeToMD)(tree, ctx) }
}

pub(crate) fn get_max_type_depth(&self) -> usize {
unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
}
Expand Down Expand Up @@ -379,6 +396,7 @@ pub(crate) mod Enzyme_AD {
EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn,
EnzymeSetCLBool: EnzymeSetCLBoolFn,
EnzymeSetCLString: EnzymeSetCLStringFn,
);
Expand Down Expand Up @@ -410,6 +428,7 @@ pub(crate) mod Enzyme_AD {
EnzymeTypeTreeInsertEq,
EnzymeTypeTreeToString,
EnzymeTypeTreeToStringFree,
EnzymeTypeTreeToMD,
EnzymePrintPerf,
EnzymePrintActivity,
EnzymePrintType,
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ pub(crate) fn CreateAttrStringValue<'ll>(
)
}
}
pub(crate) fn CreateAttrStringValueFromCStr<'ll>(
llcx: &'ll Context,
attr: &std::ffi::CStr,
value: &std::ffi::CStr,
) -> &'ll Attribute {
unsafe {
LLVMCreateStringAttribute(
llcx,
(*attr).as_ptr(),
(*attr).to_bytes().len() as c_uint,
(*value).as_ptr(),
(*value).to_bytes().len() as c_uint,
)
}
}

pub(crate) fn CreateAttrString<'ll>(llcx: &'ll Context, attr: &str) -> &'ll Attribute {
unsafe {
Expand Down
Loading
Loading