1+ #include < deque>
2+ #include < mutex>
3+ #include < thread>
4+
15#include " shader.h"
26#include " shader_recompiler.h"
37#include " dxc_compiler.h"
48
9+ #ifdef XENOS_RECOMP_AIR
10+ #include " air_compiler.h"
11+ #endif
12+
513static std::unique_ptr<uint8_t []> readAllBytes (const char * filePath, size_t & fileSize)
614{
715 FILE* file = fopen (filePath, " rb" );
@@ -26,9 +34,43 @@ struct RecompiledShader
2634 uint8_t * data = nullptr ;
2735 IDxcBlob* dxil = nullptr ;
2836 std::vector<uint8_t > spirv;
37+ std::vector<uint8_t > air;
2938 uint32_t specConstantsMask = 0 ;
3039};
3140
41+ void recompileShader (RecompiledShader& shader, const std::string_view include, std::atomic<uint32_t >& progress, uint32_t numShaders)
42+ {
43+ thread_local ShaderRecompiler recompiler;
44+ recompiler = {};
45+ recompiler.recompile (shader.data , include);
46+
47+ shader.specConstantsMask = recompiler.specConstantsMask ;
48+
49+ thread_local DxcCompiler dxcCompiler;
50+
51+ #ifdef XENOS_RECOMP_DXIL
52+ shader.dxil = dxcCompiler.compile (recompiler.out , recompiler.isPixelShader , recompiler.specConstantsMask != 0 , false );
53+ assert (shader.dxil != nullptr );
54+ assert (*(reinterpret_cast <uint32_t *>(shader.dxil ->GetBufferPointer ()) + 1 ) != 0 && " DXIL was not signed properly!" );
55+ #endif
56+
57+ #ifdef XENOS_RECOMP_AIR
58+ shader.air = AirCompiler::compile (recompiler.out );
59+ #endif
60+
61+ IDxcBlob* spirv = dxcCompiler.compile (recompiler.out , recompiler.isPixelShader , false , true );
62+ assert (spirv != nullptr );
63+
64+ bool result = smolv::Encode (spirv->GetBufferPointer (), spirv->GetBufferSize (), shader.spirv , smolv::kEncodeFlagStripDebugInfo );
65+ assert (result);
66+
67+ spirv->Release ();
68+
69+ size_t currentProgress = ++progress;
70+ if ((currentProgress % 10 ) == 0 || (currentProgress == numShaders - 1 ))
71+ fmt::println (" Recompiling shaders... {}%" , currentProgress / float (numShaders) * 100 .0f );
72+ }
73+
3274int main (int argc, char ** argv)
3375{
3476#ifndef XENOS_RECOMP_INPUT
@@ -71,6 +113,7 @@ int main(int argc, char** argv)
71113 {
72114 std::vector<std::unique_ptr<uint8_t []>> files;
73115 std::map<XXH64_hash_t, RecompiledShader> shaders;
116+ std::map<XXH64_hash_t, std::string> shaderFilenames;
74117
75118 for (auto & file : std::filesystem::recursive_directory_iterator (input))
76119 {
@@ -99,6 +142,7 @@ int main(int argc, char** argv)
99142 {
100143 shader.first ->second .data = fileData.get () + i;
101144 foundAny = true ;
145+ shaderFilenames[hash] = file.path ().string ();
102146 }
103147
104148 i += dataSize;
@@ -113,38 +157,42 @@ int main(int argc, char** argv)
113157 files.emplace_back (std::move (fileData));
114158 }
115159
116- std::atomic<uint32_t > progress = 0 ;
117-
118- std::for_each (std::execution::par_unseq, shaders.begin (), shaders.end (), [&](auto & hashShaderPair)
119- {
120- auto & shader = hashShaderPair.second ;
121-
122- thread_local ShaderRecompiler recompiler;
123- recompiler = {};
124- recompiler.recompile (shader.data , include);
125-
126- shader.specConstantsMask = recompiler.specConstantsMask ;
127-
128- thread_local DxcCompiler dxcCompiler;
129-
130- #ifdef XENOS_RECOMP_DXIL
131- shader.dxil = dxcCompiler.compile (recompiler.out , recompiler.isPixelShader , recompiler.specConstantsMask != 0 , false );
132- assert (shader.dxil != nullptr );
133- assert (*(reinterpret_cast <uint32_t *>(shader.dxil ->GetBufferPointer ()) + 1 ) != 0 && " DXIL was not signed properly!" );
134- #endif
135-
136- IDxcBlob* spirv = dxcCompiler.compile (recompiler.out , recompiler.isPixelShader , false , true );
137- assert (spirv != nullptr );
138-
139- bool result = smolv::Encode (spirv->GetBufferPointer (), spirv->GetBufferSize (), shader.spirv , smolv::kEncodeFlagStripDebugInfo );
140- assert (result);
160+ std::mutex shaderQueueMutex;
161+ std::deque<XXH64_hash_t> shaderQueue;
162+ for (const auto & [hash, _] : shaders)
163+ {
164+ shaderQueue.emplace_back (hash);
165+ }
141166
142- spirv->Release ();
167+ const uint32_t numThreads = std::max (std::thread::hardware_concurrency (), 1u );
168+ fmt::println (" Recompiling shaders with {} threads" , numThreads);
143169
144- size_t currentProgress = ++progress;
145- if ((currentProgress % 10 ) == 0 || (currentProgress == shaders.size () - 1 ))
146- fmt::println (" Recompiling shaders... {}%" , currentProgress / float (shaders.size ()) * 100 .0f );
170+ std::atomic<uint32_t > progress = 0 ;
171+ std::vector<std::thread> threads;
172+ threads.reserve (numThreads);
173+ for (uint32_t i = 0 ; i < numThreads; i++)
174+ {
175+ threads.emplace_back ([&]
176+ {
177+ while (true )
178+ {
179+ XXH64_hash_t shaderHash;
180+ {
181+ std::lock_guard lock (shaderQueueMutex);
182+ if (shaderQueue.empty ()) {
183+ return ;
184+ }
185+ shaderHash = shaderQueue.front ();
186+ shaderQueue.pop_front ();
187+ }
188+ recompileShader (shaders[shaderHash], include, progress, shaders.size ());
189+ }
147190 });
191+ }
192+ for (auto & thread : threads)
193+ {
194+ thread.join ();
195+ }
148196
149197 fmt::println (" Creating shader cache..." );
150198
@@ -154,18 +202,32 @@ int main(int argc, char** argv)
154202
155203 std::vector<uint8_t > dxil;
156204 std::vector<uint8_t > spirv;
205+ std::vector<uint8_t > air;
157206
158207 for (auto & [hash, shader] : shaders)
159208 {
160- f.println (" \t {{ 0x{:X}, {}, {}, {}, {}, {} }}," ,
161- hash, dxil.size (), (shader.dxil != nullptr ) ? shader.dxil ->GetBufferSize () : 0 , spirv.size (), shader.spirv .size (), shader.specConstantsMask );
209+ const std::string& fullFilename = shaderFilenames[hash];
210+ std::string filename = fullFilename;
211+ size_t shaderPos = filename.find (" shader" );
212+ if (shaderPos != std::string::npos) {
213+ filename = filename.substr (shaderPos);
214+ // Prevent bad escape sequences in Windows shader path.
215+ std::replace (filename.begin (), filename.end (), ' \\ ' , ' /' );
216+ }
217+ f.println (" \t {{ 0x{:X}, {}, {}, {}, {}, {}, {}, {}, \" {}\" }}," ,
218+ hash, dxil.size (), (shader.dxil != nullptr ) ? shader.dxil ->GetBufferSize () : 0 ,
219+ spirv.size (), shader.spirv .size (), air.size (), shader.air .size (), shader.specConstantsMask , filename);
162220
163221 if (shader.dxil != nullptr )
164222 {
165223 dxil.insert (dxil.end (), reinterpret_cast <uint8_t *>(shader.dxil ->GetBufferPointer ()),
166224 reinterpret_cast <uint8_t *>(shader.dxil ->GetBufferPointer ()) + shader.dxil ->GetBufferSize ());
167225 }
168-
226+
227+ #ifdef XENOS_RECOMP_AIR
228+ air.insert (air.end (), shader.air .begin (), shader.air .end ());
229+ #endif
230+
169231 spirv.insert (spirv.end (), shader.spirv .begin (), shader.spirv .end ());
170232 }
171233
@@ -189,6 +251,22 @@ int main(int argc, char** argv)
189251 f.println (" const size_t g_dxilCacheDecompressedSize = {};" , dxil.size ());
190252#endif
191253
254+ #ifdef XENOS_RECOMP_AIR
255+ fmt::println (" Compressing AIR cache..." );
256+
257+ std::vector<uint8_t > airCompressed (ZSTD_compressBound (air.size ()));
258+ airCompressed.resize (ZSTD_compress (airCompressed.data (), airCompressed.size (), air.data (), air.size (), level));
259+
260+ f.print (" const uint8_t g_compressedAirCache[] = {{" );
261+
262+ for (auto data : airCompressed)
263+ f.print (" {}," , data);
264+
265+ f.println (" }};" );
266+ f.println (" const size_t g_airCacheCompressedSize = {};" , airCompressed.size ());
267+ f.println (" const size_t g_airCacheDecompressedSize = {};" , air.size ());
268+ #endif
269+
192270 fmt::println (" Compressing SPIRV cache..." );
193271
194272 std::vector<uint8_t > spirvCompressed (ZSTD_compressBound (spirv.size ()));
0 commit comments