Skip to content

Commit 4eb99a1

Browse files
committed
nll libtorch
1 parent 0a2984e commit 4eb99a1

7 files changed

Lines changed: 88 additions & 34 deletions

File tree

bridge/include/bridge.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ bridge_tensor_t conv2d(
7272
int padding
7373
);
7474

75+
bridge_tensor_t nll_loss(
76+
bridge_tensor_t input,
77+
bridge_tensor_t target,
78+
bridge_tensor_t weight,
79+
int ignoreIndex,
80+
int reduction
81+
);
82+
7583
bridge_tensor_t matmul(bridge_tensor_t a, bridge_tensor_t b);
7684

7785
bridge_tensor_t max_pool2d(

bridge/lib/bridge.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,39 @@ extern "C" bridge_tensor_t add_two_arrays(bridge_tensor_t a, bridge_tensor_t b)
279279
return torch_to_bridge(output);
280280
}
281281

282+
extern "C" bridge_tensor_t nll_loss(
283+
bridge_tensor_t input,
284+
bridge_tensor_t target,
285+
bridge_tensor_t weight,
286+
int ignoreIndex,
287+
int reduction
288+
) {
289+
// Convert bridge_tensor_t to torch::Tensor
290+
torch::Tensor t_input = bridge_to_torch(input);
291+
torch::Tensor t_target = bridge_to_torch(target);
292+
torch::Tensor t_weight = bridge_to_torch(weight);
293+
294+
// Map reduction int to string
295+
std::string reduction_str;
296+
switch (reduction) {
297+
case 0: reduction_str = "none"; break;
298+
case 1: reduction_str = "mean"; break;
299+
case 2: reduction_str = "sum"; break;
300+
default: reduction_str = "mean"; break; // fallback default
301+
}
302+
303+
torch::Tensor output = torch::nn::functional::nll_loss(
304+
t_input,
305+
t_target,
306+
torch::nn::functional::NLLLossFuncOptions()
307+
.weight(t_weight)
308+
.ignore_index(ignoreIndex)
309+
.reduction(reduction_str)
310+
);
311+
312+
return torch_to_bridge(output);
313+
}
314+
282315
// extern "C" bridge_tensor_t capture_webcam_bridge(int cam_index) {
283316
// torch::Tensor image = capture_webcam(cam_index);
284317
// return torch_to_bridge(image);

lib/Autograd.chpl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,13 +1084,12 @@ record nllLossOp : serializable {
10841084
var target: shared BaseTensorResource(?);
10851085
var weight: shared BaseTensorResource(?);
10861086
var ignoreIndex: int;
1087-
var red: bool;
10881087
var reduction: string;
10891088

10901089
proc children do return (input,target,weight);
10911090

10921091
proc forward() do
1093-
return ndarray.nllLoss(input.array,target.array,weight.array,ignoreIndex,red,reduction);
1092+
return ndarray.nllLoss(input.array,target.array,weight.array,ignoreIndex,reduction);
10941093

10951094
proc spec : GradOpSpec do return new dict(("operation", "nllLoss"));
10961095
}

lib/Bridge.chpl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ module Bridge {
8787
in a: bridge_tensor_t,
8888
in b: bridge_tensor_t): bridge_tensor_t;
8989

