Skip to content

Commit 929b5f8

Browse files
authored
Rollup merge of rust-lang#130060 - EnzymeAD:enzyme-cg-llvm, r=davidtwco
Autodiff Upstreaming - rustc_codegen_llvm changes Now that the autodiff/Enzyme backend is merged, this is an upstream PR for the `rustc_codegen_llvm` changes. It also includes small changes to three files under `compiler/rustc_ast`, which overlap with my frontend PR (rust-lang#129458). Here I only include minimal definitions of structs and enums to be able to build this backend code. The same goes for minimal changes to `compiler/rustc_codegen_ssa`, the majority of changes there will be in another PR, once either this or the frontend gets merged. We currently have 68 files left to merge, 19 in the frontend PR, 21 (+3 from the frontend) in this PR, and then ~30 in the middle-end. This PR is large because it includes two of my three large files (~800 loc each). I could also first only upstream enzyme_ffi.rs, but I think people might want to see some use of these bindings in the same PR? To already highlight the things which reviewers might want to discuss: 1) `enzyme_ffi.rs`: I do have a fallback module to make sure that we don't link rustc against Enzyme when we build rustc without autodiff support. 2) `add_panic_msg_to_global` was a pain to write and I currently can't even use it. Enzyme writes gradients into shadow memory. Pass in one float scalar? We'll allocate and return an extra float telling you how this float affected the output. Pass in a slice of floats? We'll let you allocate the vector and pass in a mutable reference to a float slice, we'll then write the gradient into that slice. It should be at least as large as your original slice, so we check that and panic if not. Currently we panic silently, but I already generate a nicer panic message with this function. I just don't know how to print it to the user. yet. I discussed this with a few rustc devs and the best we could come up with (for now), was to look for mangled panic calls in the IR and pick one, which works surprisingly reliably. If someone knows a good way to clean this up and print the panic message I'm all in, otherwise I can remove the code that writes the nicer panic message and keep the silent panic, since it's enough for soundness. Especially since this PR is already a bit larger. 3) `SanitizeHWAddress`: When differentiating C++, Enzyme can use TBAA to "understand" enums/unions, but for Rust we don't have this information. LLVM might to speculative loads which (without TBAA) confuse Enzyme, so we disable those with this attribute. This attribute is only set during the first opt run before Enzyme differentiates code. We then remove it again once we are done with autodiff and run the opt pipeline a second time. Since enums are everywhere in Rust, support for them is crucial, but if this looks too cursed I can remove these ~100 lines and keep them in my fork for now, we can then discuss them separately to make this PR simpler? 4) Duplicated llvm-opt runs: Differentiating already optimized code (and being able to do additional optimizations on the fly, e.g. for GPU code) is _the_ reason why Enzyme is so fast, so the compile time is acceptable for autodiff users: https://enzyme.mit.edu/talks/Publications/ (There are also algorithmic issues in Enzyme core which are more serious than running opt twice). 5) I assume that if we merge these minimal cg_ssa changes here already, I also need to fix the other backends (GCC and cliff) to have dummy implementations, correct? 6) *I'm happy to split this PR up further if reviewers have recommendations on how to.* For the full implementation, see: rust-lang#129175 Tracking: - rust-lang#124509
2 parents f4f0faf + 3d2e36c commit 929b5f8

File tree

13 files changed

+561
-27
lines changed

13 files changed

+561
-27
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+3-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -262,22 +257,14 @@ impl AutoDiffAttrs {
262257
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263258
}
264259

265-
pub fn into_item(
266-
self,
267-
source: String,
268-
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271-
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
260+
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
261+
AutoDiffItem { source, target, attrs: self }
273262
}
274263
}
275264

276265
impl fmt::Display for AutoDiffItem {
277266
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278267
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
268+
write!(f, " with attributes: {:?}", self.attrs)
282269
}
283270
}

compiler/rustc_codegen_gcc/src/lib.rs

