Skip to content

Commit 8880939

Browse files
committed
wip(workgraph): add node type + appropriate attributes
1 parent a99f733 commit 8880939

File tree

5 files changed

+86
-28
lines changed

5 files changed

+86
-28
lines changed

include/luisa/dsl/work_graph/work_graph.h

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ struct WorkGraphNode {
2424
luisa::shared_ptr<const detail::FunctionBuilder> fn_builder;
2525

2626
luisa::string name;
27+
WorkGraphLaunchType node_type;
2728
const Type* input_record_type;
2829
luisa::vector<WorkGraphEdge> out_edges;
30+
uint3 threadgroup_dim = uint3(1, 1, 1);
31+
uint3 dispatch_dim = uint3(1, 1, 1);
32+
luisa::optional<uint> dispatch_grid_member = luisa::nullopt;
2933
bool defined = false;
3034
};
3135

@@ -60,7 +64,7 @@ class WorkGraphNodeOutput {
6064
uint _edge_index;
6165
};
6266

63-
template<typename T>
67+
template<WorkGraphLaunchType NodeType, typename T>
6468
class WorkGraphNode {
6569
public:
6670
explicit WorkGraphNode(WorkGraphBuilder* builder, uint node_index) noexcept : _builder(builder), _node_index(node_index) {}
@@ -84,10 +88,8 @@ class WorkGraphNode {
8488
return WorkGraphNodeOutput<EdgeRecord>(_builder, _node_index, e.source_output_index);
8589
}
8690

87-
template<typename InputRecord, typename Def>
88-
void define(const WorkGraphNodeKernel<InputRecord, Def>& kernel) noexcept {
89-
static_assert(std::is_same_v<T, InputRecord>, "type mismatch between work graph node and its definition");
90-
91+
template<typename Def>
92+
void define(const WorkGraphNodeKernel<T, Def>& kernel) noexcept {
9193
LUISA_ASSERT(!inner()->defined, "redefining node kernel is not allowed");
9294

9395
// yoink the function builder, make sure type of input record matches what we were declared with
@@ -111,6 +113,35 @@ class WorkGraphNode {
111113
edge->dest = inner()->index;
112114
}
113115

116+
// specifying NumThreads (threadgroup size) is not allowed for per-thread node launch
117+
void set_threadgroup_size(uint3 size) const requires (NodeType != WorkGraphLaunchType::THREAD) {
118+
inner()->threadgroup_dim = size;
119+
}
120+
121+
// for broadcasting nodes, this is either the static size of a dispatch, or the max size
122+
// (if its dynamically sized using SV_DispatchGrid)
123+
void set_dispatch_size(uint3 size) const requires (NodeType == WorkGraphLaunchType::BROADCASTING) {
124+
inner()->dispatch_dim = size;
125+
}
126+
127+
// sets field in InputRecord type to be marked with SV_DispatchGrid
128+
void set_dispatch_grid_field(uint index) const requires (NodeType == WorkGraphLaunchType::BROADCASTING) {
129+
auto t = Type::of<T>();
130+
auto members = t->members();
131+
132+
LUISA_ASSERT(index < members.size(), "dispatch grid field index out of bounds");
133+
134+
auto member_type = members[index];
135+
LUISA_ASSERT(
136+
member_type == Type::of<uint>() ||
137+
member_type == Type::of<uint2>() ||
138+
member_type == Type::of<uint3>(),
139+
"dispatch grid field must be uint, uint2, or uint3"
140+
);
141+
142+
inner()->dispatch_grid_member = index;
143+
}
144+
114145
private:
115146
// invalidated by modifying graph
116147
[[nodiscard]] detail::WorkGraphNode* inner() const noexcept { return &detail::index_to_node(_builder, _node_index); }
@@ -139,8 +170,8 @@ class WorkGraph {
139170
class WorkGraphBuilder {
140171
public:
141172

142-
template<typename InputRecord>
143-
WorkGraphNode<InputRecord> add_node(luisa::string name) noexcept {
173+
template<WorkGraphLaunchType NodeType, typename InputRecord>
174+
WorkGraphNode<NodeType, InputRecord> add_node(luisa::string name) noexcept {
144175
const Type* input_record_type;
145176
if constexpr (std::is_same_v<InputRecord, WorkGraphEmptyRecord>) {
146177
input_record_type = nullptr;
@@ -155,10 +186,11 @@ class WorkGraphBuilder {
155186
nullptr,
156187

157188
std::move(name),
189+
NodeType,
158190
input_record_type
159191
);
160192

161-
return WorkGraphNode<InputRecord>(this, node_index);
193+
return WorkGraphNode<NodeType, InputRecord>(this, node_index);
162194
}
163195

164196
LUISA_DSL_API WorkGraph build() noexcept;

src/backends/common/hlsl/codegen_utils/work_graph_codegen.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,47 @@ void CodegenUtility::GenerateNodeInputDecl(
5959

6060
// Helper to generate node shader attributes
6161
void CodegenUtility::GenerateNodeShaderAttributes(
62-
bool is_entry_point,
63-
luisa::string_view node_name,
62+
const luisa::compute::detail::WorkGraphNode &node,
6463
vstd::StringBuilder &result) {
6564

6665
// [Shader("node")] attribute (required for all work graph nodes)
6766
result << "[Shader(\"node\")]\n"sv;
6867

69-
// Entry point nodes need [NodeLaunch("broadcasting")]
70-
if (is_entry_point) {
71-
result << "[NodeLaunch(\"broadcasting\")]\n"sv;
68+
switch (node.node_type) {
69+
case WorkGraphLaunchType::BROADCASTING: {
70+
result << "[NodeLaunch(\"broadcasting\")]\n"sv;
71+
72+
luisa::string threadgroup_dim = luisa::format(
73+
"[NumThreads({}, {}, {})]",
74+
node.threadgroup_dim.x, node.threadgroup_dim.y, node.threadgroup_dim.z
75+
);
76+
result << threadgroup_dim << '\n';
77+
78+
luisa::string dispatch_properties;
79+
if (node.dispatch_grid_member) {
80+
dispatch_properties = luisa::format(
81+
"[NodeMaxDispatchGrid({}, {}, {})]",
82+
node.dispatch_dim.x, node.dispatch_dim.y, node.dispatch_dim.z
83+
);
84+
}
85+
else {
86+
dispatch_properties = luisa::format(
87+
"[NodeDispatchGrid({}, {}, {})]",
88+
node.dispatch_dim.x, node.dispatch_dim.y, node.dispatch_dim.z
89+
);
90+
}
91+
result << dispatch_properties << '\n';
92+
} break;
93+
case WorkGraphLaunchType::THREAD: {
94+
result << "[NodeLaunch(\"thread\")]\n"sv;
95+
} break;
7296
}
7397

7498
// Add comment with node name for debugging
75-
if (!node_name.empty()) {
76-
result << "// node name: "sv << node_name << "\n"sv;
99+
if (!node.name.empty()) {
100+
result << "// node name: "sv << node.name << "\n"sv;
77101
}
102+
78103
}
79104

80105
// Helper to generate work graph node function signature
@@ -270,7 +295,7 @@ void CodegenUtility::CodegenWorkGraphNode(const WorkGraph &work_graph, size_t no
270295
}
271296

272297
// Generate node shader attributes using helper
273-
GenerateNodeShaderAttributes(is_entry_point, node.name, result);
298+
GenerateNodeShaderAttributes(node, result);
274299

275300
// Generate node function signature
276301
// use actual name from frontend here, rather than custom_<i>, since node names are meaningful

src/backends/common/hlsl/hlsl_codegen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class CodegenUtility {
7373
static void GenerateMaxRecordsAttribute(uint max_records, vstd::StringBuilder &result);
7474
void GenerateNodeOutputDecl(const Type *record_type, uint max_records, luisa::string_view var_name_prefix, int output_index, vstd::StringBuilder &result);
7575
void GenerateNodeInputDecl(const Type *record_type, luisa::string_view var_name, vstd::StringBuilder &result);
76-
static void GenerateNodeShaderAttributes(bool is_entry_point, luisa::string_view node_name, vstd::StringBuilder &result);
76+
static void GenerateNodeShaderAttributes(const luisa::compute::detail::WorkGraphNode& node, vstd::StringBuilder &result);
7777
void GenerateNodeFunctionSignature(Function node_func, const luisa::compute::detail::WorkGraphNode &node, const luisa::vector<luisa::compute::detail::WorkGraphNode> &all_nodes, vstd::StringBuilder &result);
7878
static void GenerateWorkGraphOutputCall(int output_index, luisa::string_view record_var_name, vstd::StringBuilder &result);
7979
void GenerateRecordStructDef(const Type *record_type, vstd::StringBuilder &result);

src/backends/common/hlsl/test/test_work_graph_codegen.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ WorkGraph create_simple_entry_to_consumer() {
4242
WorkGraphBuilder wg;
4343

4444
// Entry point node (always uses WorkGraphEmptyRecord)
45-
auto entry = wg.add_node<WorkGraphEmptyRecord>("entry_node");
45+
auto entry = wg.add_node<WorkGraphLaunchType::BROADCASTING, WorkGraphEmptyRecord>("entry_node");
4646
auto entry_output = entry.output<SimpleRecord>(16);
4747

4848
WorkGraphNodeKernel entry_kernel = [&]() {
@@ -53,7 +53,7 @@ WorkGraph create_simple_entry_to_consumer() {
5353
entry.define(entry_kernel);
5454

5555
// Consumer node (receives SimpleRecord)
56-
auto consumer = wg.add_node<SimpleRecord>("consumer_node");
56+
auto consumer = wg.add_node<WorkGraphLaunchType::THREAD, SimpleRecord>("consumer_node");
5757
WorkGraphNodeKernel consumer_kernel = [&](Var<SimpleRecord> input) {
5858
auto val = input->value;
5959
(void)val;
@@ -71,7 +71,7 @@ WorkGraph create_multi_output_node() {
7171
WorkGraphBuilder wg;
7272

7373
// Entry with multiple outputs
74-
auto entry = wg.add_node<WorkGraphEmptyRecord>("multi_out_entry");
74+
auto entry = wg.add_node<WorkGraphLaunchType::BROADCASTING, WorkGraphEmptyRecord>("multi_out_entry");
7575
auto output_a = entry.output<SimpleRecord>(8);
7676
auto output_b = entry.output<ComplexRecord>(4);
7777

@@ -90,15 +90,15 @@ WorkGraph create_multi_output_node() {
9090
entry.define(entry_kernel);
9191

9292
// Consumer A (receives SimpleRecord)
93-
auto consumer_a = wg.add_node<SimpleRecord>("consumer_a");
93+
auto consumer_a = wg.add_node<WorkGraphLaunchType::THREAD, SimpleRecord>("consumer_a");
9494
WorkGraphNodeKernel consumer_a_kernel = [&](Var<SimpleRecord> input) {
9595
auto val = input->value;
9696
(void)val;
9797
};
9898
consumer_a.define(consumer_a_kernel);
9999

100100
// Consumer B (receives ComplexRecord)
101-
auto consumer_b = wg.add_node<ComplexRecord>("consumer_b");
101+
auto consumer_b = wg.add_node<WorkGraphLaunchType::THREAD, ComplexRecord>("consumer_b");
102102
WorkGraphNodeKernel consumer_b_kernel = [&](Var<ComplexRecord> input) {
103103
auto id = input->id;
104104
auto pos = input->position;
@@ -119,7 +119,7 @@ WorkGraph create_chained_nodes() {
119119
WorkGraphBuilder wg;
120120

121121
// Node A: Entry point
122-
auto node_a = wg.add_node<WorkGraphEmptyRecord>("node_a");
122+
auto node_a = wg.add_node<WorkGraphLaunchType::BROADCASTING, WorkGraphEmptyRecord>("node_a");
123123
auto output_a = node_a.output<SimpleRecord>(32);
124124

125125
WorkGraphNodeKernel node_a_kernel = [&]() {
@@ -130,7 +130,7 @@ WorkGraph create_chained_nodes() {
130130
node_a.define(node_a_kernel);
131131

132132
// Node B: Middle node (processes and forwards)
133-
auto node_b = wg.add_node<SimpleRecord>("node_b");
133+
auto node_b = wg.add_node<WorkGraphLaunchType::THREAD, SimpleRecord>("node_b");
134134
auto output_b = node_b.output<SimpleRecord>(16);
135135

136136
WorkGraphNodeKernel node_b_kernel = [&](Var<SimpleRecord> input) {
@@ -141,7 +141,7 @@ WorkGraph create_chained_nodes() {
141141
node_b.define(node_b_kernel);
142142

143143
// Node C: Final node
144-
auto node_c = wg.add_node<SimpleRecord>("node_c");
144+
auto node_c = wg.add_node<WorkGraphLaunchType::THREAD, SimpleRecord>("node_c");
145145
WorkGraphNodeKernel node_c_kernel = [&](Var<SimpleRecord> input) {
146146
auto val = input->value;
147147
(void)val;
@@ -160,7 +160,7 @@ WorkGraph create_terminal_entry_node() {
160160
WorkGraphBuilder wg;
161161

162162
// Entry point that doesn't output anything
163-
auto entry = wg.add_node<WorkGraphEmptyRecord>("terminal_entry");
163+
auto entry = wg.add_node<WorkGraphLaunchType::BROADCASTING, WorkGraphEmptyRecord>("terminal_entry");
164164

165165
WorkGraphNodeKernel entry_kernel = [&]() {
166166
// Just do some work, no outputs

src/tests/test_work_graph.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ LUISA_STRUCT(ConsumerRecord, datum) {};
1919
WorkGraph describe_work_graph() {
2020
WorkGraphBuilder work_graph;
2121

22-
auto producer = work_graph.add_node<WorkGraphEmptyRecord>("producer");
22+
auto producer = work_graph.add_node<WorkGraphLaunchType::BROADCASTING, WorkGraphEmptyRecord>("producer");
23+
producer.set_threadgroup_size({64, 1, 1});
2324
auto producer_output = producer.output<ConsumerRecord>(16);
2425
WorkGraphNodeKernel producer_kernel = [&]() {
2526
Var<ConsumerRecord> out;
@@ -28,7 +29,7 @@ WorkGraph describe_work_graph() {
2829
producer.define(producer_kernel);
2930

3031

31-
auto consumer = work_graph.add_node<ConsumerRecord>("consumer");
32+
auto consumer = work_graph.add_node<WorkGraphLaunchType::THREAD, ConsumerRecord>("consumer");
3233
WorkGraphNodeKernel consumer_kernel = [&](Var<ConsumerRecord> input) {
3334
// do work
3435
};

0 commit comments

Comments
 (0)