Skip to content

Commit bbc9d9e

Browse files
ericastorcopybara-github
authored andcommitted
Add next_value support to the proc interpreter
PiperOrigin-RevId: 597982784
1 parent c06d44f commit bbc9d9e

7 files changed

+368
-70
lines changed

xls/interpreter/BUILD

+3-1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ cc_library(
110110
"@com_google_absl//absl/container:flat_hash_map",
111111
"@com_google_absl//absl/status",
112112
"@com_google_absl//absl/status:statusor",
113+
"@com_google_absl//absl/strings",
114+
"@com_google_absl//absl/strings:str_format",
113115
"@com_google_absl//absl/types:span",
114116
"//xls/common/status:ret_check",
115117
"//xls/common/status:status_macros",
@@ -320,7 +322,6 @@ cc_library(
320322
"@com_google_absl//absl/strings:str_format",
321323
"//xls/common/logging",
322324
"//xls/ir",
323-
"//xls/ir:channel",
324325
"//xls/ir:elaboration",
325326
"//xls/ir:events",
326327
"//xls/ir:value",
@@ -335,6 +336,7 @@ cc_library(
335336
deps = [
336337
":channel_queue",
337338
":proc_evaluator",
339+
"@com_google_absl//absl/status",
338340
"@com_google_absl//absl/status:statusor",
339341
"//xls/common/status:matchers",
340342
"//xls/ir",

xls/interpreter/proc_evaluator.cc

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "xls/common/logging/logging.h"
2222
#include "xls/ir/node.h"
2323
#include "xls/ir/nodes.h"
24+
#include "xls/ir/proc.h"
2425

2526
namespace xls {
2627

xls/interpreter/proc_evaluator_test_base.cc

+246-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "gmock/gmock.h"
2121
#include "gtest/gtest.h"
22+
#include "absl/status/status.h"
2223
#include "absl/status/statusor.h"
2324
#include "xls/common/status/matchers.h"
2425
#include "xls/interpreter/channel_queue.h"
@@ -36,8 +37,10 @@ namespace xls {
3637
namespace {
3738

3839
using status_testing::IsOkAndHolds;
39-
using testing::ElementsAre;
40-
using testing::Optional;
40+
using status_testing::StatusIs;
41+
using ::testing::ElementsAre;
42+
using ::testing::HasSubstr;
43+
using ::testing::Optional;
4144

4245
TEST_P(ProcEvaluatorTestBase, EmptyProc) {
4346
auto package = CreatePackage();
@@ -543,6 +546,247 @@ TEST_P(ProcEvaluatorTestBase, ConditionalSendProc) {
543546
EXPECT_THAT(queue.Read(), Optional(Value(UBits(4, 32))));
544547
}
545548

549+
TEST_P(ProcEvaluatorTestBase, UnconditionalNextProc) {
550+
if (!GetParam().SupportsNextValue()) {
551+
GTEST_SKIP() << "Evaluator does not support next_value";
552+
}
553+
554+
// Create an output-only proc which increments its counter value each
555+
// iteration, using explicit next_value nodes.
556+
Package package(TestName());
557+
XLS_ASSERT_OK_AND_ASSIGN(
558+
Channel * channel,
559+
package.CreateStreamingChannel("counter_out", ChannelOps::kSendOnly,
560+
package.GetBitsType(32)));
561+
562+
ProcBuilder pb("counter", /*token_name=*/"tok", &package);
563+
BValue counter = pb.StateElement("counter", Value(UBits(0, 32)));
564+
BValue send = pb.Send(channel, pb.GetTokenParam(), counter);
565+
BValue incremented_counter = pb.Add(counter, pb.Literal(UBits(1, 32)));
566+
pb.Next(/*param=*/counter, /*value=*/incremented_counter);
567+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build(send));
568+
569+
std::unique_ptr<ChannelQueueManager> queue_manager =
570+
GetParam().CreateQueueManager(&package);
571+
std::unique_ptr<ProcEvaluator> evaluator = GetParam().CreateEvaluator(
572+
FindProc("counter", &package), queue_manager.get());
573+
574+
ChannelQueue& queue = queue_manager->GetQueue(channel);
575+
XLS_ASSERT_OK_AND_ASSIGN(
576+
ChannelInstance * channel_instance,
577+
queue_manager->elaboration().GetUniqueInstance(channel));
578+
579+
std::unique_ptr<ProcContinuation> continuation = evaluator->NewContinuation(
580+
queue_manager->elaboration().GetUniqueInstance(proc).value());
581+
582+
EXPECT_THAT(evaluator->Tick(*continuation),
583+
IsOkAndHolds(TickResult{
584+
.execution_state = TickExecutionState::kSentOnChannel,
585+
.channel_instance = channel_instance,
586+
.progress_made = true}));
587+
EXPECT_THAT(
588+
evaluator->Tick(*continuation),
589+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
590+
.channel_instance = std::nullopt,
591+
.progress_made = true}));
592+
EXPECT_EQ(queue.GetSize(), 1);
593+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(0, 32))));
594+
595+
EXPECT_THAT(evaluator->Tick(*continuation),
596+
IsOkAndHolds(TickResult{
597+
.execution_state = TickExecutionState::kSentOnChannel,
598+
.channel_instance = channel_instance,
599+
.progress_made = true}));
600+
EXPECT_THAT(
601+
evaluator->Tick(*continuation),
602+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
603+
.channel_instance = std::nullopt,
604+
.progress_made = true}));
605+
EXPECT_EQ(queue.GetSize(), 1);
606+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(1, 32))));
607+
608+
EXPECT_THAT(evaluator->Tick(*continuation),
609+
IsOkAndHolds(TickResult{
610+
.execution_state = TickExecutionState::kSentOnChannel,
611+
.channel_instance = channel_instance,
612+
.progress_made = true}));
613+
EXPECT_THAT(
614+
evaluator->Tick(*continuation),
615+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
616+
.channel_instance = std::nullopt,
617+
.progress_made = true}));
618+
EXPECT_EQ(queue.GetSize(), 1);
619+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(2, 32))));
620+
}
621+
622+
TEST_P(ProcEvaluatorTestBase, ConditionalNextProc) {
623+
if (!GetParam().SupportsNextValue()) {
624+
GTEST_SKIP() << "Evaluator does not support next_value";
625+
}
626+
627+
// Create an output-only proc which increments its counter value only every
628+
// other iteration.
629+
Package package(TestName());
630+
XLS_ASSERT_OK_AND_ASSIGN(
631+
Channel * channel,
632+
package.CreateStreamingChannel("slow_counter_out", ChannelOps::kSendOnly,
633+
package.GetBitsType(32)));
634+
635+
ProcBuilder pb("slow_counter", /*token_name=*/"tok", &package);
636+
BValue counter = pb.StateElement("counter", Value(UBits(0, 32)));
637+
BValue iteration = pb.StateElement("iteration", Value(UBits(0, 32)));
638+
BValue send = pb.Send(channel, pb.GetTokenParam(), counter);
639+
BValue incremented_counter = pb.Add(counter, pb.Literal(UBits(1, 32)));
640+
BValue odd_iteration = pb.Eq(pb.BitSlice(iteration, /*start=*/0, /*width=*/1),
641+
pb.Literal(UBits(1, 1)));
642+
pb.Next(/*param=*/counter, /*value=*/incremented_counter,
643+
/*pred=*/odd_iteration);
644+
pb.Next(/*param=*/iteration,
645+
/*value=*/pb.Add(iteration, pb.Literal(UBits(1, 32))));
646+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build(send));
647+
648+
std::unique_ptr<ChannelQueueManager> queue_manager =
649+
GetParam().CreateQueueManager(&package);
650+
std::unique_ptr<ProcEvaluator> evaluator = GetParam().CreateEvaluator(
651+
FindProc("slow_counter", &package), queue_manager.get());
652+
653+
ChannelQueue& queue = queue_manager->GetQueue(channel);
654+
XLS_ASSERT_OK_AND_ASSIGN(
655+
ChannelInstance * channel_instance,
656+
queue_manager->elaboration().GetUniqueInstance(channel));
657+
658+
std::unique_ptr<ProcContinuation> continuation = evaluator->NewContinuation(
659+
queue_manager->elaboration().GetUniqueInstance(proc).value());
660+
EXPECT_THAT(evaluator->Tick(*continuation),
661+
IsOkAndHolds(TickResult{
662+
.execution_state = TickExecutionState::kSentOnChannel,
663+
.channel_instance = channel_instance,
664+
.progress_made = true}));
665+
EXPECT_THAT(
666+
evaluator->Tick(*continuation),
667+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
668+
.channel_instance = std::nullopt,
669+
.progress_made = true}));
670+
EXPECT_EQ(queue.GetSize(), 1);
671+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(0, 32))));
672+
673+
EXPECT_THAT(evaluator->Tick(*continuation),
674+
IsOkAndHolds(TickResult{
675+
.execution_state = TickExecutionState::kSentOnChannel,
676+
.channel_instance = channel_instance,
677+
.progress_made = true}));
678+
EXPECT_THAT(
679+
evaluator->Tick(*continuation),
680+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
681+
.channel_instance = std::nullopt,
682+
.progress_made = true}));
683+
EXPECT_EQ(queue.GetSize(), 1);
684+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(0, 32))));
685+
686+
EXPECT_THAT(evaluator->Tick(*continuation),
687+
IsOkAndHolds(TickResult{
688+
.execution_state = TickExecutionState::kSentOnChannel,
689+
.channel_instance = channel_instance,
690+
.progress_made = true}));
691+
EXPECT_THAT(
692+
evaluator->Tick(*continuation),
693+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
694+
.channel_instance = std::nullopt,
695+
.progress_made = true}));
696+
EXPECT_EQ(queue.GetSize(), 1);
697+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(1, 32))));
698+
699+
EXPECT_THAT(evaluator->Tick(*continuation),
700+
IsOkAndHolds(TickResult{
701+
.execution_state = TickExecutionState::kSentOnChannel,
702+
.channel_instance = channel_instance,
703+
.progress_made = true}));
704+
EXPECT_THAT(
705+
evaluator->Tick(*continuation),
706+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
707+
.channel_instance = std::nullopt,
708+
.progress_made = true}));
709+
EXPECT_EQ(queue.GetSize(), 1);
710+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(1, 32))));
711+
712+
EXPECT_THAT(evaluator->Tick(*continuation),
713+
IsOkAndHolds(TickResult{
714+
.execution_state = TickExecutionState::kSentOnChannel,
715+
.channel_instance = channel_instance,
716+
.progress_made = true}));
717+
EXPECT_THAT(
718+
evaluator->Tick(*continuation),
719+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
720+
.channel_instance = std::nullopt,
721+
.progress_made = true}));
722+
EXPECT_EQ(queue.GetSize(), 1);
723+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(2, 32))));
724+
}
725+
726+
TEST_P(ProcEvaluatorTestBase, CollidingNextValuesProc) {
727+
if (!GetParam().SupportsNextValue()) {
728+
GTEST_SKIP() << "Evaluator does not support next_value";
729+
}
730+
731+
// Create an output-only proc which increments its counter value only every
732+
// other iteration - but also tries to set the counter value to a different
733+
// value.
734+
Package package(TestName());
735+
XLS_ASSERT_OK_AND_ASSIGN(
736+
Channel * channel,
737+
package.CreateStreamingChannel("slow_counter_out", ChannelOps::kSendOnly,
738+
package.GetBitsType(32)));
739+
740+
ProcBuilder pb("slow_counter", /*token_name=*/"tok", &package);
741+
BValue counter = pb.StateElement("counter", Value(UBits(0, 32)));
742+
BValue iteration = pb.StateElement("iteration", Value(UBits(0, 32)));
743+
BValue send = pb.Send(channel, pb.GetTokenParam(), counter);
744+
BValue incremented_counter = pb.Add(counter, pb.Literal(UBits(1, 32)));
745+
BValue odd_iteration = pb.Eq(pb.BitSlice(iteration, /*start=*/0, /*width=*/1),
746+
pb.Literal(UBits(1, 1)));
747+
pb.Next(/*param=*/counter, /*value=*/incremented_counter,
748+
/*pred=*/odd_iteration);
749+
pb.Next(/*param=*/counter, /*value=*/pb.Literal(UBits(0, 32)));
750+
pb.Next(/*param=*/iteration,
751+
/*value=*/pb.Add(iteration, pb.Literal(UBits(1, 32))));
752+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build(send));
753+
754+
std::unique_ptr<ChannelQueueManager> queue_manager =
755+
GetParam().CreateQueueManager(&package);
756+
std::unique_ptr<ProcEvaluator> evaluator = GetParam().CreateEvaluator(
757+
FindProc("slow_counter", &package), queue_manager.get());
758+
759+
ChannelQueue& queue = queue_manager->GetQueue(channel);
760+
XLS_ASSERT_OK_AND_ASSIGN(
761+
ChannelInstance * channel_instance,
762+
queue_manager->elaboration().GetUniqueInstance(channel));
763+
764+
std::unique_ptr<ProcContinuation> continuation = evaluator->NewContinuation(
765+
queue_manager->elaboration().GetUniqueInstance(proc).value());
766+
EXPECT_THAT(evaluator->Tick(*continuation),
767+
IsOkAndHolds(TickResult{
768+
.execution_state = TickExecutionState::kSentOnChannel,
769+
.channel_instance = channel_instance,
770+
.progress_made = true}));
771+
EXPECT_THAT(
772+
evaluator->Tick(*continuation),
773+
IsOkAndHolds(TickResult{.execution_state = TickExecutionState::kCompleted,
774+
.channel_instance = std::nullopt,
775+
.progress_made = true}));
776+
EXPECT_EQ(queue.GetSize(), 1);
777+
EXPECT_THAT(queue.Read(), Optional(Value(UBits(0, 32))));
778+
779+
EXPECT_THAT(evaluator->Tick(*continuation),
780+
IsOkAndHolds(TickResult{
781+
.execution_state = TickExecutionState::kSentOnChannel,
782+
.channel_instance = channel_instance,
783+
.progress_made = true}));
784+
EXPECT_THAT(evaluator->Tick(*continuation),
785+
StatusIs(absl::StatusCode::kAlreadyExists,
786+
HasSubstr("Multiple active next values for param "
787+
"\"counter\" in a single activation")));
788+
}
789+
546790
TEST_P(ProcEvaluatorTestBase, OneToTwoDemux) {
547791
// Build a proc which acts as a one-to-two demux. Data channels are streaming,
548792
// and the selector is a single-value channel.

xls/interpreter/proc_evaluator_test_base.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
#include <functional>
1919
#include <memory>
20+
#include <utility>
2021

21-
#include "gmock/gmock.h"
2222
#include "gtest/gtest.h"
2323
#include "xls/interpreter/channel_queue.h"
2424
#include "xls/interpreter/proc_evaluator.h"
@@ -33,9 +33,11 @@ class ProcEvaluatorTestParam {
3333
std::function<std::unique_ptr<ProcEvaluator>(Proc*, ChannelQueueManager*)>
3434
evaluator_factory,
3535
std::function<std::unique_ptr<ChannelQueueManager>(Package*)>
36-
queue_manager_factory)
37-
: evaluator_factory_(evaluator_factory),
38-
queue_manager_factory_(queue_manager_factory) {}
36+
queue_manager_factory,
37+
bool supports_next_value = false)
38+
: evaluator_factory_(std::move(evaluator_factory)),
39+
queue_manager_factory_(std::move(queue_manager_factory)),
40+
supports_next_value_(supports_next_value) {}
3941
ProcEvaluatorTestParam() = default;
4042

4143
std::unique_ptr<ChannelQueueManager> CreateQueueManager(
@@ -48,11 +50,14 @@ class ProcEvaluatorTestParam {
4850
return evaluator_factory_(proc, queue_manager);
4951
}
5052

53+
bool SupportsNextValue() const { return supports_next_value_; }
54+
5155
private:
5256
std::function<std::unique_ptr<ProcEvaluator>(Proc*, ChannelQueueManager*)>
5357
evaluator_factory_;
5458
std::function<std::unique_ptr<ChannelQueueManager>(Package*)>
5559
queue_manager_factory_;
60+
const bool supports_next_value_;
5661
};
5762

5863
// A suite of test which can be run against arbitrary ProcEvaluator

0 commit comments

Comments
 (0)