Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion shardy/dialect/mpmd/ir/fragment_execution_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bool ParseFragmentOrigin(llvm::cl::Option& opt, llvm::StringRef& arg,
if (!arg.consume_front("(")) {
return false;
}
if (!arg.consumeInteger(10, origin.transpose_count)) {
if (arg.consumeInteger(10, origin.transpose_count)) {
return opt.error("Expected a transpose count");
}
if (!arg.consume_front(")")) {
Expand Down Expand Up @@ -187,6 +187,7 @@ void SetFragmentInfo(FragmentOp fragment, const FragmentInfo& metadata,
namespace llvm::cl {

using ::mlir::mpmd::FragmentMergeRule;
using ::mlir::mpmd::FragmentScheduleRule;

template class basic_parser<FragmentMergeRule>;

Expand Down Expand Up @@ -228,4 +229,39 @@ void parser<FragmentMergeRule>::printOptionDiff(const Option& opt,

void parser<FragmentMergeRule>::anchor() {}

template class basic_parser<FragmentScheduleRule>;

// Parses a fragment schedule rule string of the form
// "FragmentScheduleRule(ordered_fragments=[<fragment1>-><fragment2>...])"
// <fragment>s are FragmentInfo strings.
bool parser<FragmentScheduleRule>::parse(Option& opt, StringRef, StringRef arg,
FragmentScheduleRule& value) {
if (!arg.consume_front(FragmentScheduleRule::kFragmentScheduleRulePrefix)) {
return opt.error("Expected '" +
FragmentScheduleRule::kFragmentScheduleRulePrefix + "'");
}
while (!arg.starts_with("]")) {
if (mlir::mpmd::ParseFragmentInfo(opt, arg,
value.ordered_fragments.emplace_back())) {
return true; // opt.error was called inside ParseFragmentInfo
}
if (!arg.consume_front("->")) {
break;
}
}
if (!arg.consume_front("])")) {
return opt.error("Expected '])'");
}
return false;
}

void parser<FragmentScheduleRule>::printOptionDiff(
const Option& opt, const FragmentScheduleRule& value,
const OptVal& defaultValue, size_t globalWidth) const {
printOptionName(opt, globalWidth);
outs() << "= " << value << "\n";
}

void parser<FragmentScheduleRule>::anchor() {}

} // namespace llvm::cl
38 changes: 38 additions & 0 deletions shardy/dialect/mpmd/ir/fragment_execution_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -142,6 +143,27 @@ struct FragmentMergeRule {

using FragmentMergeRules = std::vector<FragmentMergeRule>;

// Describes a rule for scheduling fragments. A rule is defined by an ordered
// sequence of fragments. This ordering dictates the execution order of the
// fragments on a given mesh.
struct FragmentScheduleRule {
// The sequence of fragments to be scheduled. The order of fragments in this
// vector defines their execution order.
std::vector<FragmentInfo> ordered_fragments;

static constexpr llvm::StringRef kFragmentScheduleRulePrefix =
"FragmentScheduleRule(ordered_fragments=[";

friend llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
const FragmentScheduleRule& rule) {
os << kFragmentScheduleRulePrefix;
llvm::interleave(rule.ordered_fragments, os, "->");
return os << "])";
}
};

using FragmentScheduleRules = std::vector<FragmentScheduleRule>;

// Returns the fragment info of a fragment op.
FragmentInfo GetFragmentInfo(FragmentOp fragment);

Expand Down Expand Up @@ -169,6 +191,22 @@ class parser<mlir::mpmd::FragmentMergeRule>
void anchor() override;
};

extern template class basic_parser<mlir::mpmd::FragmentScheduleRule>;

template <>
class parser<mlir::mpmd::FragmentScheduleRule>
: public basic_parser<mlir::mpmd::FragmentScheduleRule> {
public:
parser(Option& opt) : basic_parser(opt) {}
bool parse(Option& opt, StringRef argName, StringRef arg,
mlir::mpmd::FragmentScheduleRule& value);
StringRef getValueName() const override { return "fragment-schedule-rule"; }
void printOptionDiff(const Option& opt,
const mlir::mpmd::FragmentScheduleRule& value,
const OptVal& defaultValue, size_t globalWidth) const;
void anchor() override;
};

} // namespace llvm::cl

#endif // SHARDY_DIALECT_MPMD_IR_FRAGMENT_EXECUTION_RULES_H_
181 changes: 181 additions & 0 deletions shardy/dialect/mpmd/ir/fragment_execution_rules_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <vector>

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -69,6 +70,48 @@ FragmentMergeRule MakeFragmentMergeRule(
return {sources, target};
}

FragmentScheduleRule MakeFragmentScheduleRule(
const std::vector<FragmentInfo>& ordered_fragments) {
return {ordered_fragments};
}

// LLVM's command line classes (OptionCategory, opt) store StringRef arguments
// directly without copying the underlying string data. When these objects are
// created with temporary string literals in test functions, the backing strings
// go out of scope after the test completes, leaving dangling pointers in the
// static GlobalParser->RegisteredOptionCategories.
//
// The functions below use static storage to ensure string literals have static
// storage duration, avoiding the need for manual cleanup. The parser helper
// functions encapsulate opt/parser creation and provide a clean interface for
// tests without exposing StringRef lifetime concerns.

llvm::cl::OptionCategory& getTestOptionCategory() {
static llvm::cl::OptionCategory category("Test Options");
return category;
}

bool parseFragmentMergeRule(llvm::StringRef rule_str, FragmentMergeRule& rule) {
static llvm::cl::opt<FragmentMergeRule> rule_opt(
"fragment-merge-rule",
llvm::cl::desc("Fragment merge rule for testing parser functionality"),
llvm::cl::cat(getTestOptionCategory()));
static llvm::cl::parser<FragmentMergeRule> parser(rule_opt);

return parser.parse(rule_opt, "test-rule", rule_str, rule);
}

bool parseFragmentScheduleRule(llvm::StringRef rule_str,
FragmentScheduleRule& rule) {
static llvm::cl::opt<FragmentScheduleRule> rule_opt(
"fragment-schedule-rule",
llvm::cl::desc("Fragment schedule rule for testing parser functionality"),
llvm::cl::cat(getTestOptionCategory()));
static llvm::cl::parser<FragmentScheduleRule> parser(rule_opt);

return parser.parse(rule_opt, "test-rule", rule_str, rule);
}

TEST(GetFragmentInfoTest, GetFragmentInfo) {
const std::string kProgram = R"mlir(
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
Expand Down Expand Up @@ -259,6 +302,144 @@ TEST(FragmentMergeRule, PrintFragmentMergeRule) {
"\"f1\"(123),\"f2\"(456)],stage=1,is_weight_gradient=false))"));
}

TEST(FragmentMergeRuleParser, ParseValidRule) {
FragmentMergeRule expected_rule = MakeFragmentMergeRule(
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)},
/*stage_id=*/1,
/*call_counter=*/std::nullopt,
/*is_weight_gradient=*/false),
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
/*stage_id=*/1,
/*call_counter=*/std::nullopt,
/*is_weight_gradient=*/true)},
MakeFragmentInfo(
{MakeFragmentOrigin("f1", 123), MakeFragmentOrigin("f2", 456)},
/*stage_id=*/1,
/*call_counter=*/std::nullopt,
/*is_weight_gradient=*/false));
// We first construct the rule and print it to a string. Then we parse that
// string to ensure that the printed form of a rule is directly compatible
// with the format the parser expects.
std::string rule_str;
llvm::raw_string_ostream os(rule_str);
os << expected_rule;

