Skip to content

Commit 5bd202e

Browse files
Fix shape mismatch caused by broadcast-to-TMS conversion when inserting on Cast edge
1 parent 9bfd4a7 commit 5bd202e

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

forge/csrc/passes/lower_to_mlir.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class AttributeMapper
170170

171171
add_op_mapping("upsample2d", "scale_factor", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
172172

173+
// Broadcast
174+
add_op_mapping(
175+
"broadcast", "broadcast_dimensions", AttributeRemap(std::nullopt, TargetType::DenseI64ArrayAttr));
176+
173177
// Add more default mappings here
174178
}
175179
};
@@ -811,6 +815,7 @@ class MLIRGenerator
811815
lowering_handler_map["avg_pool2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AvgPool2dOp>;
812816
lowering_handler_map["batchnorm"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BatchNormOp>;
813817
lowering_handler_map["bitwise_and"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BitwiseAndOp>;
818+
lowering_handler_map["broadcast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BroadcastOp>;
814819
lowering_handler_map["cast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TypecastOp>;
815820
lowering_handler_map["clip"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ClampScalarOp>;
816821
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;

forge/csrc/passes/pre_lowering_passes.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,37 @@ using EdgeType = graphlib::EdgeType;
1818

1919
void convert_broadcast_ops_to_tms(Graph *graph)
2020
{
21-
std::vector<Node *> broadcast_ops = graph->nodes(
22-
[](Node *node) -> bool
21+
// Determines if the consumer node is unary to allow broadcast propagation.
22+
// If not unary, the broadcast is lowered to TMs and inserted on the consumer edge.
23+
24+
std::vector<graphlib::Node *> broadcast_ops = graph->nodes(
25+
[](graphlib::Node *node) -> bool
2326
{
2427
graphlib::OpNode *op = dynamic_cast<graphlib::OpNode *>(node);
2528
return op and op->op_type() == ops::OpType::Broadcast;
2629
});
2730

2831
for (Node *node : broadcast_ops)
2932
{
30-
graphlib::OpNode *op = node->as<graphlib::OpNode>();
31-
ops::Op op_type = op->op();
33+
graphlib::OpNode *op = dynamic_cast<graphlib::OpNode *>(node);
34+
if (not op)
35+
continue;
36+
37+
graphlib::OpType op_type = op->op_type();
38+
bool has_unary_consumer = false;
39+
for (graphlib::Edge user_edge : graph->user_data_edges(node))
40+
{
41+
graphlib::OpNode *consumer_op =
42+
dynamic_cast<graphlib::OpNode *>(graph->node_by_id(user_edge.consumer_node_id));
43+
if (consumer_op && consumer_op->is_eltwise_unary())
44+
{
45+
has_unary_consumer = true;
46+
break;
47+
}
48+
}
49+
if (has_unary_consumer)
50+
continue;
51+
3252
constexpr bool remove_node = true;
3353
graphlib::bypass_node(
3454
graph,

forge/forge/op/tm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,10 @@ def Broadcast(name: str, operandA: Tensor, dim: int, shape: int) -> Tensor:
387387
Tensor
388388
Forge tensor
389389
"""
390+
broadcast_dimensions = [1] * len(operandA.shape.dims)
391+
broadcast_dimensions[dim] = shape
390392

391-
return op(OpType.Broadcast, name, operandA, dim=dim, size=shape).get_tensor()
393+
return op("broadcast", name, operandA, broadcast_dimensions=broadcast_dimensions, dim=dim, size=shape).get_tensor()
392394

393395

394396
def Repeat(name: str, operandA: Tensor, repeats: List[int]) -> Tensor:

0 commit comments

Comments
 (0)