Skip to content

Commit 6d74871

Browse files
Fix shape mismatch caused by broadcast-to-TMS conversion when inserting on Cast edge
1 parent d4f1dec commit 6d74871

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

forge/csrc/passes/lower_to_mlir.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ class AttributeMapper
167167
add_op_mapping("update_cache", "batch_offset", AttributeRemap(std::nullopt, TargetType::I32Attr));
168168
add_op_mapping("fill_cache", "batch_offset", AttributeRemap(std::nullopt, TargetType::I32Attr));
169169

170+
// Broadcast
171+
add_op_mapping(
172+
"broadcast", "broadcast_dimensions", AttributeRemap(std::nullopt, TargetType::DenseI64ArrayAttr));
173+
170174
// Add more default mappings here
171175
}
172176
};
@@ -801,6 +805,7 @@ class MLIRGenerator
801805
lowering_handler_map["avg_pool2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AvgPool2dOp>;
802806
lowering_handler_map["batchnorm"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BatchNormOp>;
803807
lowering_handler_map["bitwise_and"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BitwiseAndOp>;
808+
lowering_handler_map["broadcast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::BroadcastOp>;
804809
lowering_handler_map["cast"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::TypecastOp>;
805810
lowering_handler_map["clip"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ClampScalarOp>;
806811
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;

forge/csrc/passes/pre_lowering_passes.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,39 @@ using NodeType = graphlib::NodeType;
1616
using Edge = graphlib::Edge;
1717
using EdgeType = graphlib::EdgeType;
1818

19-
void convert_broadcast_ops_to_tms(Graph *graph)
19+
void convert_broadcast_ops_to_tms(graphlib::Graph *graph)
2020
{
21-
std::vector<Node *> broadcast_ops = graph->nodes(
22-
[](Node *node) -> bool
21+
// Determines if the consumer node is unary to allow broadcast propagation.
22+
// If not unary, the broadcast is lowered to TMs and inserted on the consumer edge.
23+
24+
std::vector<graphlib::Node *> broadcast_ops = graph->nodes(
25+
[](graphlib::Node *node) -> bool
2326
{
2427
graphlib::OpNode *op = dynamic_cast<graphlib::OpNode *>(node);
2528
return op and op->new_op_type() == ops::OpType::Broadcast;
2629
});
2730

2831
for (Node *node : broadcast_ops)
2932
{
30-
graphlib::OpNode *op = node->as<graphlib::OpNode>();
33+
graphlib::OpNode *op = dynamic_cast<graphlib::OpNode *>(node);
34+
if (not op)
35+
continue;
36+
3137
graphlib::OpType op_type = op->op_type();
38+
bool has_unary_consumer = false;
39+
for (graphlib::Edge user_edge : graph->user_data_edges(node))
40+
{
41+
graphlib::OpNode *consumer_op =
42+
dynamic_cast<graphlib::OpNode *>(graph->node_by_id(user_edge.consumer_node_id));
43+
if (consumer_op && consumer_op->is_eltwise_unary())
44+
{
45+
has_unary_consumer = true;
46+
break;
47+
}
48+
}
49+
if (has_unary_consumer)
50+
continue;
51+
3252
constexpr bool remove_node = true;
3353
graphlib::bypass_node(
3454
graph,

forge/forge/op/tm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ def Broadcast(name: str, operandA: Tensor, dim: int, shape: int) -> Tensor:
381381
Tensor
382382
Forge tensor
383383
"""
384+
broadcast_dimensions = [1] * len(operandA.shape.dims)
385+
broadcast_dimensions[dim] = shape
384386

385-
return op("broadcast", name, operandA, dim=dim, size=shape).get_tensor()
387+
return op("broadcast", name, operandA, broadcast_dimensions=broadcast_dimensions, dim=dim, size=shape).get_tensor()
386388

387389

388390
def Repeat(name: str, operandA: Tensor, repeats: List[int]) -> Tensor:

0 commit comments

Comments
 (0)