Skip to content

Commit 19f1fe3

Browse files
authored
Merge pull request #27 from Iainmon/iain-changes
Waiting for response/fix from Chapel compiler regarding tuples of `ref`s.
2 parents 8cd7f87 + 7dbc0d8 commit 19f1fe3

7 files changed

Lines changed: 176 additions & 16 deletions

File tree

iain_dict_test.chpl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use OrderedDict;
2+
3+
4+
5+
writeln("Hello, world!");
6+
7+
var d = new dict(
8+
("one", 1),
9+
("two", 2),
10+
("three", 3)
11+
);
12+
13+
14+
for (k,v) in zip(d.keys(), d.values()) {
15+
writeln(k, " => ", v);
16+
}
17+
18+
19+
20+
// increment the value for all
21+
for (k,v) in zip(d.keys(),d.values()) {
22+
v += 1;
23+
}
24+
25+
26+
for (k,v) in zip(d.keys(), d.values()) {
27+
writeln(k, " => ", v);
28+
}

lib/Autograd.chpl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,17 @@ record reluOp : serializable {
278278
proc spec : GradOpSpec do return new dict(("operation","ReLU"));
279279
}
280280

281+
record squareOp : serializable {
282+
var input: shared BaseTensorResource(?);
283+
284+
proc children do return (input,);
285+
286+
proc forward() do
287+
return input.array.square();
288+
289+
proc spec : GradOpSpec do return new dict(("operation","Square"));
290+
}
291+
281292
record expOp : serializable {
282293
var input: shared BaseTensorResource(?);
283294

lib/DynamicTensor.chpl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ proc dynamicTensor.relu(): dynamicTensor(eltType) {
244244
return new dynamicTensor(eltType);
245245
}
246246

247+
proc dynamicTensor.square(): dynamicTensor(eltType) {
248+
for param rank in 1..maxRank {
249+
if this.checkRank(rank) then
250+
return this.forceRank(rank).square().eraseRank();
251+
}
252+
halt("Could not determine rank in dynamicTensor.square.");
253+
return new dynamicTensor(eltType);
254+
}
255+
247256
proc dynamicTensor.gelu(): dynamicTensor(eltType) {
248257
for param rank in 1..maxRank {
249258
if this.checkRank(rank) then

lib/NDArray.chpl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,18 @@ record ndarray : serializable {
582582
return rl;
583583
}
584584

585+
inline proc square() {
586+
const ref thisData = data;
587+
const dom = this.domain;
588+
var rl = new ndarray(dom,eltType);
589+
ref rlD = rl.data;
590+
forall i in dom.every() {
591+
const x = thisData[i];
592+
rlD[i] = x * x;
593+
}
594+
return rl;
595+
}
596+
585597
inline proc gelu() {
586598
const ref thisData = data;
587599
const dom = this.domain;

lib/Network.chpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ proc modelFromSpecFile(path: string, type dtype=real(32), targetLocales: [] loca
311311
return moduleFromSpec(ms,dtype,targetLocales,inputShape);
312312
}
313313

314-
proc loadModel(specFile: string, weightsFolder: string, type dtype = real(32),debug = false): owned Module(dtype) {
314+
proc loadModel(specFile: string, weightsFolder: string, type dtype = real(32),param debug = false): owned Module(dtype) {
315315
var model: owned Module(dtype) = modelFromSpecFile(specFile, dtype, empty_locales, optional.empty(1*int));
316316

317317
model.loadPyTorchDump(weightsFolder,dtype = dtype, debug = debug);

lib/OrderedDict.chpl

Lines changed: 109 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ record dict : serializable {
1919
this.order = order;
2020
}
2121

22-
proc init(table: map(?keyType,?valType)) {
22+
proc init(in table: map(?keyType,?valType)) {
2323
var ks = new list(keyType);
24-
var tbl: map(keyType,valType) = table;
25-
for k in tbl.keys() do
24+
var tbl = new map(keyType,valType);
25+
for k in table.keys() {
2626
ks.pushBack(k);
27-
this.init(tbl,ks);
27+
tbl.addOrReplace(k,table[k]);
28+
}
29+
this.init(table,ks);
2830
}
2931

3032
proc init(type keyType, type valType) {
@@ -59,18 +61,111 @@ record dict : serializable {
5961
return table[key];
6062
}
6163

62-
iter keys(): keyType do
63-
for i in 0..<order.size do
64-
yield order[i];
6564

66-
iter values(): valType do
67-
for k in keys() do
68-
yield table[k];
65+
proc ref getKey(i: int) ref do
66+
return order[i];
67+
68+
proc const ref getKey(i: int) const ref do
69+
return order[i];
70+
71+
iter keys() do
72+
for i in 0..<this.size do
73+
yield this.getKey(i);
74+
75+
proc ref getNVal(i: int) ref throws {
76+
return table[order[i]];
77+
}
78+
79+
proc const ref getNVal(i: int) const ref throws {
80+
return table[order[i]];
81+
}
82+
83+
iter ref values() ref do
84+
for i in 0..<this.size do
85+
yield this.getNVal(i);
86+
87+
iter const ref values() const ref do
88+
for i in 0..<this.size do
89+
yield this.getNVal(i);
90+
91+
// iter values() do
92+
// for i in 0..<this.size do
93+
// yield this.getVal(i);
94+
95+
proc ref this(k: keyType) ref throws {
96+
if !table.contains(k) then
97+
throw new Error("Key not found: " + k:string);
98+
return table[k];
99+
}
100+
101+
// iter values() ref : valType do
102+
// for k in keys() do
103+
// yield table[k];
104+
105+
106+
107+
108+
// iter these() {
109+
// for (k,v) in zip(this.keys(),this.values()) {
110+
// yield (k,v);
111+
// }
112+
// }
113+
114+
115+
// iter these() ref where !isClassType(valType) {
116+
// for k in order {
117+
// ref v = table[k];
118+
// yield (k,v);
119+
// }
120+
// }
121+
122+
// iter these() const ref where !isClassType(valType) do
123+
124+
// for k in order {
125+
// const ref v = table[k];
126+
// yield (k,v);
127+
// }
128+
129+
// iter these() where !isClassType(valType) do
130+
// for k in order {
131+
// yield (k,table[k]);
132+
// }
133+
134+
// iter these() where isSharedClassType(valType) {
135+
// for k in order {
136+
// yield (k,table[k]);
137+
// }
138+
// }
139+
140+
// // iter these() where isClassType(valType) {
141+
// // compilerError(valType:string);
142+
// // for k in order {
143+
// // yield (k,table[k]);
144+
// // }
145+
// // }
146+
147+
// iter these() where isClassType(valType) && !isSharedClassType(valType) {
148+
// compilerError(valType:string);
149+
// for k in order {
150+
// yield (k,table[k]);
151+
// }
152+
// }
153+
154+
// // iter these() ref {
155+
// // for k in order {
156+
// // ref v = table[k];
157+
// // yield (k,v);
158+
// // }
159+
// // }
160+
161+
162+
// // iter these() const ref {
163+
// // for k in order {
164+
// // const ref v = table[k];
165+
// // yield (k,v);
166+
// // }
167+
// // }
69168

70-
iter these() do
71-
for k in order {
72-
yield (k,table[k]);
73-
}
74169

75170
proc ref insert(in key: keyType, in value: valType) {
76171
if !order.contains(key) then

lib/StaticTensor.chpl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ operator :(in t: staticTensor(?rank,?eltType), type toType): staticTensor(rank,t
9595
return new staticTensor(b);
9696
}
9797

98-
proc tensorFromCtx(param rank: int, type eltType, ctx): staticTensor(rank,eltType) {
98+
proc tensorFromCtx(param rank: int, type eltType, ctx: ?ctxType): staticTensor(rank,eltType) {
9999
var newMeta = new owned TensorResource(eltType,rank,ctx);
100100
newMeta.forward();
101101
return new staticTensor(newMeta);
@@ -177,6 +177,11 @@ proc staticTensor.relu() {
177177
return tensorFromCtx(rank,eltType,ctx);
178178
}
179179

180+
proc staticTensor.square() {
181+
var ctx = new squareOp(meta);
182+
return tensorFromCtx(rank,eltType,ctx);
183+
}
184+
180185
proc staticTensor.gelu() {
181186
var t = new staticTensor(rank,eltType);
182187
on this.device {

0 commit comments

Comments
 (0)