Skip to content

gpu offload host code generation #142097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 283 additions & 11 deletions compiler/rustc_codegen_llvm/src/back/lto.rs

Large diffs are not rendered by default.

155 changes: 95 additions & 60 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
};
call
}

}

impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
Expand Down Expand Up @@ -118,6 +119,40 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
}
bx
}

pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
let val = unsafe {
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
// Cast to default addrspace if necessary
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
};
if name != "" {
let name = std::ffi::CString::new(name).unwrap();
unsafe {llvm::set_value_name(val, &name.as_bytes())};
}
val
}

pub(crate) fn inbounds_gep(
&mut self,
ty: &'ll Type,
ptr: &'ll Value,
indices: &[&'ll Value],
) -> &'ll Value {
unsafe {
llvm::LLVMBuildGEPWithNoWrapFlags(
self.llbuilder,
ty,
ptr,
indices.as_ptr(),
indices.len() as c_uint,
UNNAMED,
GEPNoWrapFlags::InBounds,
)
}
}

}

/// Empty string, to be used where LLVM expects an instruction name, indicating
Expand Down Expand Up @@ -1261,7 +1296,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
unsafe {
llvm::LLVMBuildCleanupRet(self.llbuilder, funclet.cleanuppad(), unwind)
.expect("LLVM does not have support for cleanupret");
}
}
}

fn catch_pad(&mut self, parent: &'ll Value, args: &[&'ll Value]) -> Funclet<'ll> {
Expand Down Expand Up @@ -1631,14 +1666,14 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
debug!(
"type mismatch in function call of {:?}. \
Expected {:?} for param {}, got {:?}; injecting bitcast",
llfn, expected_ty, i, actual_ty
llfn, expected_ty, i, actual_ty
);
self.bitcast(actual_val, expected_ty)
} else {
actual_val
}
})
.collect();
.collect();

Cow::Owned(casted_args)
}
Expand Down Expand Up @@ -1791,48 +1826,48 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let is_indirect_call = unsafe { llvm::LLVMRustIsNonGVFunctionPointerTy(llfn) };
if self.tcx.sess.is_sanitizer_cfi_enabled()
&& let Some(fn_abi) = fn_abi
&& is_indirect_call
{
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
&& is_indirect_call
{
return;
}
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
{
return;
}

let mut options = cfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
}
let mut options = cfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
}

let typeid = if let Some(instance) = instance {
cfi::typeid_for_instance(self.tcx, instance, options)
} else {
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
let dbg_loc = self.get_dbg_loc();

// Test whether the function pointer is associated with the type identifier.
let cond = self.type_test(llfn, typeid_metadata);
let bb_pass = self.append_sibling_block("type_test.pass");
let bb_fail = self.append_sibling_block("type_test.fail");
self.cond_br(cond, bb_pass, bb_fail);

self.switch_to_block(bb_fail);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
self.abort();
self.unreachable();
let typeid = if let Some(instance) = instance {
cfi::typeid_for_instance(self.tcx, instance, options)
} else {
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
let dbg_loc = self.get_dbg_loc();

// Test whether the function pointer is associated with the type identifier.
let cond = self.type_test(llfn, typeid_metadata);
let bb_pass = self.append_sibling_block("type_test.pass");
let bb_fail = self.append_sibling_block("type_test.fail");
self.cond_br(cond, bb_pass, bb_fail);

self.switch_to_block(bb_fail);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
self.abort();
self.unreachable();

self.switch_to_block(bb_pass);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
self.switch_to_block(bb_pass);
if let Some(dbg_loc) = dbg_loc {
self.set_dbg_loc(dbg_loc);
}
}
}
}

// Emits KCFI operand bundles.
Expand All @@ -1847,31 +1882,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let kcfi_bundle = if self.tcx.sess.is_sanitizer_kcfi_enabled()
&& let Some(fn_abi) = fn_abi
&& is_indirect_call
{
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
{
return None;
}
if let Some(fn_attrs) = fn_attrs
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
{
return None;
}

let mut options = kcfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
}
let mut options = kcfi::TypeIdOptions::empty();
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
}
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
}

