Skip to content

Commit 80aec58

Browse files
lukstaficlaude
andcommitted
fix: restore cross-device synchronization in device_to_device copies
When source and destination are on different devices, wait for source stream writes to complete before scheduling the copy on the destination stream. Without this, async CUDA/Metal copies could read stale data. Addresses review feedback on PR #12. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 692d8c9 commit 80aec58

4 files changed

Lines changed: 42 additions & 35 deletions

File tree

arrayjit/lib/backend_intf.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ end
3636

3737
type merge_buffer_use = No | Copy [@@deriving sexp_of]
3838

39-
type kparam_source =
39+
type param_source =
4040
| Log_file_name
4141
| Merge_buffer
42-
| Kparam_ptr of Tnode.t
42+
| Param_ptr of Tnode.t
4343
| Static_idx of Indexing.static_symbol
4444
[@@deriving sexp_of]
4545

arrayjit/lib/backends.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8888
match Map.find src.ctx_arrays tn with
8989
| None -> false
9090
| Some s_arr -> (
91+
(* For cross-device copies, wait for the source stream's writes to complete
92+
before the destination stream reads the data. *)
93+
if not same_device then
94+
Backend.will_wait_for dst (Backend.all_work src.stream);
9195
match into_merge_buffer with
9296
| No -> (
9397
match Map.find dst.ctx_arrays tn with
@@ -122,6 +126,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
122126
("init_from_device: tensor node " ^ Tn.debug_name tn
123127
^ " already on same stream, for stream " ^ Backend.get_name src.stream)
124128
else (
129+
(* For cross-device copies, wait for source writes to complete. *)
130+
if not same_device then
131+
Backend.will_wait_for dst (Backend.all_work src.stream);
125132
match Map.find dst.ctx_arrays tn with
126133
| Some _ ->
127134
raise

arrayjit/lib/cuda_backend.ml

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
284284
type code = {
285285
traced_store : Low_level.traced_store;
286286
ptx : Nvrtc.compile_to_ptx_result;
287-
kparams : (string * kparam_source) list;
287+
params : (string * param_source) list;
288288
bindings : Indexing.unit_bindings;
289289
name : string;
290290
}
@@ -294,7 +294,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
294294
traced_stores : Low_level.traced_store option array;
295295
ptx : Nvrtc.compile_to_ptx_result;
296296
bindings : Indexing.unit_bindings;
297-
kparams_and_names : ((string * kparam_source) list * string) option array;
297+
params_and_names : ((string * param_source) list * string) option array;
298298
}
299299
[@@deriving sexp_of]
300300

@@ -839,7 +839,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
839839
let procs = [| lowered |]
840840
end)) in
841841
let idx_params = Indexing.bound_symbols bindings in
842-
let kparams, proc_doc = Syntax.compile_proc ~name idx_params lowered in
842+
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
843843
let cuda_includes =
844844
{|#include <cuda_fp16.h>
845845
#include <cuda_bf16.h>
@@ -861,21 +861,21 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
861861
~proc_doc
862862
in
863863
let ptx = cuda_to_ptx ~name source in
864-
{ traced_store; ptx; kparams; bindings; name }
864+
{ traced_store; ptx; params; bindings; name }
865865

866866
let%diagn2_sexp compile_batch ~names bindings lowereds =
867867
let module Syntax = C_syntax.C_syntax (Cuda_syntax_config (struct
868868
let procs = Array.filter_opt lowereds
869869
end)) in
870870
let idx_params = Indexing.bound_symbols bindings in
871-
let kparams_and_docs =
871+
let params_and_docs =
872872
Array.map2_exn names lowereds
873873
~f:
874874
(Option.map2 ~f:(fun name lowered ->
875-
let kparams, doc = Syntax.compile_proc ~name idx_params lowered in
876-
((kparams, name), doc)))
875+
let params, doc = Syntax.compile_proc ~name idx_params lowered in
876+
((params, name), doc)))
877877
in
878-
let all_proc_docs = List.filter_map (Array.to_list kparams_and_docs) ~f:(Option.map ~f:snd) in
878+
let all_proc_docs = List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd) in
879879
let final_doc = PPrint.(separate hardline all_proc_docs) in
880880
let cuda_includes =
881881
{|#include <cuda_fp16.h>
@@ -905,8 +905,8 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
905905
in
906906
let ptx = cuda_to_ptx ~name source in
907907
let traced_stores = Array.map lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store)) in
908-
let kparams_and_names = Array.map kparams_and_docs ~f:(Option.map ~f:fst) in
909-
{ traced_stores; ptx; kparams_and_names; bindings }
908+
let params_and_names = Array.map params_and_docs ~f:(Option.map ~f:fst) in
909+
{ traced_stores; ptx; params_and_names; bindings }
910910

