Skip to content

Commit 4bb55d2

Browse files
committed
Define all the pointer-like address structs outside storage type
1 parent ec9ba8e commit 4bb55d2

4 files changed

Lines changed: 155 additions & 129 deletions

File tree

source/standard-modules/neural/bindless-storage.slang

Lines changed: 112 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,115 @@
11
implementing neural;
22

3+
/**
4+
Bindless address type with pointer-like semantics.
5+
Wraps a buffer handle and base index to provide array-like access.
6+
*/
7+
public struct BindlessAddress<T> : IPointerLikeAddress<T>
8+
where T : __BuiltinFloatingPointType
9+
where T.Differential == T
10+
{
11+
public typealias Differential = BindlessAddress<T.Differential>;
12+
13+
internal RWStructuredBuffer<T>.Handle handle;
14+
internal uint baseIndex;
15+
16+
public __init(RWStructuredBuffer<T>.Handle handle)
17+
{
18+
this.handle = handle;
19+
this.baseIndex = 0;
20+
}
21+
22+
public __subscript(uint index)->T
23+
{
24+
[nonmutating]
25+
get { return handle[baseIndex + index]; }
26+
27+
[mutating]
28+
set { handle[baseIndex + index] = newValue; }
29+
}
30+
31+
[require(cuda_glsl_hlsl_metal_spirv, sm_6_6)]
32+
public void atomicAdd(uint index, T value)
33+
{
34+
__atomic_add(handle[baseIndex + index], value);
35+
}
36+
37+
public This getOffset(int elements)
38+
{
39+
uint newBaseIndex = baseIndex + elements;
40+
41+
This address = This(handle);
42+
address.baseIndex = newBaseIndex;
43+
return address;
44+
}
45+
}
46+
47+
public struct PointerAddress<T> : IPointerLikeAddress<T>
48+
where T : __BuiltinFloatingPointType
49+
where T.Differential == T
50+
{
51+
public typealias Differential = PointerAddress<T.Differential>;
52+
53+
T* ptr;
54+
55+
public __init(T* ptr)
56+
{
57+
this.ptr = ptr;
58+
}
59+
60+
public __subscript(uint index)->T
61+
{
62+
[nonmutating]
63+
get { return ptr[index]; }
64+
65+
[mutating]
66+
set { ptr[index] = newValue; }
67+
}
68+
69+
public This getOffset(int elements)
70+
{
71+
return This(ptr + elements);
72+
}
73+
74+
[require(cuda_glsl_hlsl_metal_spirv, sm_6_6)]
75+
public void atomicAdd(uint index, T value)
76+
{
77+
__atomic_add(ptr[index], value);
78+
}
79+
}
80+
81+
// We currently don't support UserPointer as an `IDifferentiablePtrType`, the issue is tracked in
82+
// https://github.com/shader-slang/slang/issues/8834.
83+
// So we define an internal extension for now, once we can resolve the issue, we can make it public.
84+
internal extension<T> Ptr<T, Access.ReadWrite, AddressSpace.Device> : IPointerLikeAddress<T>
85+
where T : __BuiltinFloatingPointType
86+
where T.Differential == T
87+
{
88+
internal typealias Differential = Ptr<T.Differential, Access.ReadWrite, AddressSpace.Device>;
89+
90+
internal __init(Ptr<T, Access.ReadWrite, AddressSpace.Device> ptr)
91+
{
92+
this = ptr;
93+
}
94+
95+
internal __subscript(uint index)->T
96+
{
97+
[nonmutating]
98+
get { return this[index]; }
99+
100+
[mutating]
101+
set { this[index] = newValue; }
102+
}
103+
104+
internal This getOffset(int elements)
105+
{
106+
return This(this + elements);
107+
}
108+
109+
[require(cuda_glsl_hlsl_metal_spirv, sm_6_6)]
110+
internal void atomicAdd(uint index, T value) {__atomic_add(this[index], value);}
111+
}
112+
3113
/**
4114
Bindless buffer storage implementation using buffer handles.
5115
Provides pointer-like addressing through buffer handles, enabling more flexible
@@ -14,49 +124,7 @@ public struct BindlessBufferStorage<T> : IStorage<T>
14124
where T : __BuiltinFloatingPointType
15125
where T.Differential == T
16126
{
17-
/**
18-
Bindless address type with pointer-like semantics.
19-
Wraps a buffer handle and base index to provide array-like access.
20-
*/
21-
public struct BindlessAddress : IPointerLikeAddress<T>
22-
{
23-
public typealias Differential =
24-
BindlessBufferStorage<T.Differential>.BindlessAddress;
25-
26-
internal RWStructuredBuffer<T>.Handle handle;
27-
internal uint baseIndex;
28-
29-
public __init(RWStructuredBuffer<T>.Handle handle)
30-
{
31-
this.handle = handle;
32-
this.baseIndex = 0;
33-
}
34-
35-
public __subscript(uint index)->T
36-
{
37-
[nonmutating]
38-
get { return handle[baseIndex + index]; }
39-
40-
[mutating]
41-
set { handle[baseIndex + index] = newValue; }
42-
}
43-
44-
[require(cuda_glsl_hlsl_metal_spirv, sm_6_6)]
45-
public void atomicAdd(uint index, T value)
46-
{
47-
__atomic_add(handle[baseIndex + index], value);
48-
}
49-
50-
public Address getOffset(int elements)
51-
{
52-
uint newBaseIndex = baseIndex + elements;
53-
54-
Address address = Address(handle);
55-
address.baseIndex = newBaseIndex;
56-
return address;
57-
}
58-
}
59-
public typealias Address = BindlessAddress;
127+
public typealias Address = BindlessAddress<T>;
60128
public typealias Differential = BindlessBufferStorage<T.Differential>;
61129