+9
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ use gccjit::{CType, Context, OptimizationLevel};
9393
#[cfg(feature = "master")]
9494
use gccjit::{TargetInfo, Version};
9595
use rustc_ast::expand::allocator::AllocatorKind;
96+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
9697
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
9798
use rustc_codegen_ssa::back::write::{
9899
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@@ -439,6 +440,14 @@ impl WriteBackendMethods for GccCodegenBackend {
439440
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
440441
back::write::link(cgcx, dcx, modules)
441442
}
443+
fn autodiff(
444+
_cgcx: &CodegenContext<Self>,
445+
_module: &ModuleCodegen<Self::Module>,
446+
_diff_fncs: Vec<AutoDiffItem>,
447+
_config: &ModuleConfig,
448+
) -> Result<(), FatalError> {
449+
unimplemented!()
450+
}
442451
}
443452

444453
/// This is the entrypoint for a hot plugged rustc_codegen_gccjit

compiler/rustc_codegen_llvm/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
2+
13
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
24
35
codegen_llvm_dynamic_linking_with_lto =
@@ -47,6 +49,8 @@ codegen_llvm_parse_bitcode_with_llvm_err = failed to parse bitcode for LTO modul
4749
codegen_llvm_parse_target_machine_config =
4850
failed to parse target machine config to target machine: {$error}
4951
52+
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
53+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
5054
codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context
5155
codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err}
5256

compiler/rustc_codegen_llvm/src/back/lto.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,14 @@ pub(crate) fn run_pass_manager(
604604
debug!("running the pass manager");
605605
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
606606
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
607-
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
607+
608+
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
609+
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
610+
let first_run = true;
611+
debug!("running llvm pm opt pipeline");
612+
unsafe {
613+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
614+
}
608615
debug!("lto done");
609616
Ok(())
610617
}

compiler/rustc_codegen_llvm/src/back/write.rs