911911
let get_global_run_id =
912912
let next_id = ref 0 in
@@ -915,7 +915,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
915915
if !next_id < 0 then next_id := 0;
916916
!next_id
917917

918-
let link_proc ~prior_context ~name ~(kparams : (string * kparam_source) list) ~ctx_arrays
918+
let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx_arrays
919919
lowered_bindings run_module =
920920
let func = Cu.Module.get_function run_module ~name in
921921
let stream = prior_context.stream in
@@ -929,13 +929,13 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
929929
"on",
930930
stream_name,
931931
(log_id : int),
932-
(kparams : (string * kparam_source) list)];
932+
(params : (string * param_source) list)];
933933
let module S = Cu.Stream in
934934
let args : S.kernel_param list =
935935
(* TODO: should we prohibit or warn about local-only tensors that are in
936936
prior_context.ctx_arrays? *)
937-
List.map kparams ~f:(function
938-
| _name, Kparam_ptr tn ->
937+
List.map params ~f:(function
938+
| _name, Param_ptr tn ->
939939
let arr = Option.value_exn ~here:[%here] @@ Map.find ctx_arrays tn in
940940
S.Tensor arr
941941
| _name, Log_file_name -> S.Int log_id
@@ -985,7 +985,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
985985
List.map idx_params ~f:(fun s -> (s, ref 0))
986986
in
987987
let task =
988-
link_proc ~prior_context ~name:code.name ~kparams:code.kparams ~ctx_arrays lowered_bindings
988+
link_proc ~prior_context ~name:code.name ~params:code.params ~ctx_arrays lowered_bindings
989989
run_module
990990
in
991991
(lowered_bindings, task)
@@ -1000,11 +1000,11 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
10001000
let run_module = Cu.Module.load_data_ex code_batch.ptx (run_options ()) in
10011001
prior_context.stream.device.dev.set_builtins_in run_module;
10021002
let procs =
1003-
Array.mapi code_batch.kparams_and_names ~f:(fun i pns ->
1003+
Array.mapi code_batch.params_and_names ~f:(fun i pns ->
10041004
Option.value ~default:None
1005-
@@ Option.map2 pns ctx_arrays.(i) ~f:(fun (kparams, name) ctx_arrays ->
1005+
@@ Option.map2 pns ctx_arrays.(i) ~f:(fun (params, name) ctx_arrays ->
10061006
let task =
1007-
link_proc ~prior_context ~name ~kparams ~ctx_arrays lowered_bindings run_module
1007+
link_proc ~prior_context ~name ~params ~ctx_arrays lowered_bindings run_module
10081008
in
10091009
Some task))
10101010
in

arrayjit/lib/metal_backend.ml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
387387
metal_source : string; (* Store source, compile during link if not already compiled *)
388388
compiled_code : Me.Library.t option array; (* Store compiled code per device *)
389389
func_name : string;
390-
kparams : (string * kparam_source) list;
390+
params : (string * param_source) list;
391391
bindings : Indexing.unit_bindings;
392392
traced_store : Low_level.traced_store;
393393
}
@@ -396,7 +396,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
396396
type code_batch = {
397397
metal_source : string; (* Store combined source *)
398398
compiled_code : Me.Library.t option array; (* Store compiled code per device *)
399-
funcs : (string * (string * kparam_source) list) option array; (* func_name * kparams *)
399+
funcs : (string * (string * param_source) list) option array; (* func_name * params *)
400400
bindings : Indexing.unit_bindings;
401401
traced_stores : Low_level.traced_store option array;
402402
}
@@ -671,7 +671,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
671671
end)) in
672672
let idx_params = Indexing.bound_symbols bindings in
673673
(* Add Metal address space qualifiers *)
674-
let kparams, proc_doc = Syntax.compile_proc ~name idx_params lowered in
674+
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
675675
let metal_includes = {|#include <metal_stdlib>
676676
using namespace metal;|} in
677677
let source =
@@ -683,7 +683,7 @@ using namespace metal;|} in
683683
compiled_code = Array.create ~len:num_devs None;
684684
(* One slot per device *)
685685
func_name = name;
686-
kparams;
686+
params;
687687
bindings;
688688
traced_store = lowered.traced_store;
689689
}
@@ -697,8 +697,8 @@ using namespace metal;|} in
697697
Array.map2_exn names lowereds
698698
~f:
699699
(Option.map2 ~f:(fun name lowered ->
700-
let kparams, doc = Syntax.compile_proc ~name idx_params lowered in
701-
((name, kparams), doc)))
700+
let params, doc = Syntax.compile_proc ~name idx_params lowered in
701+
((name, params), doc)))
702702
in
703703
let all_proc_docs = List.filter_map (Array.to_list funcs_and_docs) ~f:(Option.map ~f:snd) in
704704
let final_doc = PPrint.(separate hardline all_proc_docs) in
@@ -720,7 +720,7 @@ using namespace metal;|} in
720720
}
721721

