Skip to content

Commit b15a558

Browse files
committed
Add attention
1 parent 9b5e2a4 commit b15a558

5 files changed

Lines changed: 101 additions & 0 deletions

File tree

lib/Autograd.chpl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,24 @@ record batchNormOp : serializable {
11011101
proc spec : GradOpSpec do return new dict(("operation","BatchNorm"));
11021102
}
11031103

1104+
record multiheadAttentionOp : serializable {
1105+
type eltType = real;
1106+
var features: shared BaseTensorResource(?);
1107+
var q_weight: shared BaseTensorResource(eltType, ?);
1108+
var k_weight: shared BaseTensorResource(eltType, ?);
1109+
var v_weight: shared BaseTensorResource(eltType, ?);
1110+
var num_heads: int;
1111+
var embed_dim: int;
1112+
1113+
proc children do return (features, q_weight, k_weight, v_weight);
1114+
1115+
proc forward() {
1116+
return ndarray.multiheadAttention(features.array, q_weight.array, k_weight.array, v_weight.array, num_heads, embed_dim);
1117+
}
1118+
1119+
proc spec : GradOpSpec do return new dict(("operation", "MultiHeadAttention"));
1120+
}
1121+
11041122
record dropoutOp : serializable {
11051123
param rank: int;
11061124
type eltType;

lib/DynamicTensor.chpl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,28 @@ proc type dynamicTensor.batchnorm(
634634
return new dynamicTensor(eltType);
635635
}
636636

637+
proc type dynamicTensor.multiheadAttention(
638+
features: dynamicTensor(?eltType),
639+
q_weight: dynamicTensor(eltType),
640+
k_weight: dynamicTensor(eltType),
641+
v_weight: dynamicTensor(eltType),
642+
num_heads: int,
643+
embed_dim: int
644+
): dynamicTensor(eltType) {
645+
if features.checkRank(3) {
646+
return staticTensor.multiheadAttention(
647+
features.forceRank(3),
648+
q_weight.forceRank(2),
649+
k_weight.forceRank(2),
650+
v_weight.forceRank(2),
651+
num_heads,
652+
embed_dim
653+
).eraseRank();
654+
}
655+
halt("Could not determine rank in dynamicTensor.multiheadAttention");
656+
return new dynamicTensor(eltType);
657+
}
658+
637659
proc dynamicTensor.softmax(): dynamicTensor(eltType) {
638660
for param rank in 1..maxRank {
639661
if this.checkRank(rank) then

lib/NDArray.chpl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,29 @@ proc type ndarray.matmul(a: ndarray(?aRank,?eltType),b: ndarray(?bRank,eltType))
19821982
return prod;
19831983
}
19841984

1985+
// Supports 1 head as of now.
1986+
proc type ndarray.multiheadAttention(
1987+
features: ndarray(3, ?eltType),
1988+
q_weight: ndarray(2, eltType),
1989+
k_weight: ndarray(2, eltType),
1990+
v_weight: ndarray(2, eltType),
1991+
num_heads: int,
1992+
embed_dim: int
1993+
): ndarray(3, eltType) {
1994+
const fshape = features.shape;
1995+
const seq_len = fshape[0];
1996+
const batch_size = fshape[1];
1997+
const head_dim = embed_dim / 1;
1998+
1999+
var q = ndarray.matmul(features,q_weight);
2000+
var k = ndarray.matmul(features,k_weight);
2001+
var v = ndarray.matmul(features,v_weight);
2002+
var z = (ndarray.matmul(q,k.permute(0,2,1))/Math.sqrt(head_dim))._softmax(axis=2);
2003+
var a = ndarray.matmul(z,v);
2004+
2005+
return a;
2006+
}
2007+
19852008
proc type ndarray.batchNormTrain(
19862009
features: ndarray(?rank,?eltType),
19872010
weight: ndarray(1,eltType),

lib/Network.chpl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,32 @@ class BatchNorm : Module(?) {
10801080
}
10811081
}
10821082

1083+
class MultiheadAttention : Module(?) {
1084+
var q_weight: owned Parameter(eltType);
1085+
var k_weight: owned Parameter(eltType);
1086+
var v_weight: owned Parameter(eltType);
1087+
var num_heads: int;
1088+
var embed_dim: int;
1089+
1090+
proc init(type eltType = real, embed_dim: int, num_heads: int) {
1091+
this.q_weight = new Parameter(Tensor.ones(embed_dim, embed_dim));
1092+
this.k_weight = new Parameter(Tensor.ones(embed_dim, embed_dim));
1093+
this.v_weight = new Parameter(Tensor.ones(embed_dim, embed_dim));
1094+
this.num_heads = num_heads;
1095+
this.embed_dim = embed_dim;
1096+
}
1097+
1098+
override proc forward(input: Tensor(eltType)): Tensor(eltType) {
1099+
return Tensor.multiheadAttention(input, q_weight.data, k_weight.data, v_weight.data, num_heads, embed_dim);
1100+
}
1101+
1102+
override proc setup() {
1103+
addModule("query", q_weight);
1104+
addModule("key", k_weight);
1105+
addModule("value", v_weight);
1106+
}
1107+
}
1108+
10831109
class AdaptiveAvgPool2D : Module(?) {
10841110
// only handles square pooling
10851111
var outputSize: int;

lib/StaticTensor.chpl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,18 @@ proc type staticTensor.batchNorm(
453453
return tensorFromCtx(featureRank, eltType, ctx);
454454
}
455455

456+
proc type staticTensor.multiheadAttention(
457+
features: staticTensor(3, ?eltType),
458+
q_weight: staticTensor(2, eltType),
459+
k_weight: staticTensor(2, eltType),
460+
v_weight: staticTensor(2, eltType),
461+
num_heads: int,
462+
embed_dim: int
463+
): staticTensor(3, eltType) {
464+
var ctx = new multiheadAttentionOp(eltType, features.meta, q_weight.meta, k_weight.meta, v_weight.meta, num_heads, embed_dim);
465+
return tensorFromCtx(3, eltType, ctx);
466+
}
467+
456468
// proc matvec(mat: staticTensor(2,?eltType),vec: staticTensor(1,eltType)): staticTensor(1,eltType) {
457469
// const (n,) = vec.array.domain.shape;
458470
// const (m,_n) = mat.array.domain.shape;

0 commit comments

Comments
 (0)