@@ -284,7 +284,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
284284 type code = {
285285 traced_store : Low_level .traced_store ;
286286 ptx : Nvrtc .compile_to_ptx_result ;
287- kparams : (string * kparam_source ) list ;
287+ params : (string * param_source ) list ;
288288 bindings : Indexing .unit_bindings ;
289289 name : string ;
290290 }
@@ -294,7 +294,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
294294 traced_stores : Low_level .traced_store option array ;
295295 ptx : Nvrtc .compile_to_ptx_result ;
296296 bindings : Indexing .unit_bindings ;
297- kparams_and_names : ((string * kparam_source ) list * string ) option array ;
297+ params_and_names : ((string * param_source ) list * string ) option array ;
298298 }
299299 [@@ deriving sexp_of ]
300300
@@ -839,7 +839,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
839839 let procs = [| lowered |]
840840 end )) in
841841 let idx_params = Indexing. bound_symbols bindings in
842- let kparams , proc_doc = Syntax. compile_proc ~name idx_params lowered in
842+ let params , proc_doc = Syntax. compile_proc ~name idx_params lowered in
843843 let cuda_includes =
844844 {| #include < cuda_fp16.h>
845845#include < cuda_bf16.h>
@@ -861,21 +861,21 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
861861 ~proc_doc
862862 in
863863 let ptx = cuda_to_ptx ~name source in
864- { traced_store; ptx; kparams ; bindings; name }
864+ { traced_store; ptx; params ; bindings; name }
865865
866866 let % diagn2_sexp compile_batch ~names bindings lowereds =
867867 let module Syntax = C_syntax. C_syntax (Cuda_syntax_config (struct
868868 let procs = Array. filter_opt lowereds
869869 end )) in
870870 let idx_params = Indexing. bound_symbols bindings in
871- let kparams_and_docs =
871+ let params_and_docs =
872872 Array. map2_exn names lowereds
873873 ~f:
874874 (Option. map2 ~f: (fun name lowered ->
875- let kparams , doc = Syntax. compile_proc ~name idx_params lowered in
876- ((kparams , name), doc)))
875+ let params , doc = Syntax. compile_proc ~name idx_params lowered in
876+ ((params , name), doc)))
877877 in
878- let all_proc_docs = List. filter_map (Array. to_list kparams_and_docs ) ~f: (Option. map ~f: snd) in
878+ let all_proc_docs = List. filter_map (Array. to_list params_and_docs ) ~f: (Option. map ~f: snd) in
879879 let final_doc = PPrint. (separate hardline all_proc_docs) in
880880 let cuda_includes =
881881 {| #include < cuda_fp16.h>
@@ -905,8 +905,8 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
905905 in
906906 let ptx = cuda_to_ptx ~name source in
907907 let traced_stores = Array. map lowereds ~f: (Option. map ~f: (fun l -> l.Low_level. traced_store)) in
908- let kparams_and_names = Array. map kparams_and_docs ~f: (Option. map ~f: fst) in
909- { traced_stores; ptx; kparams_and_names ; bindings }
908+ let params_and_names = Array. map params_and_docs ~f: (Option. map ~f: fst) in
909+ { traced_stores; ptx; params_and_names ; bindings }
910910
911911 let get_global_run_id =
912912 let next_id = ref 0 in
@@ -915,7 +915,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
915915 if ! next_id < 0 then next_id := 0 ;
916916 ! next_id
917917
918- let link_proc ~prior_context ~name ~(kparams : (string * kparam_source ) list ) ~ctx_arrays
918+ let link_proc ~prior_context ~name ~(params : (string * param_source ) list ) ~ctx_arrays
919919 lowered_bindings run_module =
920920 let func = Cu.Module. get_function run_module ~name in
921921 let stream = prior_context.stream in
@@ -929,13 +929,13 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
929929 " on" ,
930930 stream_name,
931931 (log_id : int ),
932- (kparams : (string * kparam_source ) list )];
932+ (params : (string * param_source ) list )];
933933 let module S = Cu. Stream in
934934 let args : S.kernel_param list =
935935 (* TODO: should we prohibit or warn about local-only tensors that are in
936936 prior_context.ctx_arrays? *)
937- List. map kparams ~f: (function
938- | _name , Kparam_ptr tn ->
937+ List. map params ~f: (function
938+ | _name , Param_ptr tn ->
939939 let arr = Option. value_exn ~here: [% here] @@ Map. find ctx_arrays tn in
940940 S. Tensor arr
941941 | _name , Log_file_name -> S. Int log_id
@@ -985,7 +985,7 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
985985 List. map idx_params ~f: (fun s -> (s, ref 0 ))
986986 in
987987 let task =
988- link_proc ~prior_context ~name: code.name ~kparams : code.kparams ~ctx_arrays lowered_bindings
988+ link_proc ~prior_context ~name: code.name ~params : code.params ~ctx_arrays lowered_bindings
989989 run_module
990990 in
991991 (lowered_bindings, task)
@@ -1000,11 +1000,11 @@ module Fresh : Ir.Backend_impl.Lowered_backend = struct
10001000 let run_module = Cu.Module. load_data_ex code_batch.ptx (run_options () ) in
10011001 prior_context.stream.device.dev.set_builtins_in run_module;
10021002 let procs =
1003- Array. mapi code_batch.kparams_and_names ~f: (fun i pns ->
1003+ Array. mapi code_batch.params_and_names ~f: (fun i pns ->
10041004 Option. value ~default: None
1005- @@ Option. map2 pns ctx_arrays.(i) ~f: (fun (kparams , name ) ctx_arrays ->
1005+ @@ Option. map2 pns ctx_arrays.(i) ~f: (fun (params , name ) ctx_arrays ->
10061006 let task =
1007- link_proc ~prior_context ~name ~kparams ~ctx_arrays lowered_bindings run_module
1007+ link_proc ~prior_context ~name ~params ~ctx_arrays lowered_bindings run_module
10081008 in
10091009 Some task))
10101010 in
0 commit comments