Skip to content

Commit 9a9b270

Browse files
authored
Use refactored coop vec API (#39)
1 parent 2cdfe37 commit 9a9b270

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

experiments/neuralnetwork/neuralnetworks/components/LinearLayer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def get_weights(self) -> np.ndarray[Any, Any]:
5757
rowmaj_weights = np.empty((self.num_outputs, self.num_inputs), dtype=self.dtype.numpy())
5858

5959
layout = CoopVecMatrixLayout.training_optimal
60-
self.weights.device.coopvec_convert_matrix_host(
61-
weights_np, rowmaj_weights, src_layout=layout
60+
self.weights.device.convert_coop_vec_matrix(
61+
dst=rowmaj_weights, src=weights_np, src_layout=layout
6262
)
6363

6464
return rowmaj_weights
@@ -119,11 +119,13 @@ def model_init(self, module: Module, input_type: SlangType):
119119

120120
if self.use_coopvec:
121121
layout = CoopVecMatrixLayout.training_optimal
122-
desc = device.coopvec_create_matrix_desc(fan_out, fan_in, layout, self.dtype.sgl(), 0)
122+
desc = device.create_coop_vec_matrix_desc(
123+
rows=fan_out, cols=fan_in, layout=layout, element_type=self.dtype.sgl()
124+
)
123125
weight_count = desc.size // self.dtype.size()
124126

125127
params_np = np.zeros((weight_count,), dtype=self.dtype.numpy())
126-
device.coopvec_convert_matrix_host(weights_np, params_np, dst_layout=layout)
128+
device.convert_coop_vec_matrix(dst=params_np, src=weights_np, dst_layout=layout)
127129

128130
self.weights = Tensor.empty(device, (weight_count,), str(self.dtype))
129131
self.weights.storage.copy_from_numpy(params_np)

0 commit comments

Comments
 (0)