Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 122 additions & 13 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,12 @@ class AttributeMapper
add_op_mapping("avg_pool2d", "dilation", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));

// conv2d_transpose
add_op_mapping("conv2d_transpose", "dilation", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d_transpose", "groups", AttributeRemap(std::nullopt, TargetType::I32Attr));
add_op_mapping(
"conv2d_transpose", "output_padding", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d_transpose", "padding", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d_transpose", "stride", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d_transpose", "groups", AttributeRemap("feature_groups_count", TargetType::I32Attr));
add_op_mapping("conv2d_transpose", "stride", AttributeRemap("window_strides", TargetType::DenseI32ArrayAttr));

// conv2d
add_op_mapping("conv2d", "dilation", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d", "groups", AttributeRemap(std::nullopt, TargetType::I32Attr));
add_op_mapping("conv2d", "padding", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d", "stride", AttributeRemap(std::nullopt, TargetType::DenseI32ArrayAttr));
add_op_mapping("conv2d", "groups", AttributeRemap("feature_groups_count", TargetType::I32Attr));
add_op_mapping("conv2d", "stride", AttributeRemap("window_strides", TargetType::DenseI32ArrayAttr));

// cumsum
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::I64Attr));
Expand Down Expand Up @@ -579,6 +573,122 @@ class MLIRGenerator
return op.getOperation()->getResult(0);
}

mlir::Value emit_mlir_ttforge_convolution_op(tt::graphlib::Graph *graph, tt::graphlib::OpNode *op_node)
{
// Get Return Type
llvm::SmallVector<mlir::Type> return_types = get_mlir_type_range(op_node);

// Get Operands (Input, Weight, Bias)
auto all_operands = graph->data_operands(op_node);
TT_ASSERT(
all_operands.size() == 2 || all_operands.size() == 3,
"ConvolutionOp must have 2 or 3 operands (input, weight, [bias])");

llvm::SmallVector<mlir::Value> operands;
operands.push_back(symbolTable_.at(all_operands[0]->name()).first); // Input
operands.push_back(symbolTable_.at(all_operands[1]->name()).first); // Weight

// Handle optional bias
if (all_operands.size() == 3)
{
operands.push_back(symbolTable_.at(all_operands[2]->name()).first);
}
else
{
// Pass ttir.none for missing bias
operands.push_back(builder_.create<mlir::tt::ttir::NoneOp>(
get_tt_forge_operation_location(graph, op_node), builder_.getNoneType()));
}

// Build Attribute List
llvm::SmallVector<mlir::NamedAttribute> mlir_attributes;
const auto &op_attrs = op_node->op().attrs();

// Check if this is conv2d_transpose
bool is_transpose = op_node->op_as_string() == "conv2d_transpose";
mlir_attributes.push_back(builder_.getNamedAttr("is_transpose", builder_.getBoolAttr(is_transpose)));

// Create ConvolutionLayoutAttr
// Based on the issues description, the incoming Conv2dOp is channel-last (NHWC), and a matching kernel (HWIO).
auto layout_attr = mlir::tt::ttcore::ConvolutionLayoutAttr::get(
builder_.getContext(),
/*inputBatchDimension=*/builder_.getI32IntegerAttr(0),
/*inputFeatureDimension=*/builder_.getI32IntegerAttr(3),
/*inputSpatialDimensions=*/builder_.getDenseI32ArrayAttr({1, 2}),
/*kernelOutputFeatures=*/builder_.getI32IntegerAttr(3),
/*kernelInputFeatures=*/builder_.getI32IntegerAttr(2),
/*kernelSpatialDimensions=*/builder_.getDenseI32ArrayAttr({0, 1}),
/*outputBatchDimension=*/builder_.getI32IntegerAttr(0),
/*outputFeatureDimension=*/builder_.getI32IntegerAttr(3),
/*outputSpatialDimensions=*/builder_.getDenseI32ArrayAttr({1, 2}));
mlir_attributes.push_back(builder_.getNamedAttr("layout", layout_attr));

// Handle remaining attributes
// Manually handle padding, dilation, and output_padding
std::set<std::string> handled_attrs = {"padding", "dilation"};
if (is_transpose)
{
handled_attrs.insert("output_padding");
}

// Automatically map attributes the following are renamed (groups, stride)
for (const auto &[name, value] : op_attrs)
{
if (handled_attrs.count(name))
continue;

auto [mapped_name, target_type] = attr_mapper_.get_mapped_name_and_type(op_node->op_as_string(), name);
mlir_attributes.push_back(builder_.getNamedAttr(mapped_name, convert_to_mlir_attribute(value, target_type)));
}

// Manual Attribute: Padding
auto padding_it = op_attrs.find("padding");
TT_ASSERT(padding_it != op_attrs.end(), "ConvOp must have 'padding' attribute");
auto padding_vec = std::get<std::vector<int>>(padding_it->second);
std::vector<int32_t> mlir_padding;
if (padding_vec.size() == 2)
{
// Expand [pad_h, pad_w] to [top, bottom, left, right]
mlir_padding = {padding_vec[0], padding_vec[0], padding_vec[1], padding_vec[1]};
}
else if (padding_vec.size() == 4)
{
mlir_padding.assign(padding_vec.begin(), padding_vec.end());
}
else
{
TT_ASSERT(false, "Unsupported padding vector size, expected 2 or 4");
}
mlir_attributes.push_back(builder_.getNamedAttr("padding", builder_.getDenseI32ArrayAttr(mlir_padding)));

// Manual Attribute: Dilation
auto dilation_it = op_attrs.find("dilation");
TT_ASSERT(dilation_it != op_attrs.end(), "ConvOp must have 'dilation' attribute");
auto dilation_vec = std::get<std::vector<int>>(dilation_it->second);
mlir_attributes.push_back(builder_.getNamedAttr(
"dilation", builder_.getDenseI32ArrayAttr(std::vector<int32_t>(dilation_vec.begin(), dilation_vec.end()))));

// Manual Attribute: Output Padding (for transpose only)
if (is_transpose)
{
auto out_pad_it = op_attrs.find("output_padding");
TT_ASSERT(out_pad_it != op_attrs.end(), "ConvTransposeOp must have 'output_padding' attribute");
auto out_pad_vec = std::get<std::vector<int>>(out_pad_it->second);
mlir_attributes.push_back(builder_.getNamedAttr(
"output_padding",
builder_.getDenseI32ArrayAttr(std::vector<int32_t>(out_pad_vec.begin(), out_pad_vec.end()))));
}

// Create the Operation
auto op = builder_.create<mlir::tt::ttir::ConvolutionOp>(
get_tt_forge_operation_location(graph, op_node),
mlir::TypeRange(return_types),
mlir::ValueRange(operands),
mlir_attributes);

return op.getOperation()->getResult(0);
}

// Get the TT-MLIR type for a TTForge operation.
llvm::SmallVector<mlir::Type> get_mlir_type_range(tt::graphlib::OpNode *op_node)
{
Expand Down Expand Up @@ -815,7 +925,7 @@ class MLIRGenerator
lowering_handler_map["clip"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ClampScalarOp>;
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["constant_pad"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::PadOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_convolution_op;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["cumsum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CumSumOp>;
lowering_handler_map["divide"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::DivOp>;
Expand Down Expand Up @@ -850,8 +960,7 @@ class MLIRGenerator
lowering_handler_map["repeat_interleave"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatInterleaveOp>;
lowering_handler_map["repeat"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RepeatOp>;
lowering_handler_map["conv2d_transpose"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConvTranspose2dOp>;
lowering_handler_map["conv2d_transpose"] = &MLIRGenerator::emit_mlir_ttforge_convolution_op;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["select"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::IndexSelectOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
Expand Down
87 changes: 87 additions & 0 deletions forge/test/mlir/test_ops_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,90 @@
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)

@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w",
(
(1, 16, 64, 56, 56, 4, 4, 1, 1), # Flipped output/input channels
(1, 64, 64, 56, 56, 3, 3, 1, 1),
(1, 128, 128, 28, 28, 3, 3, 2, 2), # Stride > 1

Check failure

Code scanning / flake8

at least two spaces before inline comment Error test

at least two spaces before inline comment
(1, 128, 128, 14, 14, 3, 3, 1, 1),
(1, 256, 256, 14, 14, 3, 3, 2, 2), # Stride > 1

Check failure

Code scanning / flake8

at least two spaces before inline comment Error test

at least two spaces before inline comment
(1, 256, 256, 7, 7, 3, 3, 1, 1),
(1, 64, 64, 8, 8, 3, 3, 1, 1),
(1, 64, 64, 16, 16, 3, 3, 1, 1),
(1, 256, 256, 7, 7, 3, 3, 1, 1),
(1, 64, 256, 28, 28, 1, 1, 2, 2), # Flipped output/input channels, Stride > 1

Check failure

Code scanning / flake8

at least two spaces before inline comment Error test

at least two spaces before inline comment
),
)
@pytest.mark.parametrize(
"weights_dtype",
[
pytest.param(
tf.bfloat16, marks=pytest.mark.xfail(reason="dtypes are not properly lowered from tensorflow #281")
),
tf.float32,
],
)
@pytest.mark.parametrize(
"activations_dtype",
[
tf.bfloat16,
tf.float32,
],
)
@pytest.mark.parametrize("has_bias", [False, True], ids=["no_bias", "with_bias"])
@pytest.mark.push
def test_conv2d_transpose(
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
activations_dtype,
weights_dtype,
has_bias,
):
tf.random.set_seed(0)
if (
activations_dtype == tf.float32
and weights_dtype == tf.float32
and input_height == input_width == 28
and input_channels == output_channels == 256
):
pytest.skip("Circular buffer grows beyond maximum L1 size.")

# Padding logic for Conv2DTranspose is different, but 'same' and 'valid' are still supported.
# same logic as test_conv2d to ensure parity.
padding = "same" if stride_h == stride_w == 1 and filter_height % 2 == 1 else "valid"

class Conv2dTranspose(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv2d_transpose = tf.keras.layers.Conv2DTranspose(
output_channels, # `filters` argument specifies output channels

Check failure

Code scanning / flake8

at least two spaces before inline comment Error test

at least two spaces before inline comment
(filter_height, filter_width),
strides=[stride_h, stride_w],
padding=padding,
data_format="channels_last",
dtype=weights_dtype,
use_bias=has_bias,
# Note: output_padding is not tested here, but could be added.
# It defaults to None which is fine for most cases.
)

def call(self, x):
return self.conv2d_transpose(x)

# Input shape uses input_channels
inputs = [tf.random.uniform((batch_size, input_height, input_width, input_channels), dtype=activations_dtype)]

framework_model = Conv2dTranspose()

compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)

Check warning

Code scanning / flake8

no newline at end of file Warning test

no newline at end of file
Loading