Skip to content

Commit 4748584

Browse files
committed
basic work-graph codegen
1 parent 471e3a6 commit 4748584

File tree

7 files changed

+802
-65
lines changed

7 files changed

+802
-65
lines changed

src/backends/common/hlsl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ if (LUISA_COMPUTE_ENABLE_DX OR LUISA_COMPUTE_ENABLE_VULKAN)
5656
codegen_utils/constant.cpp
5757
codegen_utils/entry_points.cpp
5858
codegen_utils/function_codegen.cpp
59+
codegen_utils/work_graph_codegen.cpp
5960
codegen_utils/property.cpp
6061
codegen_utils/resource.cpp
6162
codegen_utils/type_system.cpp

src/backends/common/hlsl/builtin/hlsl_builtin.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
extern "C" {
2525
LC_HLSL_DECL_VARNAME(hlsl_header_bytes)
26+
LC_HLSL_DECL_VARNAME(work_graph_bytes)
2627
LC_HLSL_DECL_VARNAME(dx_linalg_bytes)
2728
LC_HLSL_DECL_VARNAME(hlsl_header_fallback_bytes)
2829
LC_HLSL_DECL_VARNAME(raytracing_header_bytes)
@@ -62,6 +63,7 @@ static HLSLCompressedHeader get_hlsl_builtin(luisa::string_view ss) {
6263
luisa::unordered_map<luisa::string_view, HLSLCompressedHeader> dict;
6364
Dict() {
6465
LC_HLSL_INSERT_VARNAME(hlsl_header_bytes, "hlsl_header")
66+
LC_HLSL_INSERT_VARNAME(work_graph_bytes, "work_graph")
6567
LC_HLSL_INSERT_VARNAME(spv_alias_bytes, "spv_alias")
6668
LC_HLSL_INSERT_VARNAME(dx_linalg_bytes, "dx_linalg")
6769
LC_HLSL_INSERT_VARNAME(hlsl_header_fallback_bytes, "hlsl_header_fallback")

src/backends/common/hlsl/codegen_utils/entry_points.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,9 @@ CodegenResult CodegenUtility::WorkGraphCodegen(
447447
bool noRegister) {
448448
opt = CodegenStackData::Allocate(this);
449449
opt->noRegister = noRegister;
450-
opt->isRaster = true;
450+
opt->isWorkGraph = true;
451451
auto disposeOpt = vstd::scope_exit([&] {
452-
opt->isRaster = false;
452+
opt->isWorkGraph = false;
453453
CodegenStackData::DeAllocate(std::move(opt));
454454
});
455455
vstd::StringBuilder codegenData;
@@ -470,12 +470,48 @@ CodegenResult CodegenUtility::WorkGraphCodegen(
470470
static_cast<void>(vstd::to_string(custom_mask));
471471
finalResult << '\n';
472472

473+
// Add work graph builtin template
474+
finalResult << ReadInternalHLSLFile("work_graph"sv);
475+
finalResult << "\n"sv;
476+
473477
vstd::unordered_set<uint64_t> globalCallableMap;
474-
for (const auto &node : work_graph.nodes()) {
475-
CodegenWorkGraphNode(node, codegenData, globalCallableMap, false /* TODO: handle cbuffer-based arguments */);
478+
const auto& nodes = work_graph.nodes();
479+
const auto& entry_points = work_graph.entry_points();
480+
vstd::unordered_set<uint32_t> entry_point_set;
481+
for (auto ep : entry_points) {
482+
entry_point_set.emplace(ep);
483+
}
484+
for (size_t i = 0; i < nodes.size(); ++i) {
485+
bool is_entry_point = entry_point_set.contains(static_cast<uint32_t>(i));
486+
CodegenWorkGraphNode(work_graph, i, is_entry_point, codegenData, globalCallableMap, false);
476487
}
477488

478-
LUISA_ASSERT(false, "unimplemented");
489+
// Append the generated code for all nodes
490+
finalResult << codegenData;
491+
492+
// Post-process properties (generates struct definitions)
493+
PostprocessCodegenProperties(finalResult, false);
494+
495+
// Create the result
496+
vstd::vector<Type const *> recordTypes;
497+
recordTypes.reserve(nodes.size());
498+
for (auto &&node : nodes) {
499+
if (node.input_record_type != nullptr) {
500+
recordTypes.push_back(node.input_record_type);
501+
}
502+
}
503+
504+
auto result = CodegenResult(
505+
std::move(finalResult),
506+
std::move(opt->printer),
507+
{}, // No properties for work graphs (resource bindings are handled differently)
508+
opt->useTex2DBindless,
509+
opt->useTex3DBindless,
510+
opt->useBufferBindless,
511+
immutableHeaderSize,
512+
GetTypeMD5(recordTypes));
513+
514+
return result;
479515
}
480516

481517
void CodegenUtility::CodegenFunction(Function func, vstd::StringBuilder &result, bool cbufferNonEmpty, bool codegen_self) {

src/backends/common/hlsl/codegen_utils/function_codegen.cpp

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <luisa/core/dynamic_module.h>
1111
#include <luisa/core/logging.h>
1212
#include <luisa/ast/external_function.h>
13-
#include "../constant_printer.h"
1413

1514
// External declaration for shared variable from hlsl_codegen_util.cpp
1615
extern bool shown_buffer_warning;
@@ -1450,15 +1449,16 @@ void CodegenUtility::GetFunctionName(CallExpr const *expr, vstd::StringBuilder &
14501449
return;
14511450
case CallOp::WORK_GRAPH_OUTPUT: {
14521451
LUISA_ASSERT(opt->isWorkGraph, "Work Graph Output can only be used in work graph nodes");
1453-
str << "_work_graph_output(";
1454-
str << "_work_graph_output_";
1452+
str << "_work_graph_output("sv;
1453+
str << "_work_graph_output_"sv;
14551454
auto literal = static_cast<const LiteralExpr *>(args[0]);
14561455
auto output_index = luisa::get<uint>(literal->value());
1457-
str << output_index << ", ";
1456+
vstd::to_string(output_index, str);
1457+
str << ", "sv;
14581458
args[2]->accept(vis);
1459-
str << ", ";
1459+
str << ", "sv;
14601460
args[3]->accept(vis);
1461-
str << ")";
1461+
str << ")"sv;
14621462
} return;
14631463
case CallOp::DDX: {
14641464
if (opt->isRaster) {
@@ -1992,57 +1992,4 @@ o0=pixel(p,primId)"sv;
19921992
// value assignment
19931993
}
19941994

1995-
void CodegenUtility::CodegenWorkGraphNode(const compute::detail::WorkGraphNode& node, vstd::StringBuilder &result, vstd::unordered_set<uint64_t>& callableMap, bool cbufferNonEmpty) {
1996-
auto codegenOneFunc = [&](Function func) {
1997-
auto constants = func.constants();
1998-
for (auto &&i : constants) {
1999-
vstd::StringBuilder constValueName;
2000-
if (!GetConstName(i.hash(), i, constValueName)) continue;
2001-
result << "static const "sv;
2002-
GetTypeName(*i.type(), result, Usage::READ);
2003-
result << ' ' << constValueName << " = "sv;
2004-
CodegenConstantPrinter printer{*this, result};
2005-
i.decode(printer);
2006-
result << ";\n"sv;
2007-
}
2008-
#ifdef LUISA_ENABLE_IR
2009-
vstd::unordered_set<Variable> grad_vars;
2010-
// glob_variables_with_grad(func, grad_vars);
2011-
#endif
2012-
2013-
opt->funcType = CodegenStackData::FuncType::Callable;
2014-
GetFunctionDecl(func, result);
2015-
result << "{\n"sv;
2016-
{
2017-
2018-
StringStateVisitor vis(func, result, this);
2019-
vis.sharedVariables = &opt->sharedVariable;
2020-
vis.VisitFunction(
2021-
#ifdef LUISA_ENABLE_IR
2022-
grad_vars,
2023-
#endif
2024-
func);
2025-
}
2026-
result << "}\n"sv;
2027-
};
2028-
2029-
auto callable = [&](auto &&callable, Function func) -> void {
2030-
for (auto &&i : func.custom_callables()) {
2031-
if (callableMap.emplace(i->hash()).second) {
2032-
callable(callable, i->function());
2033-
}
2034-
}
2035-
codegenOneFunc(func);
2036-
};
2037-
2038-
auto node_func = node.fn_builder->function();
2039-
for (auto &&i : node_func.custom_callables()) {
2040-
if (callableMap.emplace(i->hash()).second) {
2041-
callable(callable, i->function());
2042-
}
2043-
}
2044-
2045-
result << luisa::format("// node name: {}\n", node.name);
2046-
result << luisa::format("void {} (\n", node_func.name());
2047-
}
20481995
}// namespace lc::hlsl

0 commit comments

Comments
 (0)