@@ -126,13 +126,22 @@ record dynamicTensor : serializable {
126126 return this ;
127127 }
128128
129- proc array (param rank: int ) ref : ndarray(rank,eltType) do
129+ inline proc ref rankedArray (param rank: int ) ref : ndarray(rank,eltType) do
130130 return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).array;
131131
132- proc grad (param rank: int ) ref : ndarray(rank,eltType) do
132+ inline proc rankedArray(param rank: int ): ndarray(rank,eltType) do
133+ return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).array;
134+
135+ inline proc ref rankedGradArray(param rank: int ) ref : ndarray(rank,eltType) do
136+ return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).grad;
137+
138+ inline proc rankedGradArray(param rank: int ): ndarray(rank,eltType) do
133139 return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).grad;
134140
135- proc data (param rank: int ) ref : [] eltType do
141+ inline proc ref rankedData(param rank: int ) ref : [] eltType do
142+ return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).data;
143+
144+ inline proc rankedData(param rank: int ): [] eltType do
136145 return (this .meta.borrow() : borrowed BaseTensorResource(eltType, rank)).data;
137146
138147
@@ -156,14 +165,23 @@ record dynamicTensor : serializable {
156165 }
157166}
158167
159- operator : (in t: dynamicTensor(?eltType), type toType): dynamicTensor(toType) {
168+ operator : (in t: dynamicTensor(?eltType), type toType): dynamicTensor(toType)
169+ where isNumericType(toType) {
160170 if eltType == toType then return t;
161171 for param rank in 1 ..maxRank do
162172 if t.checkRank(rank) then
163173 return (t.forceRank(rank) : toType).eraseRank();
164174 halt(" Could not identify rank for this: " , t);
165175}
166176
177+ operator : (in t: dynamicTensor(?eltType), type toType: ndarray(?rank,?toEltType)): ndarray(rank,toEltType)
178+ where isNumericType(eltType) && isNumericType(toEltType) {
179+ if eltType == toEltType then
180+ return t.toNDArray(rank);
181+ else
182+ return t.toNDArray(rank) : toEltType;
183+ }
184+
167185proc type dynamicTensor.detachMode() param : bool {
168186 return defaultDetachedMode;
169187}
0 commit comments