forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMetalShaderLibrary.h
More file actions
217 lines (199 loc) · 7.52 KB
/
MetalShaderLibrary.h
File metadata and controls
217 lines (199 loc) · 7.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#pragma once
#ifdef __OBJC__
#include <Metal/Metal.h>
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLFunction> MTLFunction_t;
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
#else
typedef void MTLCompileOptions;
typedef void* MTLLibrary_t;
typedef void* MTLFunction_t;
typedef void* MTLComputePipelineState_t;
typedef void* MTLComputeCommandEncoder_t;
#endif
#include <c10/core/Scalar.h>
#include <c10/util/OptionalArrayRef.h>
#include <functional>
#include <optional>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
// Forward declaration of TensorBase and TensorIteratorBase
namespace at {
class TensorBase;
struct TensorIteratorBase;
} // namespace at
namespace at::native::mps {
namespace detail {
template <typename T>
class has_size_type {
template <typename U>
static constexpr std::true_type check(typename U::size_type*);
template <typename>
static constexpr std::false_type check(...);
public:
static constexpr bool value = decltype(check<T>(nullptr))::value;
};
template <typename T>
constexpr bool has_size_type_v = has_size_type<T>::value;
} // namespace detail
// Returns `gpuAddress` of respective `id<MTLBuffer>` plus storage offset
void* get_tensor_gpu_address(const at::TensorBase&);
class MetalKernelFunction {
public:
MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_);
~MetalKernelFunction();
MetalKernelFunction(MetalKernelFunction&) = delete;
// Shader properties
uint64_t getMaxThreadsPerThreadgroup() const;
uint64_t getThreadExecutionWidth() const;
uint64_t getStaticThreadGroupMemoryLength() const;
void runCommandBlock(std::function<void(void)> f);
// Methods below should be called from runCommandBlock function
void startEncoding();
void setArg(unsigned idx, const at::TensorBase& t);
void setArg(unsigned idx, const void* ptr, uint64_t size);
void setErrorBufferIndex(unsigned idx);
template <
typename T,
typename = std::enable_if_t<
std::is_integral_v<T> || std::is_same_v<T, float> ||
(std::is_class_v<T> && std::is_trivially_copyable_v<T> &&
!detail::has_size_type_v<T>)>>
inline void setArg(unsigned idx, const T val) {
setArg(idx, &val, sizeof(T));
}
template <
typename Container,
typename = std::enable_if_t<detail::has_size_type_v<Container>>>
inline void setArg(unsigned idx, const Container& values) {
setArg(
idx,
values.data(),
values.size() * sizeof(typename Container::value_type));
}
void dispatch(
uint64_t length,
std::optional<uint64_t> groupSize = std::nullopt);
void dispatch(
c10::ArrayRef<uint64_t> length,
c10::OptionalArrayRef<uint64_t> groupSize = std::nullopt);
private:
MTLComputePipelineState_t cps;
MTLFunction_t func;
MTLComputeCommandEncoder_t encoder = nullptr;
};
class MetalShaderLibrary {
public:
MetalShaderLibrary(std::string src)
: shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {}
MetalShaderLibrary(std::string src, unsigned nparams_)
: shaderSource(std::move(src)),
nparams(nparams_),
compile_options(nullptr) {}
MetalShaderLibrary(
std::string src,
unsigned nparams_,
MTLCompileOptions* compile_options_)
: shaderSource(std::move(src)),
nparams(nparams_),
compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
virtual ~MetalShaderLibrary();
std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
}
MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline MTLFunction_t getMTLFunction(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).second;
}
MTLFunction_t getMTLFunction(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).second;
}
static MetalShaderLibrary& getBundledLibrary();
void exec_unary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const std::optional<c10::Scalar> alpha = std::nullopt,
const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt);
// `ilp_threshold` lets callers tune when the dense ILP variant kicks in
// (numel >= threshold). When unspecified, the default is the same 256K
// crossover used by the unary path, but only for floating-point output;
// non-float outputs get UINT32_MAX (i.e. ILP off by default). Comparison
// and other ops with different memory-bandwidth profiles can override.
// `natural_output_dtype` is the dtype the kernel naturally produces (its
// registered DTYPEO). Defaults to `iter.common_dtype()`, which is right for
// arithmetic kernels where DTYPEO==compute precision. Comparison kernels
// produce bool and must pass `kBool` so the output-cast fallback allocates
// the right temp.
void exec_binary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const std::optional<c10::Scalar> alpha = std::nullopt,
const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt,
const std::optional<c10::ScalarType> natural_output_dtype = std::nullopt,
const std::optional<uint32_t> ilp_threshold = std::nullopt);
void exec_ternary_kernel(TensorIteratorBase& iter, const std::string& name);
template <typename T>
void exec_unary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
template <typename T>
void exec_binary_kernel_with_params(
TensorIteratorBase& iter,
const std::string& name,
T params,
const std::string& params_type_name);
protected:
virtual MTLLibrary_t getLibrary();
virtual MTLLibrary_t getLibrary(
const std::initializer_list<std::string>& params);
MTLLibrary_t library = nullptr;
private:
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
MTLLibrary_t lib,
const std::string& fname);
MTLLibrary_t compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
MTLCompileOptions* compile_options;
std::unordered_map<std::string, MTLLibrary_t> libMap;
std::unordered_map<
std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
};
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
public:
DynamicMetalShaderLibrary(const std::string& src) : MetalShaderLibrary(src) {
// Compile right away
getLibrary();
}
~DynamicMetalShaderLibrary() override;
};
class PrecompiledMetalShaderLibrary : public MetalShaderLibrary {
public:
explicit PrecompiledMetalShaderLibrary(std::vector<uint8_t> data);
explicit PrecompiledMetalShaderLibrary(const std::string& path);
~PrecompiledMetalShaderLibrary() override;
};
} // namespace at::native::mps