Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions tests/ttnn/unit_tests/operations/sdpa/test_sdpa_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ def fa_rand(*shape):
return normal_1 + normal_2 * bernoulli


def is_watcher_enabled():
return os.environ.get("TT_METAL_WATCHER") is not None


def create_sliding_window_mask_prefill(b, nh, seq_len, sliding_window=0, is_causal=True):
"""
Create attention mask for sliding window attention in prefill mode.
Expand Down Expand Up @@ -568,7 +564,6 @@ def run_test_sdpa_with_attention_sink(
# ---------------------------------------------------------------------------


@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b], ids=["bfp8"])
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG], ids=["dram_interleaved"])
@pytest.mark.parametrize("q_chunk_size", [128], ids=["q128"])
Expand Down Expand Up @@ -598,7 +593,6 @@ def test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype, me
)


@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@pytest.mark.parametrize("dtype", [ttnn.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_chunk_size", [128], ids=["q128"])
@pytest.mark.parametrize("k_chunk_size", [128], ids=["k128"])
Expand All @@ -615,7 +609,6 @@ def test_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dt
run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype, rmse_threshold=rmse_threshold)


@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b], ids=["bfp8"])
@pytest.mark.parametrize("q_chunk_size", [128], ids=["q128"])
@pytest.mark.parametrize("k_chunk_size", [128], ids=["k128"])
Expand All @@ -635,7 +628,6 @@ def test_sdpa_tt_with_program_cache(device, b, nh, nkv, s, d, q_chunk_size, k_ch
assert device.num_program_cache_entries() == 1


@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b], ids=["bfp8"])
@pytest.mark.parametrize("q_chunk_size", [256], ids=["q256"])
@pytest.mark.parametrize("k_chunk_size", [256], ids=["k256"])
Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/kernel/dataflow/generate_reduce_scaler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

// Tile is assumed to have 16-bit elements
// Scaler is assumed to be a 16-bit value double packed into a u32
template <bool half_tile = false>
FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t scaler) {
cb_reserve_back(cb_id, 1);

constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
constexpr uint32_t num_zeros_reads = (half_tile ? 1024 : 2048) / MEM_ZEROS_SIZE;
static_assert(num_zeros_reads > 0, "num_zeros_reads must be greater than 0");
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
uint32_t write_addr = get_write_ptr(cb_id);
Expand All @@ -27,7 +28,7 @@ FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t sc
noc_async_read_barrier();

if (scaler != 0) {
for (int k = 0; k < 4; ++k) {
for (int k = 0; k < (half_tile ? 2 : 4); ++k) {
uint32_t idx = k << 7;
for (int j = 0; j < 8; ++j) {
ptr[idx + j] = scaler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ void kernel_main() {
constexpr uint32_t cb_out_l = tt::CBIndex::c_18;

// generate and send scaler to compute
// These helper functions respect tile size of CBs (ie. no need for special handling of tiny tiles)
generate_reduce_scaler(cb_identity_scale_in, identity_scalar_packed);
generate_reduce_scaler(cb_zero_in, zero_scalar_packed);
constexpr bool is_half_tile = (get_tile_size(cb_identity_scale_in) < 2 * tt::constants::TILE_HW);
generate_reduce_scaler<is_half_tile>(cb_identity_scale_in, identity_scalar_packed);
generate_reduce_scaler<is_half_tile>(cb_zero_in, zero_scalar_packed);
generate_bcast_col_scalar(cb_col_identity, identity_scalar_packed);

if (k_chunk_start == window_start_chunk && window_start_unaligned > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ SdpaDecodeProgramFactory::cached_program_t SdpaDecodeProgramFactory::create(
for (auto core : core_group_idle) {
log_debug(tt::LogOp, "Setting core {} to idle", core);
// reader runtime args
std::vector<uint32_t> reader_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<uint32_t> reader_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

// writer runtime args
std::vector<uint32_t> writer_rt_args = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Expand Down
Loading