FragmentMergeRule rule;
bool result = parseFragmentMergeRule(rule_str, rule);

EXPECT_FALSE(result);

ASSERT_EQ(rule.sources.size(), 2);
ExpectFragmentInfoEq(rule.sources[0], expected_rule.sources[0]);
ExpectFragmentInfoEq(rule.sources[1], expected_rule.sources[1]);
ExpectFragmentInfoEq(rule.target, expected_rule.target);
}

struct InvalidRuleTestParams {
std::string test_name;
std::string invalid_rule_str;
};

class FragmentMergeRuleParserInvalidSyntaxTest
: public ::testing::TestWithParam<InvalidRuleTestParams> {};

TEST_P(FragmentMergeRuleParserInvalidSyntaxTest, ParseInvalidRule) {
const auto& params = GetParam();
FragmentMergeRule rule;
bool result = parseFragmentMergeRule(params.invalid_rule_str, rule);

EXPECT_TRUE(result);
}

INSTANTIATE_TEST_SUITE_P(
FragmentMergeRuleParser, FragmentMergeRuleParserInvalidSyntaxTest,
testing::Values(
InvalidRuleTestParams{"MissingPrefix",
"sources=[FragmentInfo(origins=[\"f1\"(123)])]"},
InvalidRuleTestParams{
"MissingSources",
"FragmentMergeRule(target=FragmentInfo(origins=[\"f1\"(123)]))"},
InvalidRuleTestParams{"MissingTarget",
"FragmentMergeRule(sources=[FragmentInfo(origins="
"[\"f1\"(123)])])"}),
[](const testing::TestParamInfo<
FragmentMergeRuleParserInvalidSyntaxTest::ParamType>& info) {
return info.param.test_name;
});

