@@ -57,8 +57,8 @@ def get_weights(self) -> np.ndarray[Any, Any]:
5757 rowmaj_weights = np .empty ((self .num_outputs , self .num_inputs ), dtype = self .dtype .numpy ())
5858
5959 layout = CoopVecMatrixLayout .training_optimal
60- self .weights .device .coopvec_convert_matrix_host (
61- weights_np , rowmaj_weights , src_layout = layout
60+ self .weights .device .convert_coop_vec_matrix (
61+ dst = rowmaj_weights , src = weights_np , src_layout = layout
6262 )
6363
6464 return rowmaj_weights
@@ -119,11 +119,13 @@ def model_init(self, module: Module, input_type: SlangType):
119119
120120 if self .use_coopvec :
121121 layout = CoopVecMatrixLayout .training_optimal
122- desc = device .coopvec_create_matrix_desc (fan_out , fan_in , layout , self .dtype .sgl (), 0 )
122+ desc = device .create_coop_vec_matrix_desc (
123+ rows = fan_out , cols = fan_in , layout = layout , element_type = self .dtype .sgl ()
124+ )
123125 weight_count = desc .size // self .dtype .size ()
124126
125127 params_np = np .zeros ((weight_count ,), dtype = self .dtype .numpy ())
126- device .coopvec_convert_matrix_host ( weights_np , params_np , dst_layout = layout )
128+ device .convert_coop_vec_matrix ( dst = params_np , src = weights_np , dst_layout = layout )
127129
128130 self .weights = Tensor .empty (device , (weight_count ,), str (self .dtype ))
129131 self .weights .storage .copy_from_numpy (params_np )
0 commit comments