Skip to content

Commit 4350656

Browse files
Add multi outputs support for topk
1 parent f57530e commit 4350656

File tree

13 files changed

+348
-24
lines changed

13 files changed

+348
-24
lines changed

forge/csrc/ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ add_library(ops
8787
op_unsqueeze.cpp
8888
op_update_cache.cpp
8989
op_upsample_2d.cpp
90+
op_topk.cpp
9091
op_where.cpp
9192
python_bindings.cpp)
9293

forge/csrc/ops/op.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class NewToOldOpType
125125
mapping_[OpType::Unsqueeze] = "unsqueeze";
126126
mapping_[OpType::UpdateCache] = "update_cache";
127127
mapping_[OpType::Upsample2d] = "upsample2d";
128+
mapping_[OpType::TopK] = "topk";
128129
mapping_[OpType::Where] = "where";
129130
}
130131

@@ -227,6 +228,7 @@ class OldToNewOpType
227228
mapping_["unsqueeze"] = OpType::Unsqueeze;
228229
mapping_["update_cache"] = OpType::UpdateCache;
229230
mapping_["upsample2d"] = OpType::Upsample2d;
231+
mapping_["topk"] = OpType::TopK;
230232
mapping_["where"] = OpType::Where;
231233
}
232234

@@ -394,6 +396,7 @@ at::Tensor Op::eval(const graphlib::OpType &old_op_type, const std::vector<at::T
394396
case OpType::Unsqueeze: return unsqueeze::eval(old_op_type, *this, tensors);
395397
case OpType::UpdateCache: return update_cache::eval(old_op_type, *this, tensors);
396398
case OpType::Upsample2d: return upsample_2d::eval(old_op_type, *this, tensors);
399+
case OpType::TopK: return topk::eval(old_op_type, *this, tensors);
397400
case OpType::Where: return where::eval(old_op_type, *this, tensors);
398401
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
399402
} // clang-format on
@@ -489,6 +492,7 @@ std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcast>> Op::shape(
489492
case OpType::Unsqueeze: return unsqueeze::shape(old_op_type, *this, inputs);
490493
case OpType::UpdateCache: return update_cache::shape(old_op_type, *this, inputs);
491494
case OpType::Upsample2d: return upsample_2d::shape(old_op_type, *this, inputs);
495+
case OpType::TopK: return topk::shape(old_op_type, *this, inputs);
492496
case OpType::Where: return where::shape(old_op_type, *this, inputs);
493497
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
494498
} // clang-format on
@@ -589,6 +593,7 @@ tt::graphlib::NodeContext Op::backward(
589593
case OpType::Unsqueeze: return unsqueeze::backward(old_op_type, *this, context, operand, inputs, output, gradient);
590594
case OpType::UpdateCache: return update_cache::backward(old_op_type, *this, context, operand, inputs, output, gradient);
591595
case OpType::Upsample2d: return upsample_2d::backward(old_op_type, *this, context, operand, inputs, output, gradient);
596+
case OpType::TopK: return topk::backward(old_op_type, *this, context, operand, inputs, output, gradient);
592597
case OpType::Where: return where::backward(old_op_type, *this, context, operand, inputs, output, gradient);
593598
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
594599
} // clang-format on
@@ -706,6 +711,7 @@ void Op::decompose_initial(
706711
case OpType::Unsqueeze: return;
707712
case OpType::UpdateCache: return;
708713
case OpType::Upsample2d: return;
714+
case OpType::TopK: return;
709715
case OpType::Where: return where::decompose_initial(old_op_type, *this, dc, inputs);
710716
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
711717
} // clang-format on
@@ -802,6 +808,7 @@ void Op::decompose_post_optimize(
802808
case OpType::Unsqueeze: return;
803809
case OpType::UpdateCache: return;
804810
case OpType::Upsample2d: return;
811+
case OpType::TopK: return;
805812
case OpType::Where: return where::decompose_post_optimize(old_op_type, *this, dc, inputs);
806813
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
807814
} // clang-format on
@@ -899,6 +906,7 @@ void Op::decompose_post_autograd(
899906
case OpType::Unsqueeze: return;
900907
case OpType::UpdateCache: return;
901908
case OpType::Upsample2d: return;
909+
case OpType::TopK: return;
902910
case OpType::Where: return where::decompose_post_autograd(old_op_type, *this, dc, inputs);
903911
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
904912
} // clang-format on
@@ -994,6 +1002,7 @@ long Op::initial_flops_estimate(
9941002
case OpType::Unsqueeze: return 0;
9951003
case OpType::UpdateCache: return 0;
9961004
case OpType::Upsample2d: return 0;
1005+
case OpType::TopK: return 0;
9971006
case OpType::Where: return where::initial_flops_estimate(old_op_type, *this, inputs);
9981007
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
9991008
} // clang-format on
@@ -1088,6 +1097,7 @@ bool Op::is_tm(const graphlib::OpType &old_op_type) const
10881097
case OpType::Unsqueeze: return true;
10891098
case OpType::UpdateCache: return false;
10901099
case OpType::Upsample2d: return false;
1100+
case OpType::TopK: return false;
10911101
case OpType::Where: return false;
10921102
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
10931103
}
@@ -1182,6 +1192,7 @@ bool Op::is_eltwise(const graphlib::OpType &old_op_type) const
11821192
case OpType::Unsqueeze: return false;
11831193
case OpType::UpdateCache: return false;
11841194
case OpType::Upsample2d: return false;
1195+
case OpType::TopK: return false;
11851196
case OpType::Where: return true;
11861197
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
11871198
}
@@ -1276,6 +1287,7 @@ bool Op::is_eltwise_unary(const graphlib::OpType &old_op_type) const
12761287
case OpType::Unsqueeze: return false;
12771288
case OpType::UpdateCache: return false;
12781289
case OpType::Upsample2d: return false;
1290+
case OpType::TopK: return false;
12791291
case OpType::Where: return false;
12801292
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
12811293
}
@@ -1370,6 +1382,7 @@ bool Op::is_eltwise_binary(const graphlib::OpType &old_op_type) const
13701382
case OpType::Unsqueeze: return false;
13711383
case OpType::UpdateCache: return false;
13721384
case OpType::Upsample2d: return false;
1385+
case OpType::TopK: return false;
13731386
case OpType::Where: return false;
13741387
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
13751388
}
@@ -1463,6 +1476,7 @@ bool Op::is_eltwise_nary(const graphlib::OpType &old_op_type) const
14631476
case OpType::Unsqueeze: return false;
14641477
case OpType::UpdateCache: return false;
14651478
case OpType::Upsample2d: return false;
1479+
case OpType::TopK: return false;
14661480
case OpType::Where: return true;
14671481
default: TT_ASSERT(false, "Unknown OpType."); unreachable();
14681482
}

