diff --git a/feature_transformer.py b/feature_transformer.py index 7084c1f4..6169e64e 100644 --- a/feature_transformer.py +++ b/feature_transformer.py @@ -367,6 +367,9 @@ def backward(ctx, grad_output): return None, None, weight_grad, bias_grad +dft_stream_0 = cp.cuda.Stream() +dft_stream_1 = cp.cuda.Stream() + class DoubleFeatureTransformerSliceFunction(autograd.Function): @staticmethod @@ -418,31 +421,33 @@ def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature max_active_features = feature_indices_0.shape[1] output_size = weight.shape[1] - output0 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) - output1 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) - kernel = make_feature_transformer_slice_forward_kernel(max_active_features, output_size) - kernel( - grid=(batch_size,), - args=( - feature_indices_0.data_ptr(), - feature_values_0.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), - output0.data_ptr() + + with dft_stream_0: + output0 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) + kernel( + grid=(batch_size,), + args=( + feature_indices_0.data_ptr(), + feature_values_0.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), + output0.data_ptr() + ) ) - ) - kernel( - grid=(batch_size,), - args=( - feature_indices_1.data_ptr(), - feature_values_1.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), - output1.data_ptr() + with dft_stream_1: + output1 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True) + kernel( + grid=(batch_size,), + args=( + feature_indices_1.data_ptr(), + feature_values_1.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), + output1.data_ptr() + ) ) - ) return output0, output1 @@ -465,27 +470,31 @@ def backward(ctx, grad_output_0, grad_output_1): bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device) kernel = make_feature_transformer_slice_backward_kernel(max_active_features, output_size) - kernel( - grid=(batch_size,), - args=( - feature_indices_0.data_ptr(), - feature_values_0.data_ptr(), - weight_grad.data_ptr(), - bias_grad.data_ptr(), - grad_output_0.data_ptr() + + # We can do it in two independent streams because all the writes in the kernel are atomic + with dft_stream_0: + kernel( + grid=(batch_size,), + args=( + feature_indices_0.data_ptr(), + feature_values_0.data_ptr(), + weight_grad.data_ptr(), + bias_grad.data_ptr(), + grad_output_0.data_ptr() + ) ) - ) - kernel( - grid=(batch_size,), - args=( - feature_indices_1.data_ptr(), - feature_values_1.data_ptr(), - weight_grad.data_ptr(), - bias_grad.data_ptr(), - grad_output_1.data_ptr() + with dft_stream_1: + kernel( + grid=(batch_size,), + args=( + feature_indices_1.data_ptr(), + feature_values_1.data_ptr(), + weight_grad.data_ptr(), + bias_grad.data_ptr(), + grad_output_1.data_ptr() + ) ) - ) return None, None, None, None, weight_grad, bias_grad