Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ xla_cc_test(
srcs = ["hlo_instruction_test.cc"],
deps = [
":hlo",
"//xla:literal_util",
"//xla:printer",
"//xla:shape_util",
"//xla:side_effect_util",
Expand Down
8 changes: 8 additions & 0 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,14 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
operand);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncUpdate(
const Shape& shape, absl::Span<HloInstruction* const> operands) {
HloInstruction* prev_async = operands[0];
return std::make_unique<HloAsyncInstruction>(
HloOpcode::kAsyncUpdate, shape, operands,
Cast<HloAsyncInstruction>(prev_async)->async_wrapped_opcode());
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncDone(
const Shape& shape, HloInstruction* operand) {
return std::make_unique<HloAsyncInstruction>(HloOpcode::kAsyncDone, shape,
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ class HloInstruction {
absl::string_view async_execution_thread = kMainExecutionThread);
static std::unique_ptr<HloInstruction> CreateAsyncUpdate(
const Shape& shape, HloInstruction* operand);
static std::unique_ptr<HloInstruction> CreateAsyncUpdate(
const Shape& shape, absl::Span<HloInstruction* const> operands);
static std::unique_ptr<HloInstruction> CreateAsyncDone(
const Shape& shape, HloInstruction* operand);

Expand Down
49 changes: 49 additions & 0 deletions xla/hlo/ir/hlo_instruction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/hlo/ir/stack_frames.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
#include "xla/literal_util.h"
#include "xla/printer.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -399,6 +400,54 @@ ENTRY main {
TF_EXPECT_OK(module->schedule().Verify());
}

TEST_F(HloInstructionTest, CreateVariadicAsyncUpdate) {
constexpr absl::string_view kHlo = R"(
HloModule main

ENTRY main {
arg.0 = s32[] parameter(0)
call-start.0 = ((s32[]), s32[], s32[]) call-start(arg.0), to_apply={
arg.0 = s32[] parameter(0)
ROOT abs.0 = abs(arg.0)
}, async_execution_thread="thread"
ROOT call-done.0 = s32[] call-done(call-start.0)
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHlo));

HloComputation* entry = module->entry_computation();
HloInstruction* async_done = entry->root_instruction();
HloInstruction* async_start = async_done->async_chain_start();

// Test 1 operand case
std::unique_ptr<HloInstruction> update1 = HloInstruction::CreateAsyncUpdate(
async_start->shape(), std::vector<HloInstruction*>{async_start});
EXPECT_EQ(update1->opcode(), HloOpcode::kAsyncUpdate);
EXPECT_EQ(update1->operand_count(), 1);

// Test 2 operands case
HloInstruction* const_op = entry->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)));

std::unique_ptr<HloInstruction> update2 = HloInstruction::CreateAsyncUpdate(
async_start->shape(),
std::vector<HloInstruction*>{async_start, const_op});
EXPECT_EQ(update2->opcode(), HloOpcode::kAsyncUpdate);
EXPECT_EQ(update2->operand_count(), 2);
EXPECT_EQ(update2->operand(1), const_op);

// Test Cloning
std::unique_ptr<HloInstruction> clone1 =
update1->CloneWithNewOperands(update1->shape(), update1->operands());
EXPECT_EQ(clone1->operand_count(), 1);

std::unique_ptr<HloInstruction> clone2 =
update2->CloneWithNewOperands(update2->shape(), update2->operands());
EXPECT_EQ(clone2->operand_count(), 2);
EXPECT_EQ(clone2->operand(1), const_op);
}