TEST(FragmentScheduleRule, PrintFragmentScheduleRule) {
FragmentScheduleRule rule = MakeFragmentScheduleRule(
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1),
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*stage_id=*/1)});
std::string str;
llvm::raw_string_ostream os(str);
os << rule;
EXPECT_THAT(str, Eq("FragmentScheduleRule(ordered_fragments=["
"FragmentInfo(origins=[\"f1\"(123)],stage=1,is_weight_"
"gradient=false)->"
"FragmentInfo(origins=[\"f2\"(456)],stage=1,is_weight_"
"gradient=false)])"));
}

TEST(FragmentScheduleRuleParser, ParseValidRule) {
FragmentScheduleRule expected_rule = MakeFragmentScheduleRule(
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)},
/*stage_id=*/1,
/*call_counter=*/std::nullopt,
/*is_weight_gradient=*/false),
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
/*stage_id=*/1,
/*call_counter=*/std::nullopt,
/*is_weight_gradient=*/true)});
// We first construct the rule and print it to a string. Then we parse that
// string to ensure that the printed form of a rule is directly compatible
// with the format the parser expects.
std::string rule_str;
llvm::raw_string_ostream os(rule_str);
os << expected_rule;

FragmentScheduleRule rule;
bool result = parseFragmentScheduleRule(rule_str, rule);

EXPECT_FALSE(result);

ASSERT_EQ(rule.ordered_fragments.size(), 2);
ExpectFragmentInfoEq(rule.ordered_fragments[0],
expected_rule.ordered_fragments[0]);
ExpectFragmentInfoEq(rule.ordered_fragments[1],
expected_rule.ordered_fragments[1]);
}

class FragmentScheduleRuleParserInvalidSyntaxTest
: public ::testing::TestWithParam<InvalidRuleTestParams> {};

TEST_P(FragmentScheduleRuleParserInvalidSyntaxTest, ParseInvalidRule) {
const auto& params = GetParam();
FragmentScheduleRule rule;
bool result = parseFragmentScheduleRule(params.invalid_rule_str, rule);

EXPECT_TRUE(result);
}

INSTANTIATE_TEST_SUITE_P(
FragmentScheduleRuleParser, FragmentScheduleRuleParserInvalidSyntaxTest,
testing::Values(
InvalidRuleTestParams{"MissingPrefix",
"[FragmentInfo(origins=[\"f1\"(123)])->"
"FragmentInfo(origins=[\"f2\"(456)])])"},
InvalidRuleTestParams{
"MissingArrow",
"FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
"\"f1\"(123)]) FragmentInfo(origins=[\"f2\"(456)])])"},
InvalidRuleTestParams{
"MissingClosingBrackets",
"FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
"\"f1\"(123)])->FragmentInfo(origins=[\"f2\"(456)])"}),
[](const testing::TestParamInfo<
FragmentScheduleRuleParserInvalidSyntaxTest::ParamType>& info) {
return info.param.test_name;
});

TEST(FragmentInfoMapInfoTest, IsEqual) {
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
Expand Down
Loading