let kcfi_typeid = if let Some(instance) = instance {
kcfi::typeid_for_instance(self.tcx, instance, options)
} else {
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
};

let kcfi_typeid = if let Some(instance) = instance {
kcfi::typeid_for_instance(self.tcx, instance, options)
Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
} else {
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
None
};

Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
} else {
None
};
kcfi_bundle
}

Expand Down
16 changes: 14 additions & 2 deletions compiler/rustc_codegen_llvm/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericCx<'ll, CX> {
type DIVariable = &'ll llvm::debuginfo::DIVariable;
}

impl<'ll> CodegenCx<'ll, '_> {
impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
pub(crate) fn const_array(&self, ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value {
let len = u64::try_from(elts.len()).expect("LLVMConstArray2 elements len overflow");
unsafe { llvm::LLVMConstArray2(ty, elts.as_ptr(), len) }
}

pub(crate) fn const_bytes(&self, bytes: &[u8]) -> &'ll Value {
bytes_in_context(self.llcx, bytes)
bytes_in_context(self.llcx(), bytes)
}

pub(crate) fn const_get_elt(&self, v: &'ll Value, idx: u64) -> &'ll Value {
Expand All @@ -119,6 +119,10 @@ impl<'ll> CodegenCx<'ll, '_> {
r
}
}

pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMConstNull(t) }
}
}

impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> {
Expand Down Expand Up @@ -373,6 +377,14 @@ pub(crate) fn bytes_in_context<'ll>(llcx: &'ll llvm::Context, bytes: &[u8]) -> &
}
}

pub(crate) fn named_struct<'ll>(
ty: &'ll Type,
elts: &[&'ll Value],
) -> &'ll Value {
let len = c_uint::try_from(elts.len()).expect("LLVMConstStructInContext elements len overflow");
unsafe { llvm::LLVMConstNamedStruct(ty, elts.as_ptr(), len) }
}

fn struct_in_context<'ll>(
llcx: &'ll llvm::Context,
elts: &[&'ll Value],
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,21 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt32TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt16TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt8TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}

pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
Expand Down
7 changes: 5 additions & 2 deletions compiler/rustc_codegen_llvm/src/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
)
}
}

}

impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
Expand Down Expand Up @@ -215,7 +216,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {

llfn
}
}

impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
/// Declare a global with an intention to define it.
///
/// Use this function when you intend to define a global. This function will
Expand All @@ -234,13 +237,13 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
///
/// Use this function when you intend to define a global without a name.
pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) }
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) }
}

/// Gets declared value by name.
pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> {
debug!("get_declared_value(name={:?})", name);
unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) }
unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) }
}

/// Gets defined or externally defined (AvailableExternally linkage) value by
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use libc::{c_char, c_uint};

use super::MetadataKindId;
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
use crate::llvm::Bool;
use crate::llvm::{Bool, Builder};

#[link(name = "llvm-wrapper", kind = "static")]
unsafe extern "C" {
Expand All @@ -32,6 +32,14 @@ unsafe extern "C" {
index: c_uint,
kind: AttributeKind,
);
pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
pub(crate) fn LLVMRustGetFunctionCall(
F: &Value,
name: *const c_char,
NameLen: libc::size_t,
) -> Option<&Value>;

}

unsafe extern "C" {
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,11 @@ unsafe extern "C" {
Count: c_uint,
Packed: Bool,
) -> &'a Value;
pub(crate) fn LLVMConstNamedStruct<'a>(
StructTy: &'a Type,
ConstantVals: *const &'a Value,
Count: c_uint,
) -> &'a Value;
pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value;

// Constant expressions
Expand Down Expand Up @@ -1209,6 +1214,7 @@ unsafe extern "C" {
) -> &'a BasicBlock;

// Operations on instructions
pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock;
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;
Expand Down Expand Up @@ -2554,6 +2560,7 @@ unsafe extern "C" {

pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine);

pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value);
pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock);

pub(crate) fn LLVMRustSetModulePICLevel(M: &Module);
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct ModuleConfig {
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
pub offload: Vec<config::Offload>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -270,6 +271,7 @@ impl ModuleConfig {
emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]),
}
}

Expand Down
Loading