forge/csrc/ops/op.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ enum class OpType : uint32_t
129129
Unsqueeze,
130130
UpdateCache,
131131
Upsample2d,
132+
TopK,
132133
Where,
133134
};
134135

forge/csrc/ops/op_interface.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ DECLARE_OP_INTERFACE(transpose);
166166
DECLARE_OP_INTERFACE(unsqueeze);
167167
DECLARE_OP_INTERFACE(update_cache);
168168
DECLARE_OP_INTERFACE(upsample_2d);
169+
DECLARE_OP_INTERFACE(topk);
169170
DECLARE_OP_INTERFACE(where);
170171

171172
#undef DECLARE_OP_INTERFACE

forge/csrc/ops/op_topk.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include <optional>
6+
#include <tuple>
7+
#include <vector>
8+
9+
#include "autograd/autograd.hpp"
10+
#include "graph_lib/node_types.hpp"
11+
#include "graph_lib/shape.hpp"
12+
#include "op.hpp"
13+
#include "op_interface.hpp"
14+
#include "ops/op_common.hpp"
15+
#include "torch/extension.h" // Needed for c++ to/from python type conversion.
16+
#include "torch/torch.h"
17+
#include "utils/assert.hpp"
18+
19+
namespace tt
20+
{
21+
namespace ops
22+
{
23+
namespace topk
24+
{
25+
using namespace graphlib;
26+
27+
// Attributes expected:
28+
// - k: int (required)
29+
// - dim: int (required)
30+
// - largest: bool (optional; default true)
31+
// - sorted: bool (optional; default true)
32+
33+
at::Tensor eval(const graphlib::OpType &old_op_type, const Op &op, const std::vector<at::Tensor> &tensors)
34+
{
35+
TT_DBG_ASSERT(op.type() == OpType::TopK, "Wrong op type.");
36+
TT_ASSERT(tensors.size() == 1, "TopK should have one input tensor");
37+
38+
const int64_t k = static_cast<int64_t>(op.attr_as<int>("k"));
39+
const int64_t dim = static_cast<int64_t>(op.attr_as<int>("dim"));
40+
const bool largest = op.has_attr("largest") ? op.attr_as<bool>("largest") : true;
41+
const bool sorted = op.has_attr("sorted") ? op.attr_as<bool>("sorted") : true;
42+
43+
// torch::topk returns a tuple (values, indices). Our infra is single-output; return values for now.
44+
auto result = torch::topk(tensors[0], k, dim, largest, sorted);
45+
at::Tensor values = std::get<0>(result);
46+
// at::Tensor indices = std::get<1>(result); // kept for future multi-output support
47+
48+
return values;
49+
}
50+
51+
std::tuple<Shape, std::vector<DimBroadcast>> shape(
52+
const graphlib::OpType &old_op_type, const Op &op, const std::vector<std::vector<std::uint32_t>> &in_shapes)
53+
{
54+
TT_DBG_ASSERT(op.type() == OpType::TopK, "Wrong op type.");
55+
TT_ASSERT(in_shapes.size() == 1, "TopK should have one input shape");
56+
57+
const auto &input = in_shapes[0];
58+
TT_ASSERT(!input.empty(), "TopK input must have rank >= 1");
59+
60+
const int dim = op.attr_as<int>("dim");
61+
TT_ASSERT(dim >= -static_cast<int>(input.size()) && dim < static_cast<int>(input.size()), "TopK dim out of range");
62+
63+
const int pos_dim = dim < 0 ? dim + static_cast<int>(input.size()) : dim;
64+
std::vector<uint32_t> out_shape = input;
65+
out_shape[pos_dim] = static_cast<uint32_t>(op.attr_as<int>("k"));
66+
67+
return {Shape::create(out_shape), {}};
68+
}
69+
70+
// No autograd for now
71+
72+
tt::graphlib::NodeContext backward(
73+
const graphlib::OpType &old_op_type,
74+
const Op &op,
75+
autograd::autograd_context &ac,
76+
int operand,
77+
const std::vector<NodeContext> &inputs,
78+
const NodeContext &output,
79+
const NodeContext &gradient)
80+
{
81+
TT_DBG_ASSERT(op.type() == OpType::TopK, "Wrong op type.");
82+
TT_THROW(false, "TopK does not have backward.");
83+
unreachable();
84+
}
85+
86+
} // namespace topk
87+
} // namespace ops
88+
} // namespace tt

