Skip to content

[Lux] Multithreading for JIT code #31

Open
@mratsim

Description

@mratsim

This issues track multithreading solution for JIT code.

Description

At the moment, Lux only target Nim and so can make use of OpenMP for threading.

In the future, Lux will probably add a JIT solution via LLVM IR, this will reduce code-size, code generation for specialized size and allow targeting new architectures that would otherwise require complex C extensions.

Example: Cuda introduce __global, magic like blockId * blockDim.x + threadIdx.x that requires some gymnastics for Nim to generate code and not throw "undefined".

Unfortunately, when doing JIT on CPU we lose OpenMP support as OpenMP is implemented in Clangand replaced by libraries call in LLVM IR. So we need an alternative solution.

Solutions to explore

  1. Reuse Nim threadpool library.

  2. Implement a threading library from scratch

  3. Wrap a C/C++ library (note that C++ will cause issues with cpuinfo with some compiler due to it using C99)

  4. Wait for OpenMP IR to be merged in LLVM see:

OpenMP code transformation

from https://stackoverflow.com/questions/52285368/how-does-llvm-translate-openmp-multi-threaded-code-with-runtime-library-calls

This OMP code

extern float foo( void );
int main () {
    int i;
    float r = 0.0;
    #pragma omp parallel for schedule(dynamic) reduction(+:r)
    for ( i = 0; i < 10; i ++ ) {
        r += foo();
    }
}

is transformed into

extern float foo( void );
int main () {
    static int zero = 0;
    auto int gtid;
    auto float r = 0.0;
    __kmpc_begin( & loc3, 0 );
    // The gtid is not actually required in this example so could be omitted;
    // We show its initialization here because it is often required for calls into
    // the runtime and should be locally cached like this.
    gtid = __kmpc_global thread num( & loc3 );
    __kmpc_fork call( & loc7, 1, main_7_parallel_3, & r );
    __kmpc_end( & loc0 );
    return 0;
}

struct main_10_reduction_t_5 { float r_10_rpr; };

static kmp_critical_name lck = { 0 };
static ident_t loc10; // loc10.flags should contain KMP_IDENT_ATOMIC_REDUCE bit set
                      // if compiler has generated an atomic reduction.
void main_7_parallel_3( int *gtid, int *btid, float *r_7_shp ) {
    auto int i_7_pr;
    auto int lower, upper, liter, incr;
    auto struct main_10_reduction_t_5 reduce;
    reduce.r_10_rpr = 0.F;
    liter = 0;
    __kmpc_dispatch_init_4( & loc7,*gtid, 35, 0, 9, 1, 1 );
    while ( __kmpc_dispatch_next_4( & loc7, *gtid, & liter, & lower, & upper, & incr
      ) ) {
        for( i_7_pr = lower; upper >= i_7_pr; i_7_pr ++ )
          reduce.r_10_rpr += foo();
    }
    switch( __kmpc_reduce_nowait( & loc10, *gtid, 1, 4, & reduce, main_10_reduce_5, &lck ) ) {
        case 1:
           *r_7_shp += reduce.r_10_rpr;
           __kmpc_end_reduce_nowait( & loc10, *gtid, & lck );
           break;
        case 2:
           __kmpc_atomic_float4_add( & loc10, *gtid, r_7_shp, reduce.r_10_rpr );
           break;
        default:;
    }
}

in LLVM IR:

[...]
; Function Attrs: nounwind uwtable
define dso_local i32 @main() local_unnamed_addr #0 {
entry:
  %i = alloca i32, align 4
  %r = alloca float, align 4
  %0 = bitcast i32* %i to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #4
  %1 = bitcast float* %r to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %1) #4
  store float 0.000000e+00, float* %r, align 4, !tbaa !2
  call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @0, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, float*)* @.omp_outlined. to void (i32*, i32*, ...)*), i32* nonnull %i, float* nonnull %r) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %1) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #4
  ret i32 0
}
[...]

; Function Attrs: norecurse nounwind uwtable
define internal void @.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i32* nocapture readnone dereferenceable(4) %i, float* nocapture dereferenceable(4) %r) #2 {
entry:
  %.omp.lb = alloca i32, align 4
  %.omp.ub = alloca i32, align 4
  %.omp.stride = alloca i32, align 4
  %.omp.is_last = alloca i32, align 4
  %r1 = alloca float, align 4
  %.omp.reduction.red_list = alloca [1 x i8*], align 8
  %0 = bitcast i32* %.omp.lb to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #4
  store i32 0, i32* %.omp.lb, align 4, !tbaa !6
  %1 = bitcast i32* %.omp.ub to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %1) #4
  store i32 9, i32* %.omp.ub, align 4, !tbaa !6
  %2 = bitcast i32* %.omp.stride to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %2) #4
  store i32 1, i32* %.omp.stride, align 4, !tbaa !6
  %3 = bitcast i32* %.omp.is_last to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %3) #4
  store i32 0, i32* %.omp.is_last, align 4, !tbaa !6
  %4 = bitcast float* %r1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %4) #4
  store float 0.000000e+00, float* %r1, align 4, !tbaa !2
  %5 = load i32, i32* %.global_tid., align 4, !tbaa !6
  tail call void @__kmpc_dispatch_init_4(%struct.ident_t* nonnull @0, i32 %5, i32 35, i32 0, i32 9, i32 1, i32 1) #4
  %6 = call i32 @__kmpc_dispatch_next_4(%struct.ident_t* nonnull @0, i32 %5, i32* nonnull %.omp.is_last, i32* nonnull %.omp.lb, i32* nonnull %.omp.ub, i32* nonnull %.omp.stride) #4
  %tobool14 = icmp eq i32 %6, 0
  br i1 %tobool14, label %omp.dispatch.end, label %omp.dispatch.body

