Skip to content

Commit f243c42

Browse files
authored
Implemented constant op in cpp. (#2419)
For detailed info, check out: #2378
1 parent 108eee7 commit f243c42

File tree

8 files changed

+104
-44
lines changed

8 files changed

+104
-44
lines changed

forge/csrc/graph_lib/node_types.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,10 @@ void TaggedNode::add_tags(const TagHints &other_tags) { this->hints.insert(other
497497

498498
const TagHints &TaggedNode::get_tags() const { return this->hints; }
499499

500-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
501-
// Calculations. This is temporary implementation in ops transition period. It will be deleted once all ops are //
502-
// migrated from python to cpp.
503-
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
500+
/**
501+
* Calculations. This is temporary implementation in ops transition period. It will be deleted once all ops are
502+
* migrated from python to cpp.
503+
*/
504504

505505
at::Tensor OpType::eval(const std::vector<at::Tensor> &tensors) const { return new_op_.eval(*this, tensors); }
506506

forge/csrc/ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_library(ops
22
STATIC
33
op.cpp
44
op_abs.cpp
5+
op_constant.cpp
56
python_bindings.cpp)
67

78
target_link_libraries(ops PUBLIC coverage_config)

forge/csrc/ops/op.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ at::Tensor Op::eval(const graphlib::OpType &old_op_type, const std::vector<at::T
399399
switch (type_)
400400
{
401401
case OpType::Abs: return abs_eval(tensors);
402+
case OpType::Constant: return constant_eval(tensors);
402403
default: return base_eval(old_op_type, tensors);
403404
}
404405
}
@@ -409,6 +410,7 @@ std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcast>> Op::shape(
409410
switch (type_)
410411
{
411412
case OpType::Abs: return abs_shape(inputs);
413+
case OpType::Constant: return constant_shape(inputs);
412414
default: return base_shape(old_op_type, inputs);
413415
}
414416
}
@@ -424,6 +426,7 @@ tt::graphlib::NodeContext Op::backward(
424426
switch (type_)
425427
{
426428
case OpType::Abs: return abs_backward(context, operand, inputs, output, gradient);
429+
case OpType::Constant: return constant_backward(context, operand, inputs, output, gradient);
427430
default: return base_backward(old_op_type, context, operand, inputs, output, gradient);
428431
}
429432
}
@@ -440,6 +443,7 @@ void Op::decompose(
440443
switch (type_)
441444
{
442445
case OpType::Abs: return;
446+
case OpType::Constant: return;
443447
default: return base_decompose(old_op_type, dispatch, dc, inputs);
444448
}
445449
}
@@ -450,6 +454,7 @@ long Op::initial_flops_estimate(
450454
switch (type_)
451455
{
452456
case OpType::Abs: return abs_initial_flops_estimate(inputs);
457+
case OpType::Constant: return 0;
453458
default: return base_initial_flops_estimate(old_op_type, inputs);
454459
}
455460
}
@@ -459,6 +464,7 @@ bool Op::is_tm(const graphlib::OpType &old_op_type) const
459464
switch (type_)
460465
{
461466
case OpType::Abs: return false;
467+
case OpType::Constant: return false;
462468
default: return base_is_tm(old_op_type);
463469
}
464470
}
@@ -468,6 +474,7 @@ bool Op::is_eltwise(const graphlib::OpType &old_op_type) const
468474
switch (type_)
469475
{
470476
case OpType::Abs: return true;
477+
case OpType::Constant: return false;
471478
default: return base_is_eltwise(old_op_type);
472479
}
473480
}
@@ -477,6 +484,7 @@ bool Op::is_eltwise_unary(const graphlib::OpType &old_op_type) const
477484
switch (type_)
478485
{
479486
case OpType::Abs: return true;
487+
case OpType::Constant: return false;
480488
default: return base_is_eltwise_unary(old_op_type);
481489
}
482490
}
@@ -486,6 +494,7 @@ bool Op::is_eltwise_binary(const graphlib::OpType &old_op_type) const
486494
switch (type_)
487495
{
488496
case OpType::Abs: return false;
497+
case OpType::Constant: return false;
489498
default: return base_is_eltwise_binary(old_op_type);
490499
}
491500
}
@@ -494,6 +503,7 @@ bool Op::is_eltwise_nary(const graphlib::OpType &old_op_type) const
494503
switch (type_)
495504
{
496505
case OpType::Abs: return false;
506+
case OpType::Constant: return false;
497507
default: return base_is_eltwise_nary(old_op_type);
498508
}
499509
}

