forked from huggingface/kernels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrelu.mm
More file actions
104 lines (83 loc) · 4.09 KB
/
relu.mm
File metadata and controls
104 lines (83 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <ATen/mps/MPSStream.h>
#include <torch/torch.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// Include the auto-generated header with embedded metallib
#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
torch::Tensor &output) {
@autoreleasepool {
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
TORCH_CHECK(stream, "Failed to get MPS stream");
id<MTLDevice> device = stream->device();
int numThreads = input.numel();
// Load the embedded Metal library from memory
NSError *error = nil;
id<MTLLibrary> customKernelLibrary = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
TORCH_CHECK(customKernelLibrary,
"Failed to create Metal library from embedded data: ",
error.localizedDescription.UTF8String);
std::string kernel_name =
std::string("relu_forward_kernel_") +
(input.scalar_type() == torch::kFloat ? "float" : "half");
id<MTLFunction> customReluFunction = [customKernelLibrary
newFunctionWithName:[NSString
stringWithUTF8String:kernel_name.c_str()]];
TORCH_CHECK(customReluFunction,
"Failed to create function state object for ",
kernel_name.c_str());
id<MTLComputePipelineState> reluPSO =
[device newComputePipelineStateWithFunction:customReluFunction
error:&error];
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
// Use stream->commandEncoder() to properly integrate with PyTorch's
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
// when the kernel is called twice in sequence.
dispatch_sync(stream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
[computeEncoder setComputePipelineState:reluPSO];
[computeEncoder setBuffer:getMTLBufferStorage(input)
offset:input.storage_offset() * input.element_size()
atIndex:0];
[computeEncoder setBuffer:getMTLBufferStorage(output)
offset:output.storage_offset() * output.element_size()
atIndex:1];
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup;
if (threadGroupSize > numThreads) {
threadGroupSize = numThreads;
}
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
[computeEncoder dispatchThreads:gridSize
threadsPerThreadgroup:threadgroupSize];
});
stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
}
return output;
}
void relu(torch::Tensor &out, torch::Tensor const &input) {
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat ||
input.scalar_type() == torch::kHalf,
"Unsupported data type: ", input.scalar_type());
TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());
TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());
TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());
dispatchReluKernel(input, out);
}