@@ -2140,36 +2140,46 @@ proc type ndarray.nllLoss(
21402140 target: ndarray(1 ,eltType),
21412141 weight: ndarray(1 , eltType),
21422142 ignoreIndex: int = - 1 ,
2143- red: bool = true ,
21442143 reduction: string = " mean"
21452144): ndarray(1 ,eltType) {
2146- const (N,C) = input.shape;
2147- assert(target.shape[ 0 ] == N, " Target shape must match batch size." );
2148- assert(weight.shape[ 0 ] == C, " Weights shape must match number of classes." );
2145+ int reduction_int = 1 ;
2146+ if reduction == " sum" then reduction_int = 2 ;
2147+ if reduction == " none" then reduction_int = 0 ;
2148+
2149+ return Bridge.nllLoss(
2150+ input: Bridge.tensorHandle(eltType),
2151+ target: Bridge.tensorHandle(eltType),
2152+ weight: Bridge.tensorHandle(eltType),
2153+ ignoreIndex,
2154+ reduction
2155+ ) : ndarray(rank,eltType);
2156+ // const (N,C) = input.shape;
2157+ // assert(target.shape[0] == N, "Target shape must match batch size.");
2158+ // assert(weight.shape[0] == C, "Weights shape must match number of classes.");
21492159
2150- const dom = util.domainFromShape(N);
2151- var loss = new ndarray(dom, eltType);
2152- ref x = input.data;
2153- ref y = target.data;
2154- ref w = weight.data;
2155- ref lossD = loss.data;
2156- var wynSum: real = 0.0 ;
2157-
2158- forall n in 0 ..< N with (+ reduce wynSum) {
2159- const yn: int = y[ n] : int ;
2160- if yn == ignoreIndex {
2161- lossD[ n] = 0.0 ;
2162- }
2163- else {
2164- lossD[ n] = - w[ yn] * x[ n,yn] ;
2165- wynSum += w[ yn] ;
2166- }
2167- }
2160+ // const dom = util.domainFromShape(N);
2161+ // var loss = new ndarray(dom, eltType);
2162+ // ref x = input.data;
2163+ // ref y = target.data;
2164+ // ref w = weight.data;
2165+ // ref lossD = loss.data;
2166+ // var wynSum: real = 0.0;
2167+
2168+ // forall n in 0..<N with (+ reduce wynSum) {
2169+ // const yn: int = y[n]:int;
2170+ // if yn == ignoreIndex {
2171+ // lossD[n] = 0.0;
2172+ // }
2173+ // else {
2174+ // lossD[n] = -w[yn]*x[n,yn];
2175+ // wynSum += w[yn];
2176+ // }
2177+ // }
21682178
2169- if !red then return loss;
2170- if reduction == " mean" then return loss.sum(0 ) / wynSum;
2171- if reduction == " sum" then return loss.sum(0 );
2172- halt(" Invalid reduction mode: " + reduction);
2179+ // if !red then return loss;
2180+ // if reduction == "mean" then return loss.sum(0) / wynSum;
2181+ // if reduction == "sum" then return loss.sum(0);
2182+ // halt("Invalid reduction mode: " + reduction);
21732183}
21742184
21752185module ndarrayRandom {
0 commit comments