We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c57013a commit d97e790Copy full SHA for d97e790
1 file changed
openequivariance_extjax/src/libjax_tp_jit.cpp
@@ -19,24 +19,16 @@ using json = json11::Json;
19
#include <cuda_runtime.h>
20
21
#include "backend/backend_cuda.hpp"
22
- #include "group_mm_cuda.hpp"
23
using JITKernel = CUJITKernel;
24
using GPU_Allocator = CUDA_Allocator;
25
-
26
- template<typename T>
27
- using GroupMM = GroupMMCUDA<T>;
28
using stream_t = cudaStream_t;
29
#endif
30
31
#ifdef HIP_BACKEND
32
#include "backend/backend_hip.hpp"
33
- #include "group_mm_hip.hpp"
34
using JITKernel = HIPJITKernel;
35
using GPU_Allocator = HIP_Allocator;
36
37
38
- using GroupMM = GroupMMHIP<T>;
39
- using stream_t = hipStream_t;
+ using stream_t = hipStream_t;
40
41
42
#include "tensorproducts.hpp"
0 commit comments