omp.dispatch.cond.loopexit:                       ; preds = %omp.inner.for.body, %omp.dispatch.body
  %7 = call i32 @__kmpc_dispatch_next_4(%struct.ident_t* nonnull @0, i32 %5, i32* nonnull %.omp.is_last, i32* nonnull %.omp.lb, i32* nonnull %.omp.ub, i32* nonnull %.omp.stride) #4
  %tobool = icmp eq i32 %7, 0
  br i1 %tobool, label %omp.dispatch.end, label %omp.dispatch.body

omp.dispatch.body:                                ; preds = %entry, %omp.dispatch.cond.loopexit
  %8 = load i32, i32* %.omp.lb, align 4, !tbaa !6
  %9 = load i32, i32* %.omp.ub, align 4, !tbaa !6, !llvm.mem.parallel_loop_access !8
  %cmp12 = icmp sgt i32 %8, %9
  br i1 %cmp12, label %omp.dispatch.cond.loopexit, label %omp.inner.for.body

omp.inner.for.body:                               ; preds = %omp.dispatch.body, %omp.inner.for.body
  %.omp.iv.013 = phi i32 [ %add4, %omp.inner.for.body ], [ %8, %omp.dispatch.body ]
  %call = call float @foo() #4, !llvm.mem.parallel_loop_access !8
  %10 = load float, float* %r1, align 4, !tbaa !2, !llvm.mem.parallel_loop_access !8
  %add3 = fadd float %call, %10
  store float %add3, float* %r1, align 4, !tbaa !2, !llvm.mem.parallel_loop_access !8
  %add4 = add nsw i32 %.omp.iv.013, 1
  %11 = load i32, i32* %.omp.ub, align 4, !tbaa !6, !llvm.mem.parallel_loop_access !8
  %cmp = icmp slt i32 %.omp.iv.013, %11
  br i1 %cmp, label %omp.inner.for.body, label %omp.dispatch.cond.loopexit, !llvm.loop !8

omp.dispatch.end:                                 ; preds = %omp.dispatch.cond.loopexit, %entry
  %12 = bitcast [1 x i8*]* %.omp.reduction.red_list to float**
  store float* %r1, float** %12, align 8
  %13 = bitcast [1 x i8*]* %.omp.reduction.red_list to i8*
  %14 = call i32 @__kmpc_reduce_nowait(%struct.ident_t* nonnull @1, i32 %5, i32 1, i64 8, i8* nonnull %13, void (i8*, i8*)* nonnull @.omp.reduction.reduction_func, [8 x i32]* nonnull @.gomp_critical_user_.reduction.var) #4
  switch i32 %14, label %.omp.reduction.default [
    i32 1, label %.omp.reduction.case1
    i32 2, label %.omp.reduction.case2
  ]

.omp.reduction.case1:                             ; preds = %omp.dispatch.end
  %15 = load float, float* %r, align 4, !tbaa !2
  %16 = load float, float* %r1, align 4, !tbaa !2
  %add5 = fadd float %15, %16
  store float %add5, float* %r, align 4, !tbaa !2
  call void @__kmpc_end_reduce_nowait(%struct.ident_t* nonnull @1, i32 %5, [8 x i32]* nonnull @.gomp_critical_user_.reduction.var) #4
  br label %.omp.reduction.default

.omp.reduction.case2:                             ; preds = %omp.dispatch.end
  %17 = bitcast float* %r to i32*
  %atomic-load = load atomic i32, i32* %17 monotonic, align 4, !tbaa !2
  %18 = load float, float* %r1, align 4, !tbaa !2
  br label %atomic_cont

atomic_cont:                                      ; preds = %atomic_cont, %.omp.reduction.case2
  %19 = phi i32 [ %atomic-load, %.omp.reduction.case2 ], [ %23, %atomic_cont ]
  %20 = bitcast i32 %19 to float
  %add7 = fadd float %18, %20
  %21 = bitcast float %add7 to i32
  %22 = cmpxchg i32* %17, i32 %19, i32 %21 monotonic monotonic
  %23 = extractvalue { i32, i1 } %22, 0
  %24 = extractvalue { i32, i1 } %22, 1
  br i1 %24, label %.omp.reduction.default, label %atomic_cont

.omp.reduction.default:                           ; preds = %atomic_cont, %.omp.reduction.case1, %omp.dispatch.end
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %4) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %3) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %2) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %1) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #4
  ret void
}
[...]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions