-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgfx_utility.cpp
More file actions
170 lines (154 loc) · 6.11 KB
/
Copy pathgfx_utility.cpp
File metadata and controls
170 lines (154 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/*!
\file gfx_utility.cpp
\author Sho Ikeda
\brief GFX utility implementations for GPU resource management
\copyright Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
SPDX-License-Identifier: MIT
*/
#include "gfx_utility.hpp"
// Standard C++ library
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <filesystem>
#include <iostream>
#include <initializer_list>
#include <memory>
#include <span>
#include <string_view>
#include <vector>
// GFX
#include "gfx.h"
#include "gfx_window.h"
// Example
#include "utility.hpp"
namespace ex {
auto createGfxContext(const bool enableDebugShader) -> std::shared_ptr<GfxContext>
{
// Create a GFX context
GfxWindow gfxWindow = gfxCreateWindow(1280, 720, "MiniDXNN", kGfxCreateWindowFlag_HideWindow);
std::shared_ptr<GfxWindow> gfxWindowPtr(new GfxWindow{gfxWindow}, [](GfxWindow* window)
{
if (window != nullptr) {
if (not *window) {
std::cerr << "[ERROR] Failed to destroy the gfx window.\n";
std::abort();
}
GfxAssertTrue{}(gfxDestroyWindow(*window), "Destroying the window failed.");
delete window;
}
});
GfxCreateContextFlags flags = kGfxCreateContextFlag_EnableExperimentalShaders;
if (enableDebugShader) {
flags |= kGfxCreateContextFlag_EnableShaderDebugging;
}
GfxContext gfxContext = gfxCreateContext(*gfxWindowPtr, flags);
auto* rawGfxContext = new GfxContext{gfxContext};
std::shared_ptr<GfxContext> gfxContextPtr(rawGfxContext, [gfxWindowPtr](GfxContext* context) mutable
{
if (context != nullptr) {
if (not *context) {
std::cerr << "[ERROR] Failed to destroy the gfx context.\n";
std::abort();
}
GfxAssertTrue{}(gfxDestroyContext(*context), "Destroying the context failed.");
delete context;
}
gfxWindowPtr.reset();
});
return gfxContextPtr;
}
auto createGfxProgram(GfxContext context, const std::string_view fileName, const std::filesystem::path& dirPath, const std::span<const OptionString> includePathList) -> std::shared_ptr<GfxProgram>
{
const std::string_view shaderMode = "6_10";
std::vector<const char*> pathList;
pathList.resize(includePathList.size());
std::ranges::transform(includePathList, pathList.begin(), [](const OptionString& option) -> const char*
{
return option.data();
});
GfxProgram program = gfxCreateProgram(context, fileName.data(), dirPath.string().c_str(), shaderMode.data(), pathList.data(), static_cast<std::uint32_t>(pathList.size()));
std::shared_ptr<GfxProgram> sharedProgram(new GfxProgram{program}, [context](GfxProgram* ptr)
{
if (ptr != nullptr) {
if (*ptr) {
GfxAssertTrue{}(gfxDestroyProgram(context, *ptr), "Destroying the program failed.");
}
delete ptr;
}
});
return sharedProgram;
}
auto createGfxComputeKernel(GfxContext context, GfxProgram program, const std::string_view entryPoint, const std::span<const OptionString> definitionList) -> std::shared_ptr<GfxKernel>
{
std::vector<const char*> defList;
defList.resize(definitionList.size());
std::ranges::transform(definitionList, defList.begin(), [](const OptionString& option)
{
return option.data();
});
GfxKernel kernel = gfxCreateComputeKernel(context, program, entryPoint.data(), defList.data(), static_cast<std::uint32_t>(defList.size()));
std::shared_ptr<GfxKernel> sharedKernel(new GfxKernel{kernel}, [context](GfxKernel* ptr)
{
if (ptr != nullptr) {
if (*ptr) {
GfxAssertTrue{}(gfxDestroyKernel(context, *ptr), "Destroying the kernel failed.");
}
delete ptr;
}
});
return sharedKernel;
}
auto runKernel(GfxContext context, GfxProgram program, GfxKernel kernel, const size_t threadGroupSize, std::span<const BufferBindingDataT> bufferList, std::initializer_list<IntBindingDataT> intList, OptionalRef<float> execTimeInMs, std::initializer_list<FloatBindingDataT> floatList) -> void
{
// Bind parameters
GfxAssertTrue{}(gfxCommandBindKernel(context, kernel), "Binding the kernel failed.");
for (const BufferBindingDataT& data : bufferList) {
GfxAssertTrue{}(gfxProgramSetBuffer(context, program, data.m_name.data(), data.m_value), "");
}
for (const IntBindingDataT& data : intList) {
GfxAssertTrue{}(gfxProgramSetParameter<std::int32_t>(context, program, data.m_name.data(), data.m_value), "");
}
for (const FloatBindingDataT& data : floatList) {
GfxAssertTrue{}(gfxProgramSetParameter<float>(context, program, data.m_name.data(), data.m_value), "");
}
// Start measuring the kernel execution time
std::shared_ptr<GfxTimestampQuery> timestamp;
if (execTimeInMs.has_value()) {
GfxTimestampQuery stamp = gfxCreateTimestampQuery(context);
timestamp.reset(new GfxTimestampQuery{stamp}, [context](GfxTimestampQuery* t)
{
if (t != nullptr) {
if (*t) {
GfxAssertTrue{}(gfxDestroyTimestampQuery(context, *t), "");
}
delete t;
}
});
}
if (timestamp) GfxAssertTrue{}(gfxCommandBeginTimestampQuery(context, *timestamp), "");
// Run the kernel
GfxAssertTrue{}(gfxCommandDispatch(context, static_cast<std::uint32_t>(threadGroupSize), 1, 1), "Dispatching the command failed.");
// End
if (timestamp) {
GfxAssertTrue{}(gfxCommandEndTimestampQuery(context, *timestamp), "");
GfxAssertTrue{}(gfxCommandResolveTimestamp(context), "");
}
// Wait for the kernel completion
GfxAssertTrue{}(gfxFinish(context), "");
// Kernel execution time
if (timestamp) {
GfxAssertTrue{}(gfxCommandUpdateTimestamp(context), "");
const float execTime = gfxTimestampQueryGetDuration(context, *timestamp);
execTimeInMs->get() = execTime;
}
timestamp.reset();
}
auto runKernel(GfxContext context, GfxProgram program, GfxKernel kernel, const size_t threadGroupSize, std::initializer_list<BufferBindingDataT> bufferList, std::initializer_list<IntBindingDataT> intList, OptionalRef<float> execTimeInMs, std::initializer_list<FloatBindingDataT> floatList) -> void
{
runKernel(context, program, kernel, threadGroupSize,
std::span<const BufferBindingDataT>{bufferList.begin(), bufferList.size()},
intList, execTimeInMs, floatList);
}
} /* namespace ex */