TEST_F(HloInstructionTest, CloneImplCollectivePermuteOp) {
constexpr absl::string_view kHlo = R"(
HloModule main
Expand Down
19 changes: 13 additions & 6 deletions xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,18 @@ HloAsyncInstruction::HloAsyncInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands, HloOpcode async_wrapped_opcode)
: HloInstruction(opcode, shape) {
CHECK(opcode == HloOpcode::kAsyncStart || operands.size() == 1);
CHECK(opcode == HloOpcode::kAsyncStart || opcode == HloOpcode::kAsyncUpdate ||
operands.size() == 1);
for (auto operand : operands) {
AppendOperand(operand);
}

if (opcode == HloOpcode::kAsyncUpdate || opcode == HloOpcode::kAsyncDone) {
HloAsyncInstruction* prev = Cast<HloAsyncInstruction>(operands[0]);
prev->async_chain_next_ = this;
// AppendComputation(prev->async_wrapped_computation());
}

// Drop 'async' from async-{start/update/done} to get the suffix.
absl::string_view suffix = HloOpcodeString(opcode).substr(5);
absl::string_view wrapped_name = HloOpcodeString(async_wrapped_opcode);
Expand Down Expand Up @@ -363,7 +370,7 @@ void HloAsyncInstruction::UpdateAsyncChain() {
}
};
auto update_operand_chain = [this]() {
CHECK_EQ(this->operand_count(), 1);
CHECK_GE(this->operand_count(), 1);
CHECK(this->operand(0)->opcode() == HloOpcode::kAsyncStart ||
this->operand(0)->opcode() == HloOpcode::kAsyncUpdate);
Cast<HloAsyncInstruction>(this->mutable_operand(0))->async_chain_next_ =
Expand Down Expand Up @@ -425,8 +432,8 @@ bool HloAsyncInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction> HloAsyncInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return std::make_unique<HloAsyncInstruction>(opcode(), shape,
new_operands[0]);
return std::make_unique<HloAsyncInstruction>(opcode(), shape, new_operands,
async_wrapped_opcode());
}

HloAsyncStartInstruction::HloAsyncStartInstruction(
Expand All @@ -443,8 +450,8 @@ HloAsyncStartInstruction::HloAsyncStartInstruction(

HloInstruction* HloAsyncStartInstruction::AddCallOperand(
HloInstruction* new_operand) {
CHECK_EQ(operand_count(),
async_wrapped_computation()->parameter_instructions().size());
CHECK_GE(async_wrapped_computation()->parameter_instructions().size(),
operand_count());
const int64_t param_no = operand_count();
std::string param_name = StrCat("param_", param_no);
HloInstruction* called_computation_parameter =
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ class HloAsyncInstruction : public HloInstruction {

void UpdateAsyncChain();

protected:
// Helper to constructs async-{start,update,done}.
HloAsyncInstruction(HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
HloOpcode async_wrapped_opcode);

protected:
// Updates all future instructions in the async chain to match the shape of
// the current instruction.
void UpdateChainShapes();
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/ir/hlo_opcode.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ namespace xla {
V(kAsinh, "asinh", 1) \
V(kAsyncDone, "async-done", 1) \
V(kAsyncStart, "async-start", kHloOpcodeIsVariadic) \
V(kAsyncUpdate, "async-update", 1) \
V(kAsyncUpdate, "async-update", kHloOpcodeIsVariadic) \
V(kAtan2, "atan2", 2) \
V(kAtanh, "atanh", 1) \
V(kBatchNormGrad, "batch-norm-grad", 5) \
Expand Down
21 changes: 9 additions & 12 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2113,27 +2113,24 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
return nullptr;
}
}
// TODO(phui): move these checks to the verifier
// async-{update,done} expect their one singular operand to be the
// previous async op.
if (opcode == HloOpcode::kAsyncUpdate ||
opcode == HloOpcode::kAsyncDone) {
if (operands.size() != 1 || !operands[0]->IsAsynchronous() ||
if (operands.empty() || !operands[0]->IsAsynchronous() ||
operands[0]->opcode() == HloOpcode::kAsyncDone) {
TokenError(
"AsyncUpdate and AsyncDone expect a single async op as their "
"operand.");
"AsyncUpdate and AsyncDone expect a single AsyncStart or "
"AsyncUpdate op as their first operand.");
return nullptr;
}
}
// For AsyncUpdate, the operand and the result should have the same shape.
if (opcode == HloOpcode::kAsyncUpdate) {
if (operands[0]->shape() != *shape) {
TokenError(
"AsyncUpdate expects the op shape to be the same as the operand "
"shape.");
return nullptr;
}
if (opcode == HloOpcode::kAsyncDone && operands.size() != 1) {
TokenError("AsyncDone expects exactly one operand");
return nullptr;
}

optional<std::string> async_execution_thread;
attrs["async_execution_thread"] = {/*required=*/false, AttrTy::kString,
&async_execution_thread};
Expand Down Expand Up @@ -2238,7 +2235,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}
if (opcode == HloOpcode::kAsyncUpdate) {
return builder->AddInstruction(
HloInstruction::CreateAsyncUpdate(*shape, operands[0]));
HloInstruction::CreateAsyncUpdate(*shape, operands));
}
return builder->AddInstruction(
HloInstruction::CreateAsyncDone(*shape, operands[0]));
Expand Down
Loading
Loading