90+
extern "nll_loss" proc nllLoss(
91+
in input: bridge_tensor_t,
92+
in target: ndarray(1,eltType),
93+
in weight: ndarray(1, eltType),
94+
in ignoreIndex: int(32),
95+
in reduction: int(32): bridge_tensor_t;
96+
9097
extern "split_loop" proc splitLoop(idx: int(64), n: int(64)): void;
9198

9299
extern "split_loop_filler" proc splitLoopFiller(n: int(64),ret: c_ptr(int(64))): void;

lib/DynamicTensor.chpl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,14 +733,13 @@ proc type dynamicTensor.nllLoss(
733733
target: dynamicTensor(eltType),
734734
weight: dynamicTensor(eltType),
735735
ignoreIndex: int = -1,
736-
red: bool = true,
737736
reduction: string = "mean"
738737
) {
739738
for param rankIn in 2..2 {
740739
if input.checkRank(rankIn) {
741740
for param rank in 1..1 {
742741
if target.checkRank(rankIn) && weight.checkRank(rank) {
743-
return staticTensor.nllLoss(input.forceRank(rankIn),target.forceRank(rank),weight.forceRank(rank),ignoreIndex,red,reduction);
742+
return staticTensor.nllLoss(input.forceRank(rankIn),target.forceRank(rank),weight.forceRank(rank),ignoreIndex,reduction);
744743
}
745744
}
746745
}
@@ -751,7 +750,6 @@ proc type dynamicTensor.nllLoss(
751750
input: dynamicTensor(?eltType),
752751
target: dynamicTensor(eltType),
753752
ignoreIndex: int = -1,
754-
red: bool = true,
755753
reduction: string = "mean"
756754
) {
757755
param inRank: int = 2;
@@ -762,7 +760,7 @@ proc type dynamicTensor.nllLoss(
762760
var stInput: staticTensor(inRank,eltType) = input.forceRank(inRank);
763761
var stTarget: staticTensor(targetRank,eltType) = target.forceRank(targetRank);
764762
var weights: staticTensor(1,eltType) = staticTensor.ones(eltType,3);
765-
return staticTensor.nllLoss(stInput,stTarget,weights,ignoreIndex,red,reduction);
763+
return staticTensor.nllLoss(stInput,stTarget,weights,ignoreIndex,reduction);
766764
}
767765
}
768766

lib/NDArray.chpl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,36 +2140,46 @@ proc type ndarray.nllLoss(
21402140
target: ndarray(1,eltType),
21412141
weight: ndarray(1, eltType),
21422142
ignoreIndex: int = -1,
2143-
red: bool = true,
21442143
reduction: string = "mean"
21452144
): ndarray(1,eltType) {
2146-
const (N,C) = input.shape;
2147-
assert(target.shape[0] == N, "Target shape must match batch size.");
2148-
assert(weight.shape[0] == C, "Weights shape must match number of classes.");
2145+
int reduction_int = 1;
2146+
if reduction == "sum" then reduction_int = 2;
2147+
if reduction == "none" then reduction_int = 0;
2148+
2149+
return Bridge.nllLoss(
2150+
input: Bridge.tensorHandle(eltType),
2151+
target: Bridge.tensorHandle(eltType),
2152+
weight: Bridge.tensorHandle(eltType),
2153+
ignoreIndex,
2154+
reduction
2155+
) : ndarray(rank,eltType);
2156+
// const (N,C) = input.shape;
2157+
// assert(target.shape[0] == N, "Target shape must match batch size.");
2158+
// assert(weight.shape[0] == C, "Weights shape must match number of classes.");
21492159

2150-
const dom = util.domainFromShape(N);
2151-
var loss = new ndarray(dom, eltType);
2152-
ref x = input.data;
2153-
ref y = target.data;
2154-
ref w = weight.data;
2155-
ref lossD = loss.data;
2156-
var wynSum: real = 0.0;
2157-
2158-
forall n in 0..<N with (+ reduce wynSum) {
2159-
const yn: int = y[n]:int;
2160-
if yn == ignoreIndex {
2161-
lossD[n] = 0.0;
2162-
}
2163-
else {
2164-
lossD[n] = -w[yn]*x[n,yn];
2165-
wynSum += w[yn];
2166-
}
2167-
}
2160+
// const dom = util.domainFromShape(N);
2161+
// var loss = new ndarray(dom, eltType);
2162+
// ref x = input.data;
2163+
// ref y = target.data;
2164+
// ref w = weight.data;
2165+
// ref lossD = loss.data;
2166+
// var wynSum: real = 0.0;
2167+
2168+
// forall n in 0..<N with (+ reduce wynSum) {
2169+
// const yn: int = y[n]:int;
2170+
// if yn == ignoreIndex {
2171+
// lossD[n] = 0.0;
2172+
// }
2173+
// else {
2174+
// lossD[n] = -w[yn]*x[n,yn];
2175+
// wynSum += w[yn];
2176+
// }
2177+
// }
21682178

2169-
if !red then return loss;
2170-
if reduction == "mean" then return loss.sum(0) / wynSum;
2171-
if reduction == "sum" then return loss.sum(0);
2172-
halt("Invalid reduction mode: " + reduction);
2179+
// if !red then return loss;
2180+
// if reduction == "mean" then return loss.sum(0) / wynSum;
2181+
// if reduction == "sum" then return loss.sum(0);
2182+
// halt("Invalid reduction mode: " + reduction);
21732183
}
21742184

21752185
module ndarrayRandom {

lib/StaticTensor.chpl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,9 @@ proc type staticTensor.nllLoss(
517517
target: staticTensor(1,eltType),
518518
weight: staticTensor(1,eltType),
519519
ignoreIndex: int = -1,
520-
red: bool = true,
521520
reduction: string = "mean"
522521
) {
523-
var ctx = new nllLossOp(input.meta,target.meta,weight.meta,ignoreIndex,red,reduction);
522+
var ctx = new nllLossOp(input.meta,target.meta,weight.meta,ignoreIndex,reduction);
524523
return tensorFromCtx(1,eltType,ctx);
525524
}
526525

0 commit comments

Comments
 (0)