+128-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t};
88
use llvm::{
99
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1010
};
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_codegen_ssa::back::link::ensure_removed;
1213
use rustc_codegen_ssa::back::versioned_llvm_target;
1314
use rustc_codegen_ssa::back::write::{
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829
use rustc_span::symbol::sym;
2930
use rustc_span::{BytePos, InnerSpan, Pos, SpanData, SyntaxContext};
3031
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
31-
use tracing::debug;
32+
use tracing::{debug, trace};
3233

3334
use crate::back::lto::ThinBuffer;
3435
use crate::back::owned_target_machine::OwnedTargetMachine;
@@ -530,9 +531,38 @@ pub(crate) unsafe fn llvm_optimize(
530531
config: &ModuleConfig,
531532
opt_level: config::OptLevel,
532533
opt_stage: llvm::OptStage,
534+
skip_size_increasing_opts: bool,
533535
) -> Result<(), FatalError> {
534-
let unroll_loops =
535-
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
536+
// Enzyme:
537+
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
538+
// source code. However, benchmarks show that optimizations increasing the code size
539+
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
540+
// and finally re-optimize the module, now with all optimizations available.
541+
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
542+
// differentiated.
543+
544+
let unroll_loops;
545+
let vectorize_slp;
546+
let vectorize_loop;
547+
548+
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
549+
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
550+
// we should make this more granular, or at least check that the user has at least one autodiff
551+
// call in their code, to justify altering the compilation pipeline.
552+
if skip_size_increasing_opts && cfg!(llvm_enzyme) {
553+
unroll_loops = false;
554+
vectorize_slp = false;
555+
vectorize_loop = false;
556+
} else {
557+
unroll_loops =
558+
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
559+
vectorize_slp = config.vectorize_slp;
560+
vectorize_loop = config.vectorize_loop;
561+
}
562+
trace!(
563+
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
564+
unroll_loops, vectorize_slp, vectorize_loop
565+
);
536566
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
537567
let pgo_gen_path = get_pgo_gen_path(config);
538568
let pgo_use_path = get_pgo_use_path(config);
@@ -596,8 +626,8 @@ pub(crate) unsafe fn llvm_optimize(
596626
using_thin_buffers,
597627
config.merge_functions,
598628
unroll_loops,
599-
config.vectorize_slp,
600-
config.vectorize_loop,
629+
vectorize_slp,
630+
vectorize_loop,
601631
config.no_builtins,
602632
config.emit_lifetime_markers,
603633
sanitizer_options.as_ref(),
@@ -619,6 +649,83 @@ pub(crate) unsafe fn llvm_optimize(
619649
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
620650
}
621651

652+
pub(crate) fn differentiate(
653+
module: &ModuleCodegen<ModuleLlvm>,
654+
cgcx: &CodegenContext<LlvmCodegenBackend>,
655+
diff_items: Vec<AutoDiffItem>,
656+
config: &ModuleConfig,
657+
) -> Result<(), FatalError> {
658+
for item in &diff_items {
659+
trace!("{}", item);
660+
}
661+
662+
let llmod = module.module_llvm.llmod();
663+
let llcx = &module.module_llvm.llcx;
664+
let diag_handler = cgcx.create_dcx();
665+
666+
// Before dumping the module, we want all the tt to become part of the module.
667+
for item in diff_items.iter() {
668+
let name = CString::new(item.source.clone()).unwrap();
669+
let fn_def: Option<&llvm::Value> =
670+
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
671+
let fn_def = match fn_def {
672+
Some(x) => x,
673+
None => {
674+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
675+
src: item.source.clone(),
676+
target: item.target.clone(),
677+
error: "could not find source function".to_owned(),
678+
}));
679+
}
680+
};
681+
let target_name = CString::new(item.target.clone()).unwrap();
682+
debug!("target name: {:?}", &target_name);
683+
let fn_target: Option<&llvm::Value> =
684+
unsafe { llvm::LLVMGetNamedFunction(llmod, target_name.as_ptr()) };
685+
let fn_target = match fn_target {
686+
Some(x) => x,
687+
None => {
688+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
689+
src: item.source.clone(),
690+
target: item.target.clone(),
691+
error: "could not find target function".to_owned(),
692+
}));
693+
}
694+
};
695+
696+
crate::builder::generate_enzyme_call(llmod, llcx, fn_def, fn_target, item.attrs.clone());
697+
}
698+
699+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
700+
701+
if let Some(opt_level) = config.opt_level {
702+
let opt_stage = match cgcx.lto {
703+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
704+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
705+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
706+
_ => llvm::OptStage::PreLinkNoLTO,
707+
};
708+
// This is our second opt call, so now we run all opts,
709+
// to make sure we get the best performance.
710+
let skip_size_increasing_opts = false;
711+
trace!("running Module Optimization after differentiation");
712+
unsafe {
713+
llvm_optimize(
714+
cgcx,
715+
diag_handler.handle(),
716+
module,
717+
config,
718+
opt_level,
719+
opt_stage,
720+
skip_size_increasing_opts,
721+
)?
722+
};
723+
}
724+
trace!("done with differentiate()");
725+
726+
Ok(())
727+
}
728+
622729
// Unsafe due to LLVM calls.
623730
pub(crate) unsafe fn optimize(
624731
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -641,14 +748,29 @@ pub(crate) unsafe fn optimize(
641748
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
642749
}
643750

751+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
752+
644753
if let Some(opt_level) = config.opt_level {
645754
let opt_stage = match cgcx.lto {
646755
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
647756
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
648757
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
649758
_ => llvm::OptStage::PreLinkNoLTO,
650759
};
651-
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
760+
761+
// If we know that we will later run AD, then we disable vectorization and loop unrolling
762+
let skip_size_increasing_opts = cfg!(llvm_enzyme);
763+
return unsafe {
764+
llvm_optimize(
765+
cgcx,
766+
dcx,
767+
module,
768+
config,
769+
opt_level,
770+
opt_stage,
771+
skip_size_increasing_opts,
772+
)
773+
};
652774
}
653775
Ok(())
654776
}

0 commit comments

Comments
 (0)