forge/forge/compile.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,9 @@ def generate_graph(
11361136
input_names_known = False
11371137
inputs, _, _ = flatten_inputs(inputs)
11381138

1139+
# Track counts to ensure unique output names per base name
1140+
output_name_counts: Dict[str, int] = {}
1141+
11391142
for out in all_subgraph_outputs:
11401143
module = output_to_module_map[out]
11411144
assert module is not None
@@ -1164,9 +1167,14 @@ def generate_graph(
11641167
raise RuntimeError("Untraced output tensor encountered")
11651168

11661169
else:
1170+
base_name = module_name + ".output_" + out.src_op.name
1171+
count = output_name_counts.get(base_name, 0)
1172+
unique_name = base_name if count == 0 else f"{base_name}_{count}"
1173+
output_name_counts[base_name] = count + 1
1174+
11671175
outq = create_output(
11681176
graph,
1169-
module_name + ".output_" + out.src_op.name,
1177+
unique_name,
11701178
out.shape.get_pytorch_shape(),
11711179
out.data_format,
11721180
module.is_loss,
@@ -1177,8 +1185,10 @@ def generate_graph(
11771185

11781186
recorded_parameters = {}
11791187

1180-
while pending_tensors:
1188+
# Map to ensure we create an op node only once per source op name
1189+
op_node_by_name: Dict[str, int] = {}
11811190

1191+
while pending_tensors:
11821192
tensor, output, port_index, operand_broadcast, subgraph_idx = pending_tensors.popleft()
11831193

11841194
if tensor in visited_tensors:
@@ -1333,15 +1343,21 @@ def generate_graph(
13331343
tags = {}
13341344
if tensor.src_layer is not None:
13351345
tags["layer"] = tensor.src_layer
1336-
op = create_op_node(
1337-
graph,
1338-
tensor.src_op.name,
1339-
tensor.src_op.cpp_op_type,
1340-
tensor.shape.get_pytorch_shape(),
1341-
tensor.data_format,
1342-
subgraph_idx,
1343-
tags,
1344-
)
1346+
# Reuse the same op node if we already created one for this src_op.name
1347+
existing = op_node_by_name.get(tensor.src_op.name)
1348+
if existing is not None:
1349+
op = existing
1350+
else:
1351+
op = create_op_node(
1352+
graph,
1353+
tensor.src_op.name,
1354+
tensor.src_op.cpp_op_type,
1355+
tensor.shape.get_pytorch_shape(),
1356+
tensor.data_format,
1357+
subgraph_idx,
1358+
tags,
1359+
)
1360+
op_node_by_name[tensor.src_op.name] = op
13451361

13461362
visited_tensors[tensor] = op
13471363
if return_intermediate and tensor.has_value():
@@ -1364,6 +1380,7 @@ def generate_graph(
13641380
if output_tensor in module_output_tensor_to_node
13651381
]
13661382
module_targets = [module_target_tensor_to_node[target_tensor] for target_tensor in target_tensors]
1383+
13671384
out_requires_grad = [
13681385
output_tensor.requires_grad
13691386
for output_tensor in all_subgraph_outputs

forge/forge/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@
7272
from .kv_cache import FillCache, UpdateCache
7373
from .misc import CumSum
7474
import forge.op.loss
75+
from .topk import TopK

0 commit comments

Comments
 (0)