Skip to content

Commit f92df86

Browse files
authored
Add action and pipeline concepts (#180)
1 parent 11682d6 commit f92df86

File tree

5 files changed

+212
-0
lines changed

5 files changed

+212
-0
lines changed

src/AMSlib/wf/action.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include "AMSError.hpp"
4+
5+
namespace ams
6+
{
7+
8+
struct EvalContext; // forward declaration
9+
10+
/// Base class for a single step in an AMS evaluation pipeline.
11+
///
12+
/// Actions mutate the shared EvalContext and may fail; failures are reported
13+
/// via AMSStatus so pipelines can short-circuit cleanly.
14+
class Action
15+
{
16+
public:
17+
virtual ~Action() = default;
18+
19+
/// Execute this action on the evaluation context.
20+
virtual AMSStatus run(EvalContext& ctx) = 0;
21+
22+
/// Human-readable name for debugging, logging, and tracing.
23+
virtual const char* name() const noexcept = 0;
24+
};
25+
26+
} // namespace ams

src/AMSlib/wf/pipeline.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "AMSError.hpp" // AMSStatus
7+
#include "wf/action.hpp" // Action
8+
9+
namespace ams
10+
{
11+
12+
struct EvalContext;
13+
14+
/// A linear sequence of Actions executed in order.
15+
///
16+
/// If any Action fails, execution stops and the error is returned.
17+
class Pipeline
18+
{
19+
public:
20+
using ActionPtr = std::unique_ptr<Action>;
21+
22+
Pipeline() = default;
23+
24+
/// Append an Action to the pipeline.
25+
Pipeline& add(ActionPtr Act)
26+
{
27+
Actions.emplace_back(std::move(Act));
28+
return *this;
29+
}
30+
31+
/// Execute all actions in order; stops on first error.
32+
AMSStatus run(EvalContext& Ctx) const
33+
{
34+
for (const auto& Act : Actions) {
35+
if (auto St = Act->run(Ctx); !St) {
36+
return St;
37+
}
38+
}
39+
return {};
40+
}
41+
42+
/// Number of actions in the pipeline.
43+
size_t size() const noexcept { return Actions.size(); }
44+
45+
/// True if there are no actions.
46+
bool empty() const noexcept { return Actions.empty(); }
47+
48+
/// Remove all actions.
49+
void clear() noexcept { Actions.clear(); }
50+
51+
private:
52+
std::vector<ActionPtr> Actions;
53+
};
54+
55+
} // namespace ams

tests/AMSlib/wf/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,9 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context)
5656

5757
BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp)
5858
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise)
59+
60+
BUILD_UNIT_TEST(action action.cpp)
61+
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action)
62+
63+
BUILD_UNIT_TEST(pipeline pipeline.cpp)
64+
ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline)

tests/AMSlib/wf/action.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "wf/action.hpp"
2+
3+
#include <catch2/catch_test_macros.hpp>
4+
#include <memory>
5+
#include <type_traits>
6+
7+
// Prefer the real EvalContext if available.
8+
// If your project uses a different header name, adjust accordingly.
9+
#include "wf/eval_context.hpp"
10+
11+
namespace ams
12+
{
13+
14+
namespace
15+
{
16+
class TestAction final : public Action
17+
{
18+
public:
19+
const char* name() const noexcept override { return "TestAction"; }
20+
21+
AMSStatus run(EvalContext& ctx) override
22+
{
23+
ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f;
24+
return {};
25+
}
26+
};
27+
} // namespace
28+
29+
CATCH_TEST_CASE("Action: abstract base class + virtual interface",
30+
"[wf][action]")
31+
{
32+
CATCH_STATIC_REQUIRE(std::is_abstract_v<Action>);
33+
CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v<Action>);
34+
35+
EvalContext ctx{};
36+
ctx.Threshold = 0.0f;
37+
std::unique_ptr<Action> act = std::make_unique<TestAction>();
38+
39+
CATCH_REQUIRE(act->name() == std::string("TestAction"));
40+
41+
auto Err = act->run(ctx);
42+
CATCH_REQUIRE(Err);
43+
CATCH_REQUIRE(ctx.Threshold == 1.0f);
44+
45+
auto Err1 = act->run(ctx);
46+
CATCH_REQUIRE(Err1);
47+
CATCH_REQUIRE(ctx.Threshold == 2.0f);
48+
}
49+
50+
} // namespace ams

tests/AMSlib/wf/pipeline.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include "wf/pipeline.hpp"
2+
3+
#include <catch2/catch_test_macros.hpp>
4+
#include <memory>
5+
#include <string>
6+
7+
#include "wf/eval_context.hpp"
8+
9+
namespace ams
10+
{
11+
12+
namespace
13+
{
14+
15+
class IncAction final : public Action
16+
{
17+
public:
18+
const char* name() const noexcept override { return "IncAction"; }
19+
20+
AMSStatus run(EvalContext& ctx) override
21+
{
22+
ctx.Threshold = ctx.Threshold.value_or(0.0f) + 1.0f;
23+
return {};
24+
}
25+
};
26+
27+
class FailAction final : public Action
28+
{
29+
public:
30+
const char* name() const noexcept override { return "FailAction"; }
31+
32+
AMSStatus run(EvalContext&) override
33+
{
34+
return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered");
35+
}
36+
};
37+
38+
} // namespace
39+
40+
CATCH_TEST_CASE("Pipeline runs actions in order and short-circuits on error",
41+
"[wf][pipeline]")
42+
{
43+
EvalContext Ctx{};
44+
Pipeline P;
45+
46+
// Two increments -> Threshold becomes 2, then FailAction stops the pipeline.
47+
P.add(std::make_unique<IncAction>())
48+
.add(std::make_unique<IncAction>())
49+
.add(std::make_unique<FailAction>())
50+
.add(std::make_unique<IncAction>()); // must NOT execute
51+
52+
Ctx.Threshold = 0.0f;
53+
54+
auto St = P.run(Ctx);
55+
CATCH_REQUIRE_FALSE(St);
56+
CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic);
57+
58+
// Only the first two IncAction should have run.
59+
CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f);
60+
}
61+
62+
CATCH_TEST_CASE("Pipeline succeeds when all actions succeed", "[wf][pipeline]")
63+
{
64+
EvalContext Ctx{};
65+
Pipeline P;
66+
67+
P.add(std::make_unique<IncAction>()).add(std::make_unique<IncAction>());
68+
69+
Ctx.Threshold = 0.0f;
70+
auto St = P.run(Ctx);
71+
CATCH_REQUIRE(St);
72+
CATCH_REQUIRE(Ctx.Threshold.value() == 2.0f);
73+
}
74+
75+
} // namespace ams

0 commit comments

Comments
 (0)