62130
// Following method will not be needed for bindless storage
@@ -66,45 +134,11 @@ public struct BindlessBufferStorage<T> : IStorage<T>
66134
public static Address getOffset(Address base, int elements) { return base.getOffset(elements); }
67135
}
68136

69-
// [require(cpp_cuda_metal_spirv)]
70137
public struct PointerStorage<T> : IStorage<T>
71138
where T : __BuiltinFloatingPointType
72139
where T.Differential == T
73140
{
74-
public struct PointerAddress : IPointerLikeAddress<T>
75-
{
76-
public typealias Differential =
77-
PointerStorage<T.Differential>.PointerAddress;
78-
79-
T* ptr;
80-
81-
public __init(T* ptr)
82-
{
83-
this.ptr = ptr;
84-
}
85-
86-
public __subscript(uint index)->T
87-
{
88-
[nonmutating]
89-
get { return ptr[index]; }
90-
91-
[mutating]
92-
set { ptr[index] = newValue; }
93-
}
94-
95-
public Address getOffset(int elements)
96-
{
97-
return Address(ptr + elements);
98-
}
99-
100-
[require(cuda_glsl_hlsl_metal_spirv, sm_6_6)]
101-
public void atomicAdd(uint index, T value)
102-
{
103-
__atomic_add(ptr[index], value);
104-
}
105-
}
106-
107-
public typealias Address = PointerAddress;
141+
public typealias Address = PointerAddress<T>;
108142
public typealias Differential = PointerStorage<T.Differential>;
109143

110144
// Following method will not be needed for pointer storage

source/standard-modules/neural/inline-vector.slang

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,10 @@ public struct InlineVector<T, int N> : IVector<T, N>
194194

195195
// Linear transformation without bias (Bindless storage)
196196
[BackwardDerivative(linearTransformBwd)]
197-
public OutputVector linearTransform<int OutputSize, BindlessStorage, OutputVector>(
198-
BindlessStorage.Address weightAddress)
199-
where BindlessStorage : IStorage<T>
200-
where BindlessStorage.Address : IPointerLikeAddress<T>
201-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
197+
public OutputVector linearTransform<int OutputSize, Address, OutputVector>(
198+
Address weightAddress)
199+
where Address : IPointerLikeAddress<T>
200+
where Address.Differential : IPointerLikeAddress<T.Differential>
202201
where OutputVector : IVector<T, OutputSize>
203202
{
204203
var output = OutputVector();
@@ -220,16 +219,15 @@ public struct InlineVector<T, int N> : IVector<T, N>
220219

221220
// Linear transformation with bias (Bindless storage)
222221
[BackwardDerivative(linearTransformBwd)]
223-
public OutputVector linearTransform<int OutputSize, BindlessStorage, OutputVector>(
224-
BindlessStorage.Address weightAddress,
225-
BindlessStorage.Address biasAddress)
226-
where BindlessStorage : IStorage<T>
227-
where BindlessStorage.Address : IPointerLikeAddress<T>
228-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
222+
public OutputVector linearTransform<int OutputSize, Address, OutputVector>(
223+
Address weightAddress,
224+
Address biasAddress)
225+
where Address : IPointerLikeAddress<T>
226+
where Address.Differential : IPointerLikeAddress<T.Differential>
229227
where OutputVector : IVector<T, OutputSize>
230228
{
231229
// Reuse the unbias matmul method
232-
OutputVector output = this.linearTransform<OutputSize, BindlessStorage, OutputVector>(weightAddress);
230+
OutputVector output = this.linearTransform<OutputSize, Address, OutputVector>(weightAddress);
233231

234232
[ForceUnroll]
235233
for (int i = 0; i < OutputSize; i++)
@@ -239,13 +237,12 @@ public struct InlineVector<T, int N> : IVector<T, N>
239237
}
240238

241239
// Backward of linear transformation without bias (Bindless storage)
242-
static public void linearTransformBwd<int OutputSize, BindlessStorage, OutputVector>(
240+
static public void linearTransformBwd<int OutputSize, Address, OutputVector>(
243241
inout DifferentialPair<This> dthis,
244-
DifferentialPtrPair<BindlessStorage.Address> dparameters,
242+
DifferentialPtrPair<Address> dparameters,
245243
OutputVector.Differential doutput)
246-
where BindlessStorage : IStorage<T>
247-
where BindlessStorage.Address : IPointerLikeAddress<T>
248-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
244+
where Address : IPointerLikeAddress<T>
245+
where Address.Differential : IPointerLikeAddress<T.Differential>
249246
where OutputVector : IVector<T, OutputSize>
250247
where OutputVector.Differential : IVector<T.Differential, OutputSize>
251248
{
@@ -283,18 +280,17 @@ public struct InlineVector<T, int N> : IVector<T, N>
283280
}
284281

285282
// Backward of linear transformation with bias (Bindless storage)
286-
static public void linearTransformBwd<int OutputSize, BindlessStorage, OutputVector>(
283+
static public void linearTransformBwd<int OutputSize, Address, OutputVector>(
287284
inout DifferentialPair<This> dthis,
288-
DifferentialPtrPair<BindlessStorage.Address> dWeightAddress,
289-
DifferentialPtrPair<BindlessStorage.Address> dBiasAddress,
285+
DifferentialPtrPair<Address> dWeightAddress,
286+
DifferentialPtrPair<Address> dBiasAddress,
290287
OutputVector.Differential doutput)
291-
where BindlessStorage : IStorage<T>
292-
where BindlessStorage.Address : IPointerLikeAddress<T>
293-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
288+
where Address : IPointerLikeAddress<T>
289+
where Address.Differential : IPointerLikeAddress<T.Differential>
294290
where OutputVector : IVector<T, OutputSize>
295291
{
296292
// Reuse the unbias backward method
297-
linearTransformBwd<OutputSize, BindlessStorage, OutputVector>(dthis, dWeightAddress, doutput);
293+
linearTransformBwd<OutputSize, Address, OutputVector>(dthis, dWeightAddress, doutput);
298294

299295
let biasOffset = dBiasAddress.d.getOffset(0);
300296
// dBias = dOutput

source/standard-modules/neural/ivector.slang

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,9 @@ public interface IVector<T, int N> : IDifferentiable
135135
- `OutputVector` must conform to `IVector<T, OutputSize>`
136136
*/
137137
[Differentiable]
138-
public OutputVector linearTransform<int OutputSize, BindlessStorage, OutputVector>(
139-
BindlessStorage.Address weightAddress)
140-
where BindlessStorage : IStorage<T>
141-
where BindlessStorage.Address : IPointerLikeAddress<T>
142-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
138+
public OutputVector linearTransform<int OutputSize, Address, OutputVector>(Address weightAddress)
139+
where Address : IPointerLikeAddress<T>
140+
where Address.Differential : IPointerLikeAddress<T.Differential>
143141
where OutputVector : IVector<T, OutputSize>;
144142

145143
/**
@@ -165,11 +163,9 @@ public interface IVector<T, int N> : IDifferentiable
165163
- `OutputVector` must conform to `IVector<T, OutputSize>`
166164
*/
167165
[Differentiable]
168-
public OutputVector linearTransform<int OutputSize, BindlessStorage, OutputVector>(
169-
BindlessStorage.Address weightAddress,
170-
BindlessStorage.Address biasAddress)
171-
where BindlessStorage : IStorage<T>
172-
where BindlessStorage.Address : IPointerLikeAddress<T>
173-
where BindlessStorage.Address.Differential : IPointerLikeAddress<T.Differential>
166+
public OutputVector linearTransform<int OutputSize, Address, OutputVector>(
167+
Address weightAddress, Address biasAddress)
168+
where Address : IPointerLikeAddress<T>
169+
where Address.Differential : IPointerLikeAddress<T.Differential>
174170
where OutputVector : IVector<T, OutputSize>;
175171
}

0 commit comments

Comments
 (0)