Skip to content

Commit f581e40

Browse files
MSL shader support
Co-authored-by: Isaac Marovitz <[email protected]>
1 parent 990d03b commit f581e40

File tree

9 files changed

+1090
-278
lines changed

9 files changed

+1090
-278
lines changed

XenosRecomp/CMakeLists.txt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ if (WIN32)
44
option(XENOS_RECOMP_DXIL "Generate DXIL shader cache" ON)
55
endif()
66

7+
if (APPLE)
8+
option(XENOS_RECOMP_AIR "Generate Metal AIR shader cache" ON)
9+
endif()
10+
711
set(SMOLV_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../thirdparty/smol-v/source")
812

9-
add_executable(XenosRecomp
13+
add_executable(XenosRecomp
1014
constant_table.h
1115
dxc_compiler.cpp
1216
dxc_compiler.h
@@ -30,13 +34,6 @@ target_precompile_headers(XenosRecomp PRIVATE pch.h)
3034

3135
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
3236
target_compile_options(XenosRecomp PRIVATE -Wno-switch -Wno-unused-variable -Wno-null-arithmetic -fms-extensions)
33-
34-
include(CheckCXXSymbolExists)
35-
check_cxx_symbol_exists(_LIBCPP_VERSION version LIBCPP)
36-
if(LIBCPP)
37-
# Allows using std::execution
38-
target_compile_options(XenosRecomp PRIVATE -fexperimental-library)
39-
endif()
4037
endif()
4138

4239
if (WIN32)
@@ -51,3 +48,8 @@ if (XENOS_RECOMP_DXIL)
5148
target_compile_definitions(XenosRecomp PRIVATE XENOS_RECOMP_DXIL)
5249
target_link_libraries(XenosRecomp PRIVATE Microsoft::DXIL)
5350
endif()
51+
52+
if (XENOS_RECOMP_AIR)
53+
target_compile_definitions(XenosRecomp PRIVATE XENOS_RECOMP_AIR)
54+
target_sources(XenosRecomp PRIVATE air_compiler.cpp air_compiler.h)
55+
endif()

XenosRecomp/air_compiler.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "air_compiler.h"
2+
3+
#include <fstream>
4+
#include <iterator>
5+
#include <spawn.h>
6+
#include <unistd.h>
7+
8+
struct TemporaryPath
9+
{
10+
const std::string path;
11+
12+
explicit TemporaryPath(std::string_view path) : path(path) {}
13+
14+
~TemporaryPath()
15+
{
16+
unlink(path.c_str());
17+
}
18+
};
19+
20+
static int executeCommand(const char** argv)
21+
{
22+
pid_t pid;
23+
if (posix_spawn(&pid, argv[0], nullptr, nullptr, const_cast<char**>(argv), nullptr) != 0)
24+
return -1;
25+
26+
int status;
27+
if (waitpid(pid, &status, 0) == -1)
28+
return -1;
29+
30+
return status;
31+
}
32+
33+
std::vector<uint8_t> AirCompiler::compile(const std::string& shaderSource)
34+
{
35+
// Save source to a location on disk for the compiler to read.
36+
char sourcePathTemplate[PATH_MAX] = "/tmp/xenos_metal_XXXXXX.metal";
37+
const int sourceFd = mkstemps(sourcePathTemplate, 6);
38+
if (sourceFd == -1)
39+
{
40+
fmt::println("Failed to create temporary file for shader source: {}", strerror(errno));
41+
std::exit(1);
42+
}
43+
44+
const TemporaryPath sourcePath(sourcePathTemplate);
45+
const TemporaryPath irPath(sourcePath.path + ".ir");
46+
const TemporaryPath metalLibPath(sourcePath.path + ".metallib");
47+
48+
const ssize_t sourceWritten = write(sourceFd, shaderSource.data(), shaderSource.size());
49+
close(sourceFd);
50+
if (sourceWritten < 0)
51+
{
52+
fmt::println("Failed to write shader source to disk: {}", strerror(errno));
53+
std::exit(1);
54+
}
55+
56+
const char* compileCommand[] = {
57+
"/usr/bin/xcrun", "-sdk", "macosx", "metal", "-o", irPath.path.c_str(), "-c", sourcePath.path.c_str(), "-Wno-unused-variable", "-frecord-sources", "-gline-tables-only", "-D__air__",
58+
#ifdef UNLEASHED_RECOMP
59+
"-DUNLEASHED_RECOMP",
60+
#endif
61+
nullptr
62+
};
63+
if (const int compileStatus = executeCommand(compileCommand); compileStatus != 0)
64+
{
65+
fmt::println("Metal compiler exited with status: {}", compileStatus);
66+
fmt::println("Generated source:\n{}", shaderSource);
67+
std::exit(1);
68+
}
69+
70+
const char* linkCommand[] = { "/usr/bin/xcrun", "-sdk", "macosx", "metallib", "-o", metalLibPath.path.c_str(), irPath.path.c_str(), nullptr };
71+
if (const int linkStatus = executeCommand(linkCommand); linkStatus != 0)
72+
{
73+
fmt::println("Metal linker exited with status: {}", linkStatus);
74+
fmt::println("Generated source:\n{}", shaderSource);
75+
std::exit(1);
76+
}
77+
78+
std::ifstream libStream(metalLibPath.path, std::ios::binary);
79+
std::vector<uint8_t> data((std::istreambuf_iterator(libStream)), std::istreambuf_iterator<char>());
80+
return data;
81+
}

XenosRecomp/air_compiler.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <vector>
5+
6+
class AirCompiler
7+
{
8+
public:
9+
[[nodiscard]] static std::vector<uint8_t> compile(const std::string& shaderSource);
10+
};

XenosRecomp/dxc_compiler.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ IDxcBlob* DxcCompiler::compile(const std::string& shaderSource, bool compilePixe
3434
target = L"-T vs_6_0";
3535
}
3636

37+
if (!compileLibrary)
38+
{
39+
args[argCount++] = L"-E shaderMain";
40+
}
41+
3742
args[argCount++] = target;
3843
args[argCount++] = L"-HV 2021";
3944
args[argCount++] = L"-all-resources-bound";

XenosRecomp/main.cpp

Lines changed: 110 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
#include <deque>
2+
#include <mutex>
3+
#include <thread>
4+
15
#include "shader.h"
26
#include "shader_recompiler.h"
37
#include "dxc_compiler.h"
48

9+
#ifdef XENOS_RECOMP_AIR
10+
#include "air_compiler.h"
11+
#endif
12+
513
static std::unique_ptr<uint8_t[]> readAllBytes(const char* filePath, size_t& fileSize)
614
{
715
FILE* file = fopen(filePath, "rb");
@@ -26,9 +34,43 @@ struct RecompiledShader
2634
uint8_t* data = nullptr;
2735
IDxcBlob* dxil = nullptr;
2836
std::vector<uint8_t> spirv;
37+
std::vector<uint8_t> air;
2938
uint32_t specConstantsMask = 0;
3039
};
3140

41+
void recompileShader(RecompiledShader& shader, const std::string_view include, std::atomic<uint32_t>& progress, uint32_t numShaders)
42+
{
43+
thread_local ShaderRecompiler recompiler;
44+
recompiler = {};
45+
recompiler.recompile(shader.data, include);
46+
47+
shader.specConstantsMask = recompiler.specConstantsMask;
48+
49+
thread_local DxcCompiler dxcCompiler;
50+
51+
#ifdef XENOS_RECOMP_DXIL
52+
shader.dxil = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, recompiler.specConstantsMask != 0, false);
53+
assert(shader.dxil != nullptr);
54+
assert(*(reinterpret_cast<uint32_t *>(shader.dxil->GetBufferPointer()) + 1) != 0 && "DXIL was not signed properly!");
55+
#endif
56+
57+
#ifdef XENOS_RECOMP_AIR
58+
shader.air = AirCompiler::compile(recompiler.out);
59+
#endif
60+
61+
IDxcBlob* spirv = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, false, true);
62+
assert(spirv != nullptr);
63+
64+
bool result = smolv::Encode(spirv->GetBufferPointer(), spirv->GetBufferSize(), shader.spirv, smolv::kEncodeFlagStripDebugInfo);
65+
assert(result);
66+
67+
spirv->Release();
68+
69+
size_t currentProgress = ++progress;
70+
if ((currentProgress % 10) == 0 || (currentProgress == numShaders - 1))
71+
fmt::println("Recompiling shaders... {}%", currentProgress / float(numShaders) * 100.0f);
72+
}
73+
3274
int main(int argc, char** argv)
3375
{
3476
#ifndef XENOS_RECOMP_INPUT
@@ -71,6 +113,7 @@ int main(int argc, char** argv)
71113
{
72114
std::vector<std::unique_ptr<uint8_t[]>> files;
73115
std::map<XXH64_hash_t, RecompiledShader> shaders;
116+
std::map<XXH64_hash_t, std::string> shaderFilenames;
74117

75118
for (auto& file : std::filesystem::recursive_directory_iterator(input))
76119
{
@@ -99,6 +142,7 @@ int main(int argc, char** argv)
99142
{
100143
shader.first->second.data = fileData.get() + i;
101144
foundAny = true;
145+
shaderFilenames[hash] = file.path().string();
102146
}
103147

104148
i += dataSize;
@@ -113,38 +157,42 @@ int main(int argc, char** argv)
113157
files.emplace_back(std::move(fileData));
114158
}
115159

116-
std::atomic<uint32_t> progress = 0;
117-
118-
std::for_each(std::execution::par_unseq, shaders.begin(), shaders.end(), [&](auto& hashShaderPair)
119-
{
120-
auto& shader = hashShaderPair.second;
121-
122-
thread_local ShaderRecompiler recompiler;
123-
recompiler = {};
124-
recompiler.recompile(shader.data, include);
125-
126-
shader.specConstantsMask = recompiler.specConstantsMask;
127-
128-
thread_local DxcCompiler dxcCompiler;
129-
130-
#ifdef XENOS_RECOMP_DXIL
131-
shader.dxil = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, recompiler.specConstantsMask != 0, false);
132-
assert(shader.dxil != nullptr);
133-
assert(*(reinterpret_cast<uint32_t *>(shader.dxil->GetBufferPointer()) + 1) != 0 && "DXIL was not signed properly!");
134-
#endif
135-
136-
IDxcBlob* spirv = dxcCompiler.compile(recompiler.out, recompiler.isPixelShader, false, true);
137-
assert(spirv != nullptr);
138-
139-
bool result = smolv::Encode(spirv->GetBufferPointer(), spirv->GetBufferSize(), shader.spirv, smolv::kEncodeFlagStripDebugInfo);
140-
assert(result);
160+
std::mutex shaderQueueMutex;
161+
std::deque<XXH64_hash_t> shaderQueue;
162+
for (const auto& [hash, _] : shaders)
163+
{
164+
shaderQueue.emplace_back(hash);
165+
}
141166

142-
spirv->Release();
167+
const uint32_t numThreads = std::max(std::thread::hardware_concurrency(), 1u);
168+
fmt::println("Recompiling shaders with {} threads", numThreads);
143169

144-
size_t currentProgress = ++progress;
145-
if ((currentProgress % 10) == 0 || (currentProgress == shaders.size() - 1))
146-
fmt::println("Recompiling shaders... {}%", currentProgress / float(shaders.size()) * 100.0f);
170+
std::atomic<uint32_t> progress = 0;
171+
std::vector<std::thread> threads;
172+
threads.reserve(numThreads);
173+
for (uint32_t i = 0; i < numThreads; i++)
174+
{
175+
threads.emplace_back([&]
176+
{
177+
while (true)
178+
{
179+
XXH64_hash_t shaderHash;
180+
{
181+
std::lock_guard lock(shaderQueueMutex);
182+
if (shaderQueue.empty()) {
183+
return;
184+
}
185+
shaderHash = shaderQueue.front();
186+
shaderQueue.pop_front();
187+
}
188+
recompileShader(shaders[shaderHash], include, progress, shaders.size());
189+
}
147190
});
191+
}
192+
for (auto& thread : threads)
193+
{
194+
thread.join();
195+
}
148196

149197
fmt::println("Creating shader cache...");
150198

@@ -154,18 +202,32 @@ int main(int argc, char** argv)
154202

155203
std::vector<uint8_t> dxil;
156204
std::vector<uint8_t> spirv;
205+
std::vector<uint8_t> air;
157206

158207
for (auto& [hash, shader] : shaders)
159208
{
160-
f.println("\t{{ 0x{:X}, {}, {}, {}, {}, {} }},",
161-
hash, dxil.size(), (shader.dxil != nullptr) ? shader.dxil->GetBufferSize() : 0, spirv.size(), shader.spirv.size(), shader.specConstantsMask);
209+
const std::string& fullFilename = shaderFilenames[hash];
210+
std::string filename = fullFilename;
211+
size_t shaderPos = filename.find("shader");
212+
if (shaderPos != std::string::npos) {
213+
filename = filename.substr(shaderPos);
214+
// Prevent bad escape sequences in Windows shader path.
215+
std::replace(filename.begin(), filename.end(), '\\', '/');
216+
}
217+
f.println("\t{{ 0x{:X}, {}, {}, {}, {}, {}, {}, {}, \"{}\" }},",
218+
hash, dxil.size(), (shader.dxil != nullptr) ? shader.dxil->GetBufferSize() : 0,
219+
spirv.size(), shader.spirv.size(), air.size(), shader.air.size(), shader.specConstantsMask, filename);
162220

163221
if (shader.dxil != nullptr)
164222
{
165223
dxil.insert(dxil.end(), reinterpret_cast<uint8_t *>(shader.dxil->GetBufferPointer()),
166224
reinterpret_cast<uint8_t *>(shader.dxil->GetBufferPointer()) + shader.dxil->GetBufferSize());
167225
}
168-
226+
227+
#ifdef XENOS_RECOMP_AIR
228+
air.insert(air.end(), shader.air.begin(), shader.air.end());
229+
#endif
230+
169231
spirv.insert(spirv.end(), shader.spirv.begin(), shader.spirv.end());
170232
}
171233

@@ -189,6 +251,22 @@ int main(int argc, char** argv)
189251
f.println("const size_t g_dxilCacheDecompressedSize = {};", dxil.size());
190252
#endif
191253

254+
#ifdef XENOS_RECOMP_AIR
255+
fmt::println("Compressing AIR cache...");
256+
257+
std::vector<uint8_t> airCompressed(ZSTD_compressBound(air.size()));
258+
airCompressed.resize(ZSTD_compress(airCompressed.data(), airCompressed.size(), air.data(), air.size(), level));
259+
260+
f.print("const uint8_t g_compressedAirCache[] = {{");
261+
262+
for (auto data : airCompressed)
263+
f.print("{},", data);
264+
265+
f.println("}};");
266+
f.println("const size_t g_airCacheCompressedSize = {};", airCompressed.size());
267+
f.println("const size_t g_airCacheDecompressedSize = {};", air.size());
268+
#endif
269+
192270
fmt::println("Compressing SPIRV cache...");
193271

194272
std::vector<uint8_t> spirvCompressed(ZSTD_compressBound(spirv.size()));

XenosRecomp/pch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#ifdef _WIN32
4+
#define NOMINMAX
45
#include <Windows.h>
56
#endif
67

0 commit comments

Comments
 (0)