Skip to content

Commit bfd3b81

Browse files
Unify kernel and firmware JIT build deduplication into JitBuildCache (#37452)
### Ticket #31189 ### Problem description Multi-thread concurrency handling for JIT build for the same target is handled by HashLookup for kernels, and mutex+set for firmware at call sites. This logic can be absorbed into jit build and the call sites can be cleaned up. ### What's changed Replace two separate concurrency/deduplication mechanisms (HashLookup for kernels, mutex+set for firmware) with a single JitBuildCache class in the jit_build module. Callers now use jit_build_once(hash, fn) which guarantees exactly-once execution per hash, with concurrent callers blocking until the build completes. Fixed a subtle bug in the old HashLookup::clear() where a thread waiting in wait_for_bin_generated() could sleep forever if clear() was called concurrently. ### Checklist - [ ] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=ruizhang/jit_build_cache)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:ruizhang/jit_build_cache) - [ ] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=ruizhang/jit_build_cache)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:ruizhang/jit_build_cache) - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=ruizhang/jit_build_cache)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:ruizhang/jit_build_cache) (failures in tools test is due to [recent change](#37147 (comment))) - [ ] New/Existing tests provide coverage for changes --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent ad0d780 commit bfd3b81

File tree

12 files changed

+135
-92
lines changed

12 files changed

+135
-92
lines changed

tests/tt_metal/tt_metal/api/test_kernel_compile_cache.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#include <tt-metalium/core_coord.hpp>
1515
#include <tt-metalium/data_types.hpp>
16-
#include "detail/kernel_cache.hpp"
1716
#include <tt-metalium/device.hpp>
1817
#include "device_fixture.hpp"
1918
#include <tt-metalium/hal.hpp>
@@ -35,7 +34,7 @@ TEST_F(MeshDeviceFixture, TensixTestEquivalentDataMovementKernelsWithDifferentPr
3534

3635
for (const auto& mesh_device : this->devices_) {
3736
auto* device = mesh_device->get_devices()[0];
38-
detail::ClearKernelCache();
37+
jit_build_cache_clear();
3938

4039
DataMovementConfig config_riscv_0 = {.processor = DataMovementProcessor::RISCV_0};
4140
DataMovementConfig config_riscv_1 = {.processor = DataMovementProcessor::RISCV_1};

tests/tt_metal/tt_metal/test_compile_program.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
#include "hal_types.hpp"
1919
#include "jit_build/build.hpp"
20-
#include "tt_metal/detail/kernel_cache.hpp"
2120
#include "tt_metal/jit_build/build_env_manager.hpp"
2221
#include <umd/device/types/arch.hpp>
2322

@@ -38,7 +37,7 @@ struct KernelCacheStatus {
3837

3938
void ClearKernelCache(const std::string& kernel_root_path) {
4039
std::filesystem::remove_all(kernel_root_path);
41-
detail::HashLookup::inst().clear();
40+
jit_build_cache_clear();
4241
}
4342

4443
// This assumes binaries are written to specific location: kernel_compile_outpath / kernel_name / hash

tests/tt_metal/tt_metal/test_compile_sets_kernel_binaries.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "impl/program/program_impl.hpp"
2626
#include "impl/kernels/kernel.hpp"
2727
#include "tt_memory.h"
28-
#include "tt_metal/detail/kernel_cache.hpp"
2928
#include "tt_metal/jit_build/build_env_manager.hpp"
3029
#include <umd/device/types/arch.hpp>
3130

@@ -184,7 +183,7 @@ TEST_F(CompileSetsKernelBinariesFixture, CompileSetsKernelBinaries) {
184183
kernel_name);
185184
}
186185
}
187-
detail::ClearKernelCache();
186+
jit_build_cache_clear();
188187
std::vector<Program> new_programs;
189188
for (int i = 0; i < num_devices_; i++) {
190189
auto& device = devices_[i];

tt_metal/detail/kernel_cache.hpp

Lines changed: 0 additions & 68 deletions
This file was deleted.

tt_metal/impl/context/metal_context.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "dispatch/topology.hpp"
3333
#include "dispatch/dispatch_core_common.hpp"
3434
#include "profiler/profiler_state_manager.hpp"
35+
#include "jit_build/build.hpp"
3536
#include "jit_build/build_env_manager.hpp"
3637
#include "llrt/get_platform_architecture.hpp"
3738
#include "llrt/llrt.hpp"
@@ -273,19 +274,13 @@ void MetalContext::initialize(
273274
BuildEnvManager::get_instance().add_build_env(device_id, num_hw_cqs_);
274275
// fw_build_key is a combination of build_key and fw_compile_hash
275276
// If fw_compile_hash changes, the fw_build_key will change and FW will be rebuilt
276-
// if it's not already in firmware_built_keys_
277277
// Combine build_key and fw_compile_hash using XOR to create unique firmware build key
278278
// Uses full 64-bit fw_compile_hash for proper change detection
279279
uint64_t fw_build_key =
280280
BuildEnvManager::get_instance().get_device_build_env(device_id).build_key() ^ fw_compile_hash;
281281

282-
{
283-
std::lock_guard<std::mutex> lock(firmware_built_keys_mutex_);
284-
if (!firmware_built_keys_.contains(fw_build_key)) {
285-
BuildEnvManager::get_instance().build_firmware(device_id);
286-
firmware_built_keys_.insert(fw_build_key);
287-
}
288-
}
282+
jit_build_once(
283+
fw_build_key, [device_id] { BuildEnvManager::get_instance().build_firmware(device_id); });
289284

290285
// Clear the entire launch message ring buffer on ethernet cores before application firmware is
291286
// activated. This is required since ethernet cores context switch between application and routing

tt_metal/impl/context/metal_context.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ class MetalContext {
201201
size_t worker_l1_unreserved_start_ = 0;
202202
size_t fw_compile_hash_ = 0; // To check if FW recompilation is needed
203203

204-
// Used to track which FW has been built already
205-
std::unordered_set<uint64_t> firmware_built_keys_;
206-
std::mutex firmware_built_keys_mutex_;
207-
208204
// Mutex to protect control_plane_ for thread-safe access
209205
std::mutex control_plane_mutex_;
210206

tt_metal/impl/program/program.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
#include "sub_device_types.hpp"
6666
#include "tile.hpp"
6767
#include "tt_memory.h"
68-
#include "tt_metal/detail/kernel_cache.hpp"
6968
#include "tt_metal/impl/debug/inspector/inspector.hpp"
7069
#include "tt_metal/impl/dispatch/data_collection.hpp"
7170
#include "tt_metal/impl/dispatch/device_command.hpp"
@@ -190,7 +189,7 @@ size_t KernelCompileHash(const std::shared_ptr<Kernel>& kernel, JitBuildOptions&
190189

191190
namespace experimental {
192191

193-
void ClearKernelCache() { detail::HashLookup::inst().clear(); }
192+
void ClearKernelCache() { jit_build_cache_clear(); }
194193

195194
} // namespace experimental
196195

@@ -1487,17 +1486,15 @@ void detail::ProgramImpl::compile(IDevice* device, bool force_slow_dispatch) {
14871486
bool is_mock = tt::tt_metal::MetalContext::instance().get_cluster().get_target_device_type() ==
14881487
tt::TargetDevice::Mock;
14891488

1490-
if (detail::HashLookup::inst().add(kernel_hash)) {
1489+
jit_build_once(kernel_hash, [&] {
14911490
if (!is_mock) {
14921491
GenerateBinaries(device, build_options, kernel);
14931492
} else {
14941493
// Create empty stub binaries for mock devices
14951494
std::vector<const ll_api::memory*> empty_binaries(kernel->expected_num_binaries(), nullptr);
14961495
kernel->set_binaries(build_env.build_key(), std::move(empty_binaries));
14971496
}
1498-
detail::HashLookup::inst().add_generated_bin(kernel_hash);
1499-
}
1500-
detail::HashLookup::inst().wait_for_bin_generated(kernel_hash);
1497+
});
15011498

15021499
Inspector::program_kernel_compile_finished(this, device, kernel, build_options);
15031500
},

tt_metal/jit_build/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(JIT_BUILD_SRCS
44
${CMAKE_CURRENT_SOURCE_DIR}/data_format.cpp
55
${CMAKE_CURRENT_SOURCE_DIR}/depend.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/genfiles.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/jit_build_cache.cpp
78
${CMAKE_CURRENT_SOURCE_DIR}/kernel_args.cpp
89
${CMAKE_CURRENT_SOURCE_DIR}/jit_build_options.cpp
910
${CMAKE_CURRENT_SOURCE_DIR}/jit_build_utils.cpp

tt_metal/jit_build/build.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "build.hpp"
66

7+
#include "jit_build_cache.hpp"
8+
79
#include <algorithm>
810
#include <array>
911
#include <atomic>
@@ -748,4 +750,10 @@ void sync_build_steps(std::vector<std::shared_future<void>>& events) {
748750
}
749751
}
750752

753+
void jit_build_once(size_t hash, const std::function<void()>& build_fn) {
754+
JitBuildCache::inst().build_once(hash, build_fn);
755+
}
756+
757+
void jit_build_cache_clear() { JitBuildCache::inst().clear(); }
758+
751759
} // namespace tt::tt_metal

tt_metal/jit_build/build.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,13 @@ void jit_link_additional_processor(
168168
void launch_build_step(const std::function<void()>& build_func, std::vector<std::shared_future<void>>& events);
169169
void sync_build_steps(std::vector<std::shared_future<void>>& events);
170170

171+
// Execute build_fn exactly once for a given hash.
172+
// Concurrent callers with the same hash block until the build completes.
173+
// Returns immediately if hash was already built.
174+
// If build_fn throws, subsequent callers will retry.
175+
void jit_build_once(size_t hash, const std::function<void()>& build_fn);
176+
177+
// Clear the JIT build cache so that subsequent jit_build_once() calls re-execute.
178+
void jit_build_cache_clear();
179+
171180
} // namespace tt::tt_metal

0 commit comments

Comments
 (0)