722722
let%debug4_sexp link_proc ~prior_context ~library ~func_name
723-
~(kparams : (string * kparam_source) list) ~lowered_bindings ~(ctx_arrays : buffer_ptr Tn.t_map)
723+
~(params : (string * param_source) list) ~lowered_bindings ~(ctx_arrays : buffer_ptr Tn.t_map)
724724
: Task.t =
725725
let stream = prior_context.stream in
726726
let device = stream.device.dev in
@@ -740,12 +740,12 @@ using namespace metal;|} in
740740
Me.ComputeCommandEncoder.set_compute_pipeline_state encoder pso;
741741

742742
(* Set arguments *)
743-
List.iteri kparams ~f:(fun index (_p_name, p_source) ->
743+
List.iteri params ~f:(fun index (_p_name, p_source) ->
744744
match p_source with
745-
| Kparam_ptr tn when Map.mem ctx_arrays tn ->
745+
| Param_ptr tn when Map.mem ctx_arrays tn ->
746746
let buffer = Map.find_exn ctx_arrays tn in
747747
Me.ComputeCommandEncoder.set_buffer encoder ~index buffer
748-
| Kparam_ptr tn when Tn.known_constant tn && Tn.is_hosted_force tn 48 ->
748+
| Param_ptr tn when Tn.known_constant tn && Tn.is_hosted_force tn 48 ->
749749
let buffer =
750750
Hashtbl.find_or_add stream.device.device_buffer_cache tn ~default:(fun () ->
751751
get_buffer_for_ptr device ~size_in_bytes:(Lazy.force tn.size_in_bytes)
@@ -754,9 +754,9 @@ using namespace metal;|} in
754754
@@ Lazy.force tn.array)
755755
in
756756
Me.ComputeCommandEncoder.set_buffer encoder ~index buffer
757-
| Kparam_ptr tn ->
757+
| Param_ptr tn ->
758758
failwith
759-
[%string "Kparam_ptr %{Tn.debug_name tn} not found in ctx_arrays for %{func_name}"]
759+
[%string "Param_ptr %{Tn.debug_name tn} not found in ctx_arrays for %{func_name}"]
760760
| Static_idx s ->
761761
let value = !(Indexing.find_exn lowered_bindings s) in
762762
let size = Ctypes.sizeof Ctypes.int in
@@ -803,7 +803,7 @@ using namespace metal;|} in
803803
List.map (Indexing.bound_symbols code.bindings) ~f:(fun s -> (s, ref 0))
804804
in
805805
let task =
806-
link_proc ~prior_context ~library ~func_name:code.func_name ~kparams:code.kparams
806+
link_proc ~prior_context ~library ~func_name:code.func_name ~params:code.params
807807
~lowered_bindings ~ctx_arrays
808808
in
809809
(lowered_bindings, task)
@@ -817,9 +817,9 @@ using namespace metal;|} in
817817

818818
let tasks =
819819
Array.mapi code_batch.funcs ~f:(fun i func_opt ->
820-
Option.bind func_opt ~f:(fun (func_name, kparams) ->
820+
Option.bind func_opt ~f:(fun (func_name, params) ->
821821
Option.map ctx_arrays_opts.(i) ~f:(fun ctx_arrays ->
822-
link_proc ~prior_context ~library ~func_name ~kparams ~lowered_bindings ~ctx_arrays)))
822+
link_proc ~prior_context ~library ~func_name ~params ~lowered_bindings ~ctx_arrays)))
823823
in
824824
(lowered_bindings, tasks)
825825
end

0 commit comments

Comments
 (0)