forge/csrc/ops/op.hpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,12 @@ class Op
202202

203203
const std::string &as_string() const;
204204

205-
/* ------------------------------------------------------------*
206-
* Calculations segment. Derived classes must implement these. *
207-
* ------------------------------------------------------------*/
205+
/* ----------------------------------------------------*
206+
* Calculations segment. All ops must implement these. *
207+
* ----------------------------------------------------*/
208208

209209
at::Tensor eval(const graphlib::OpType &old_op_type, const std::vector<at::Tensor> &tensors) const;
210+
210211
std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> shape(
211212
const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t>> &inputs) const;
212213

@@ -218,6 +219,16 @@ class Op
218219
const tt::graphlib::NodeContext &output,
219220
const tt::graphlib::NodeContext &gradient) const;
220221

222+
bool is_tm(const graphlib::OpType &old_op_type) const;
223+
bool is_eltwise(const graphlib::OpType &old_op_type) const;
224+
bool is_eltwise_unary(const graphlib::OpType &old_op_type) const;
225+
bool is_eltwise_binary(const graphlib::OpType &old_op_type) const;
226+
bool is_eltwise_nary(const graphlib::OpType &old_op_type) const;
227+
228+
/* --------------------------*
229+
* Optional implementations. *
230+
* --------------------------*/
231+
221232
void decompose(
222233
const graphlib::OpType &old_op_type,
223234
const char *dispatch,
@@ -227,18 +238,13 @@ class Op
227238
long initial_flops_estimate(
228239
const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t>> &inputs) const;
229240

230-
bool is_tm(const graphlib::OpType &old_op_type) const;
231-
bool is_eltwise(const graphlib::OpType &old_op_type) const;
232-
bool is_eltwise_unary(const graphlib::OpType &old_op_type) const;
233-
bool is_eltwise_binary(const graphlib::OpType &old_op_type) const;
234-
bool is_eltwise_nary(const graphlib::OpType &old_op_type) const;
235-
236241
private:
237242
/* ------------------------------------------------------------*
238243
* Base - common for all ops that are not yet migrated to cpp. *
239244
* ------------------------------------------------------------*/
240245

241246
at::Tensor base_eval(const graphlib::OpType &old_op_type, const std::vector<at::Tensor> &tensors) const;
247+
242248
std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> base_shape(
243249
const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t>> &inputs) const;
244250

@@ -250,6 +256,12 @@ class Op
250256
const tt::graphlib::NodeContext &output,
251257
const tt::graphlib::NodeContext &gradient) const;
252258

