@@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t};
8
8
use llvm:: {
9
9
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
10
10
} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_codegen_ssa:: back:: link:: ensure_removed;
12
13
use rustc_codegen_ssa:: back:: versioned_llvm_target;
13
14
use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
28
29
use rustc_span:: InnerSpan ;
29
30
use rustc_span:: symbol:: sym;
30
31
use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31
- use tracing:: debug;
32
+ use tracing:: { debug, trace } ;
32
33
33
34
use crate :: back:: lto:: ThinBuffer ;
34
35
use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -517,9 +518,38 @@ pub(crate) unsafe fn llvm_optimize(
517
518
config : & ModuleConfig ,
518
519
opt_level : config:: OptLevel ,
519
520
opt_stage : llvm:: OptStage ,
521
+ skip_size_increasing_opts : bool ,
520
522
) -> Result < ( ) , FatalError > {
521
- let unroll_loops =
522
- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
523
+ // Enzyme:
524
+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
525
+ // source code. However, benchmarks show that optimizations increasing the code size
526
+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
527
+ // and finally re-optimize the module, now with all optimizations available.
528
+ // FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
529
+ // differentiated.
530
+
531
+ let unroll_loops;
532
+ let vectorize_slp;
533
+ let vectorize_loop;
534
+
535
+ // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
536
+ // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
537
+ // we should make this more granular, or at least check that the user has at least one autodiff
538
+ // call in their code, to justify altering the compilation pipeline.
539
+ if skip_size_increasing_opts && cfg ! ( llvm_enzyme) {
540
+ unroll_loops = false ;
541
+ vectorize_slp = false ;
542
+ vectorize_loop = false ;
543
+ } else {
544
+ unroll_loops =
545
+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
546
+ vectorize_slp = config. vectorize_slp ;
547
+ vectorize_loop = config. vectorize_loop ;
548
+ }
549
+ trace ! (
550
+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
551
+ unroll_loops, vectorize_slp, vectorize_loop
552
+ ) ;
523
553
let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
524
554
let pgo_gen_path = get_pgo_gen_path ( config) ;
525
555
let pgo_use_path = get_pgo_use_path ( config) ;
@@ -583,8 +613,8 @@ pub(crate) unsafe fn llvm_optimize(
583
613
using_thin_buffers,
584
614
config. merge_functions ,
585
615
unroll_loops,
586
- config . vectorize_slp ,
587
- config . vectorize_loop ,
616
+ vectorize_slp,
617
+ vectorize_loop,
588
618
config. no_builtins ,
589
619
config. emit_lifetime_markers ,
590
620
sanitizer_options. as_ref ( ) ,
@@ -606,6 +636,83 @@ pub(crate) unsafe fn llvm_optimize(
606
636
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
607
637
}
608
638
639
+ pub ( crate ) fn differentiate (
640
+ module : & ModuleCodegen < ModuleLlvm > ,
641
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
642
+ diff_items : Vec < AutoDiffItem > ,
643
+ config : & ModuleConfig ,
644
+ ) -> Result < ( ) , FatalError > {
645
+ for item in & diff_items {
646
+ trace ! ( "{}" , item) ;
647
+ }
648
+
649
+ let llmod = module. module_llvm . llmod ( ) ;
650
+ let llcx = & module. module_llvm . llcx ;
651
+ let diag_handler = cgcx. create_dcx ( ) ;
652
+
653
+ // Before dumping the module, we want all the tt to become part of the module.
654
+ for item in diff_items. iter ( ) {
655
+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
656
+ let fn_def: Option < & llvm:: Value > =
657
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
658
+ let fn_def = match fn_def {
659
+ Some ( x) => x,
660
+ None => {
661
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
662
+ src : item. source . clone ( ) ,
663
+ target : item. target . clone ( ) ,
664
+ error : "could not find source function" . to_owned ( ) ,
665
+ } ) ) ;
666
+ }
667
+ } ;
668
+ let target_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
669
+ debug ! ( "target name: {:?}" , & target_name) ;
670
+ let fn_target: Option < & llvm:: Value > =
671
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, target_name. as_ptr ( ) ) } ;
672
+ let fn_target = match fn_target {
673
+ Some ( x) => x,
674
+ None => {
675
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
676
+ src : item. source . clone ( ) ,
677
+ target : item. target . clone ( ) ,
678
+ error : "could not find target function" . to_owned ( ) ,
679
+ } ) ) ;
680
+ }
681
+ } ;
682
+
683
+ crate :: builder:: generate_enzyme_call ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
684
+ }
685
+
686
+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
687
+
688
+ if let Some ( opt_level) = config. opt_level {
689
+ let opt_stage = match cgcx. lto {
690
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
691
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
692
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
693
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
694
+ } ;
695
+ // This is our second opt call, so now we run all opts,
696
+ // to make sure we get the best performance.
697
+ let skip_size_increasing_opts = false ;
698
+ trace ! ( "running Module Optimization after differentiation" ) ;
699
+ unsafe {
700
+ llvm_optimize (
701
+ cgcx,
702
+ diag_handler. handle ( ) ,
703
+ module,
704
+ config,
705
+ opt_level,
706
+ opt_stage,
707
+ skip_size_increasing_opts,
708
+ ) ?
709
+ } ;
710
+ }
711
+ trace ! ( "done with differentiate()" ) ;
712
+
713
+ Ok ( ( ) )
714
+ }
715
+
609
716
// Unsafe due to LLVM calls.
610
717
pub ( crate ) unsafe fn optimize (
611
718
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -628,14 +735,29 @@ pub(crate) unsafe fn optimize(
628
735
unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
629
736
}
630
737
738
+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
739
+
631
740
if let Some ( opt_level) = config. opt_level {
632
741
let opt_stage = match cgcx. lto {
633
742
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
634
743
Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
635
744
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
636
745
_ => llvm:: OptStage :: PreLinkNoLTO ,
637
746
} ;
638
- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
747
+
748
+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
749
+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
750
+ return unsafe {
751
+ llvm_optimize (
752
+ cgcx,
753
+ dcx,
754
+ module,
755
+ config,
756
+ opt_level,
757
+ opt_stage,
758
+ skip_size_increasing_opts,
759
+ )
760
+ } ;
639
761
}
640
762
Ok ( ( ) )
641
763
}
0 commit comments