259+
bool base_is_tm(const graphlib::OpType &old_op_type) const;
260+
bool base_is_eltwise(const graphlib::OpType &old_op_type) const;
261+
bool base_is_eltwise_unary(const graphlib::OpType &old_op_type) const;
262+
bool base_is_eltwise_binary(const graphlib::OpType &old_op_type) const;
263+
bool base_is_eltwise_nary(const graphlib::OpType &old_op_type) const;
264+
253265
void base_decompose(
254266
const graphlib::OpType &old_op_type,
255267
const char *dispatch,
@@ -259,12 +271,6 @@ class Op
259271
long base_initial_flops_estimate(
260272
const graphlib::OpType &old_op_type, const std::vector<std::vector<std::uint32_t>> &inputs) const;
261273

262-
bool base_is_tm(const graphlib::OpType &old_op_type) const;
263-
bool base_is_eltwise(const graphlib::OpType &old_op_type) const;
264-
bool base_is_eltwise_unary(const graphlib::OpType &old_op_type) const;
265-
bool base_is_eltwise_binary(const graphlib::OpType &old_op_type) const;
266-
bool base_is_eltwise_nary(const graphlib::OpType &old_op_type) const;
267-
268274
/* -----------------------------*
269275
* Ops specific implementation. *
270276
* -----------------------------*/
@@ -274,6 +280,7 @@ class Op
274280
* -------------*/
275281

276282
at::Tensor abs_eval(const std::vector<at::Tensor> &tensors) const;
283+
277284
std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> abs_shape(
278285
const std::vector<std::vector<std::uint32_t>> &inputs) const;
279286

@@ -286,6 +293,22 @@ class Op
286293

287294
long abs_initial_flops_estimate(const std::vector<std::vector<std::uint32_t>> &inputs) const;
288295

296+
/* ------------------*
297+
* OpType::Constant. *
298+
* ------------------*/
299+
300+
at::Tensor constant_eval(const std::vector<at::Tensor> &tensors) const;
301+
302+
std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcastTrampoline>> constant_shape(
303+
const std::vector<std::vector<std::uint32_t>> &inputs) const;
304+
305+
tt::graphlib::NodeContext constant_backward(
306+
tt::autograd::autograd_context &context,
307+
int operand,
308+
const std::vector<tt::graphlib::NodeContext> &inputs,
309+
const tt::graphlib::NodeContext &output,
310+
const tt::graphlib::NodeContext &gradient) const;
311+
289312
private:
290313
OpType type_;
291314
Attrs attrs_;

forge/csrc/ops/op_constant.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include <ATen/ops/zeros.h>
6+
7+
#include "autograd/autograd.hpp"
8+
#include "graph_lib/shape.hpp"
9+
#include "op.hpp"
10+
#include "torch/extension.h"
11+
#include "torch/torch.h"
12+
#include "utils/assert.hpp"
13+
14+
namespace tt
15+
{
16+
namespace ops
17+
{
18+
at::Tensor Op::constant_eval(const std::vector<at::Tensor> &tensors) const
19+
{
20+
TT_DBG_ASSERT(type_ == OpType::Constant, "Wrong op type.");
21+
TT_DBG_ASSERT(tensors.size() == 0, "Constant eval should not have any operands");
22+
TT_DBG_ASSERT(attrs().size() == 1, "Constant eval should contain 1 attr.");
23+
24+
return torch::tensor({attr_as<float>("c")});
25+
}
26+
27+
std::tuple<graphlib::Shape, std::vector<graphlib::DimBroadcast>> Op::constant_shape(
28+
const std::vector<std::vector<std::uint32_t>> &in_shapes) const
29+
{
30+
TT_DBG_ASSERT(type_ == OpType::Constant, "Wrong op type.");
31+
TT_DBG_ASSERT(in_shapes.size() == 0, "Constant should not have any operands");
32+
33+
return std::make_tuple(graphlib::Shape::create({1}), std::vector<graphlib::DimBroadcast>{});
34+
}
35+
36+
tt::graphlib::NodeContext Op::constant_backward(
37+
tt::autograd::autograd_context &context,
38+
int operand,
39+
const std::vector<tt::graphlib::NodeContext> &inputs,
40+
const tt::graphlib::NodeContext &output,
41+
const tt::graphlib::NodeContext &gradient) const
42+
{
43+
TT_DBG_ASSERT(type_ == OpType::Constant, "Wrong op type.");
44+
TT_THROW("OpType::Constant does not have backward.");
45+
__builtin_unreachable();
46+
}
47+
48+
} // namespace ops
49+
} // namespace tt

forge/forge/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1352,7 +1352,7 @@ def generate_graph(
13521352
continue
13531353

13541354
elif tensor.src_op.op_type == "constant":
1355-
constant_value = tensor.src_op.attrs[0]
1355+
constant_value = tensor.src_op.named_attrs["c"]
13561356
constant = create_constant_input(
13571357
graph,
13581358
"constant_" + str(port_index) + "_" + graph.get_node_name(output),

forge/forge/op/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def Constant(name: str, *, constant: float) -> Tensor:
2222
Tensor
2323
Forge tensor
2424
"""
25-
return op("constant", name, attrs=(constant,)).get_tensor()
25+
return op("constant", name, **{"c": constant}).get_tensor()

forge/forge/op/eval/forge/constant.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)