diff --git a/ynnpack/subgraph/dot.cc b/ynnpack/subgraph/dot.cc index 82799631db0..76b90c8b1c5 100644 --- a/ynnpack/subgraph/dot.cc +++ b/ynnpack/subgraph/dot.cc @@ -69,11 +69,25 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, return [type, kernel_flags, transposed_a, pack_b, num_k_dims]( slinky::raw_buffer a, slinky::raw_buffer b, - slinky::raw_buffer init_c, slinky::raw_buffer c) -> index_t { + slinky::raw_buffer init_c, slinky::raw_buffer c, + const slinky::raw_buffer& reduction_bounds) -> index_t { // If the dot has fewer than 3 reduction dimensions, we use this dummy // dimension instead. const slinky::dim& dummy_dim = slinky::dim::broadcast(); + const slinky::dim& r_k1 = reduction_bounds.dim(0); + const slinky::dim& r_k2 = + num_k_dims >= 2 ? reduction_bounds.dim(1) : dummy_dim; + const slinky::dim& r_k3 = + num_k_dims >= 3 ? reduction_bounds.dim(2) : dummy_dim; + + index_t k1_min = r_k1.min(); + index_t k1_extent = r_k1.extent(); + index_t k2_min = num_k_dims >= 2 ? r_k2.min() : 0; + index_t k2_extent = num_k_dims >= 2 ? r_k2.extent() : 1; + index_t k3_min = num_k_dims >= 3 ? r_k3.min() : 0; + index_t k3_extent = num_k_dims >= 3 ? r_k3.extent() : 1; + // Learn what we need to know about m, n, k1, k2, k3 before slicing them. const int a_k1_dim = transposed_a ? 1 : 0; const slinky::dim& init_c_m = init_c.dim(1); @@ -94,17 +108,17 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const int b_type_element_count = type_element_count(type.b); const index_t tile_k = b_k1i.extent() * b_type_element_count; - // If a is transposed, then the k dimension has been reshaped to have tile_k - // values in each element. + // If a is transposed, then the k dimension has been reshaped to have + // `tile_k` values in each element. const index_t a_tile_k = a_k1i.extent(); const index_t a_stride_m = a_m.stride(); const index_t a_stride_k3 = a_k3.stride(); const index_t a_stride_k2 = a_k2.stride(); const index_t a_stride_k1 = a_k1o.stride() / a_tile_k; - const index_t k1 = (a_k1o.extent() * a_tile_k) & ~(tile_k - 1); - const index_t k1_tail = (a_k1o.extent() * a_tile_k) & (tile_k - 1); - const index_t k2 = a_k2.extent(); - const index_t k3 = a_k3.extent(); + const index_t k1 = k1_extent & ~(tile_k - 1); + const index_t k1_tail = k1_extent & (tile_k - 1); + const index_t k2 = k2_extent; + const index_t k3 = k3_extent; const index_t block_n = pack_b ? b_ni.extent() : c_n.extent(); const index_t b_stride_k3 = b_k3.stride(); const index_t b_stride_k2 = b_k2.stride(); @@ -126,6 +140,15 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const index_t c_stride_n = c_n.stride(); index_t init_c_stride_m = init_c_m.stride(); + a.base = offset_bytes(a.base, + (k3_min - a_k3.min()) * a_stride_k3 + + (k2_min - a_k2.min()) * a_stride_k2 + + (k1_min - a_k1o.min() * a_tile_k) * a_stride_k1); + b.base = + offset_bytes(b.base, (k3_min - b_k3.min()) * b_stride_k3 + + (k2_min - b_k2.min()) * b_stride_k2 + + (k1_min - b_k1o.min() * tile_k) * b_stride_k1); + // Find a kernel that is compatible with the packed data we have, and // matches whether A is transposed or not. dot_shape shape; @@ -151,17 +174,11 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const index_t block_k = kernel.block_k; assert(a_k1i.min() == 0); - assert(a_k1o.min() == 0); assert(a_tile_k == 1 || a_k1i.stride() == a.elem_size); - assert(a_k2.min() == 0); - assert(a_k3.min() == 0); assert(b_k1i.min() == 0); assert(b_k1i.extent() == 1 || b_k1i.stride() == b.elem_size); assert(b_ni.min() == 0); assert(b_ni.extent() == 1 || b_ni.stride() == b.elem_size * b_k1i.extent()); - assert(b_k1o.min() == 0); - assert(b_k2.min() == 0); - assert(b_k3.min() == 0); assert(!init_c_m.is_folded()); assert(!init_c_n.is_folded()); assert(!c_m.is_folded()); @@ -176,18 +193,30 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, assert(!b_k2.is_folded()); assert(!b_k3.is_folded()); - if (init_c.base && init_c.base != c.base && c_n.extent() > 1) { - if (init_c_n.stride() == 0) { - // The initializer is broadcasted in the n dimension, which the kernel - // cannot handle. We need to copy it to the output, and update the - // initializer to point to the output. - slinky::copy(init_c, c); - init_c_stride_m = c_stride_m; - init_c = c; - } else { - assert(init_c_n.stride() == c_stride_n); + bool init_output = true; + for (int i = 0; i < reduction_bounds.rank; ++i) { + if (reduction_bounds.dim(i).min() != 0) { + init_output = false; + break; } } + if (init_output) { + if (init_c.base && init_c.base != c.base && c_n.extent() > 1) { + if (init_c_n.stride() == 0) { + // The initializer is broadcasted in the n dimension, which the kernel + // cannot handle. We need to copy it to the output, and update the + // initializer to point to the output. + slinky::copy(init_c, c); + init_c_stride_m = c_stride_m; + init_c = c; + } else { + assert(init_c_n.stride() == c_stride_n); + } + } + } else { + init_c_stride_m = c_stride_m; + init_c = c; + } // `for_each_element` below handles the batch dimensions, we handle the loop // over m, and the kernel handles the rest (n, k1, k2, k3). We need to slice @@ -273,8 +302,8 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, auto loops = schedule_dot(cache_sizes, c_m.extent(), c_n.extent(), k_tail, block_m, block_n, block_k, a.elem_size, b.elem_size, loops_storage); - // Dot kernels can't handle k1 not aligned to tile_k. We handle that here - // by making a padded copy of the unaligned elements and calling the + // Dot kernels can't handle k1 not aligned to tile_k. We handle that + // here by making a padded copy of the unaligned elements and calling the // kernel again. // // We do this padding+kernel call once for each value of k3, k2, which @@ -362,7 +391,7 @@ auto make_pack_impl(int elem_count) { (void)output_ki; input.slice(0, output_no.min() * block_n / elem_count); - input.slice(0); + input.slice(0, slinky::in_bounds{output_ko.min() * tile_k}); output.slice({0, 1, 2, 3}); // Depending on the strides of the input, we might use either an interleave @@ -373,12 +402,17 @@ auto make_pack_impl(int elem_count) { transpose ? input_n.stride() : input_k.stride(); // We need the extent of the intersection of the input and output bounds. - assert(output_ko.min() == 0); - const index_t k = std::min(output_ko.end() * tile_k, input_k.end()); + const index_t k = + std::max(0, std::min(output_ko.end() * tile_k, input_k.end()) - + output_ko.min() * tile_k); assert(input_n.min() * elem_count <= output_no.min() * block_n); - const index_t n = - (std::min(output_no.end() * block_n, input_n.end() * elem_count) - - output_no.begin() * block_n); + // For sub-byte datatypes (e.g. int4), Slinky's buffer extents represent + // physical bytes, not logical elements. We must multiply `input_n.end()` by + // `elem_count` to convert the available input bounds into logical elements + // before intersecting with the output bounds. + const index_t n = std::max( + 0, (std::min(output_no.end() * block_n, input_n.end() * elem_count) - + output_no.min() * block_n)); packer p(transpose, elem_size * 8 / elem_count, tile_k, block_n); @@ -468,10 +502,8 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type, slinky::var ni = dims[1]; slinky::var ko = dims[2]; slinky::var no = dims[3]; - func_input.bounds = { - slinky::point((no * block_n + ni) / element_count), - slinky::point((ko * tile_k + ki) * element_count), - }; + func_input.bounds = {slinky::point(no * block_n + ni), + slinky::point(ko * tile_k + ki)}; for (size_t i = 4; i < dims.size(); ++i) { func_input.bounds.push_back(slinky::point(dims[i])); } @@ -549,11 +581,12 @@ auto make_transpose_a_impl(int m_dim) { (void)output_m; // We need the intersection of the input and output bounds. - const index_t m = - std::min(output_m.end(), input_m.end()) - output_m.begin(); + const index_t m = std::max( + 0, std::min(output_m.end(), input_m.end()) - output_m.min()); assert(input_k.min() <= output_ko.min() * tile_k); - assert(output_ko.min() == 0); - const index_t k = std::min(output_ko.end() * tile_k, input_k.end()); + const index_t k = + std::max(0, std::min(output_ko.end() * tile_k, input_k.end()) - + output_ko.min() * tile_k); // We're transposing columns of the input to rows of the output, but // doing tile_k of them at a time. @@ -563,7 +596,7 @@ auto make_transpose_a_impl(int m_dim) { const index_t input_m_stride = input_m.stride(); const index_t output_ko_stride = output_ko.stride(); - input.slice(0); + input.slice(0, slinky::in_bounds{output_ko.min() * tile_k}); input.slice(m_dim - 1, output_m.min()); output.slice({0, 1, static_cast(m_dim + 1)}); @@ -656,10 +689,10 @@ uint32_t define_transpose_a(ynn_subgraph& subgraph, index_t tile_k, return output.id; } -std::tuple choose_split_factors( +std::tuple choose_split_factors( ynn_runtime& runtime, slinky::expr m, slinky::expr n, slinky::expr k, - slinky::expr block_n) { - // We can only return a scalar from a slinky expression, so we pack the two + slinky::expr block_n, int32_t k_alignment) { + // We can only return a scalar from a slinky expression, so we pack the // splits into one integer. auto impl = [](const slinky::call* op, slinky::eval_context& ctx) { index_t m = evaluate(op->args[0], ctx); @@ -717,13 +750,18 @@ std::tuple choose_split_factors( }; slinky::expr splits = slinky::call::make(impl, {m, n, k, block_n}); - // Extract the two splits from the single index_t result. + // Extract the splits from the single index_t result. splits = runtime.globals.get(splits, "dot_splits"); slinky::expr split_m = splits / 65536; slinky::expr split_n = splits % 65536; + slinky::expr split_k = k; split_m = runtime.globals.get(split_m, "split_m"); split_n = runtime.globals.get(split_n, "split_n"); - return {split_n, split_m}; + // Ensure the K split is a multiple of tile_k so that the microkernels can + // process full blocks. + split_k = slinky::align_up(split_k, k_alignment); + split_k = runtime.globals.get(split_k, "split_k"); + return {split_n, split_m, split_k}; } void learn_shape_from_b(dot_shape& shape, size_t num_k_dims, @@ -1004,8 +1042,9 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, ? consistent_block_n : std::max(YNN_CACHE_LINE_SIZE / b_elem_size, unpacked_kernel.block_n); - node.create = [consistent_arithmetic, pack_b, transpose_a, block_n_unpacked]( - const ynn_node& node, ynn_runtime& runtime) { + node.create = [consistent_arithmetic, pack_b, transpose_a, block_n_unpacked, + tile_k = kernel.tile_k](const ynn_node& node, + ynn_runtime& runtime) { const ynn_node::dot& op = std::get(node.op); const size_t num_k_dims = op.num_k_dims; const ynn_runtime_value& input_a = runtime.value(node.inputs[0]); @@ -1023,14 +1062,67 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, } output.make_buffer(runtime); - std::vector dims = runtime.globals.make_dims(output.rank()); - slinky::var j = dims[0]; + std::vector output_dims = + runtime.globals.make_dims(output.rank()); + slinky::var j = output_dims[0]; + + slinky::buffer_expr_ptr reduction_buffer = slinky::buffer_expr::make( + runtime.globals.symbols, "reduction", num_k_dims, 0); + + std::vector all_dims; + std::vector reduction_dims; + std::vector all_extents; + + for (int i = 0; i < output.rank(); ++i) { + all_dims.push_back(output_dims[i]); + all_extents.push_back(output.extent(i)); + } + + int reduction_dim = 0; + for (size_t d = 0; d < num_k_dims; ++d) { + slinky::var r_dim = runtime.globals.make_reduction_dim(reduction_dim); + all_dims.push_back(r_dim); + reduction_dims.push_back(r_dim); + + const int a_k_dim = transpose_a ? 1 : 0; + slinky::expr k_extent = input_a.extent(a_k_dim + d); + if (transpose_a && d == 0) { + // When A is transposed, its K1 dimension is split into blocks of size + // tile_k. The logical extent of the reduction dimension should be the + // total number of elements, so we multiply the number of blocks by the + // block size. + k_extent *= tile_k; + } + all_extents.push_back(k_extent); + + reduction_buffer->dim(reduction_dim).bounds = + slinky::min_extent(0, k_extent); + reduction_buffer->dim(reduction_dim).stride = 0; + reduction_buffer->dim(reduction_dim).fold_factor = slinky::dim::unfolded; + ++reduction_dim; + } // A: We need all of the k dims, i is elementwise. const int num_a_k_dims = num_k_dims + (transpose_a ? 1 : 0); slinky::box_expr a_bounds(std::min(input_a.rank(), num_a_k_dims)); - for (size_t i = 0; i < a_bounds.size(); ++i) { - a_bounds[i] = all_bounds(input_a.physical_extent(i)); + if (transpose_a) { + a_bounds[0] = all_bounds(tile_k); + for (size_t d = 1; d < a_bounds.size(); ++d) { + int k_idx = d - 1; + if (k_idx == 0) { + // Since the reduction dimension represents the total number of + // elements, we need to divide by the block size (tile_k) to get the + // corresponding block index for the transposed A buffer. + a_bounds[d] = + slinky::point(slinky::simplify(reduction_dims[k_idx] / tile_k)); + } else { + a_bounds[d] = slinky::point(reduction_dims[k_idx]); + } + } + } else { + for (size_t d = 0; d < a_bounds.size(); ++d) { + a_bounds[d] = slinky::point(reduction_dims[d]); + } } // B: We need all of the k dims, j is elementwise. j has been split into @@ -1039,34 +1131,38 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, slinky::box_expr b_bounds(num_b_k_dims + 1); b_bounds[0] = all_bounds(packed_b.physical_extent(0)); // ki b_bounds[1] = all_bounds(packed_b.physical_extent(1)); // ni - b_bounds[2] = all_bounds(packed_b.physical_extent(2)); // ko + if (pack_b) { + b_bounds[2] = slinky::point(slinky::simplify(reduction_dims[0] / tile_k)); + } else { + b_bounds[2] = slinky::point(reduction_dims[0]); + } // When we split a packed dimension, the inner part of the split remains // packed, but the outer part is not. b_bounds[3] = slinky::point(j) / packed_b.physical_extent(1); for (size_t i = 4; i < num_b_k_dims + 1; ++i) { - b_bounds[i] = all_bounds(packed_b.physical_extent(i)); + b_bounds[i] = slinky::point(reduction_dims[i - 3]); } // C: Elementwise slinky::box_expr c_bounds; if (input_c.rank() >= 1) { c_bounds.push_back( - elementwise_bounds(dims[0], input_c.physical_extent(0))); + elementwise_bounds(output_dims[0], input_c.physical_extent(0))); } // Batch dims are elementwise too. - for (size_t i = 1; i < dims.size(); ++i) { + for (size_t i = 1; i < output_dims.size(); ++i) { if (i + num_a_k_dims - 1 < input_a.rank()) { a_bounds.push_back(elementwise_bounds( - dims[i], input_a.physical_extent(i + num_a_k_dims - 1))); + output_dims[i], input_a.physical_extent(i + num_a_k_dims - 1))); } if (i >= 2 && i + 2 + num_k_dims - 1 < packed_b.rank()) { b_bounds.push_back(elementwise_bounds( - dims[i], packed_b.physical_extent(i + 2 + num_k_dims - 1))); + output_dims[i], packed_b.physical_extent(i + 2 + num_k_dims - 1))); } if (i < input_c.rank()) { c_bounds.push_back( - elementwise_bounds(dims[i], input_c.physical_extent(i))); + elementwise_bounds(output_dims[i], input_c.physical_extent(i))); } } @@ -1088,7 +1184,9 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, {{input_a.buffer, std::move(a_bounds)}, {packed_b.buffer, std::move(b_bounds)}, {input_c.buffer, std::move(c_bounds)}}, - {{output.buffer, dims}}, std::move(attrs)); + {{output.buffer, output_dims}, + {reduction_buffer, std::move(reduction_dims)}}, + std::move(attrs)); slinky::expr block_n = pack_b ? packed_b.extent(1) : block_n_unpacked; slinky::expr n = output.extent(0); @@ -1100,9 +1198,13 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, k *= packed_b.extent(3 + d); } - slinky::expr split_n, split_m; - std::tie(split_n, split_m) = - choose_split_factors(runtime, m, n, k, block_n); + int k_alignment = pack_b ? tile_k : 1; + if (transpose_a) { + k_alignment = std::max(k_alignment, tile_k); + } + slinky::expr split_n, split_m, split_k; + std::tie(split_n, split_m, split_k) = + choose_split_factors(runtime, m, n, k, block_n, k_alignment); if (slinky::prove_true(n <= block_n)) { // We know n is smaller than the side of the area we want to compute, @@ -1120,15 +1222,31 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, } } - slinky::expr splits[] = {split_n, split_m}; - auto sched = - runtime.make_schedule(dims, output.physical_extents(), - output.buffer->elem_size(), splits, loop_order); + // If output is rank >= 2, we want to split n, m, and k. Otherwise, we only + // split n and k (e.g. fully-connected layers). + std::vector splits; + splits.push_back(split_n); + if (output.rank() >= 2) { + splits.push_back(split_m); + for (size_t i = 2; i < output.rank(); ++i) { + splits.push_back({}); + } + } + splits.push_back(split_k); + + auto sched = runtime.make_schedule( + all_dims, all_extents, output.buffer->elem_size(), splits, loop_order); // We want to use exactly these loop splits for two innermost dot loops. - for (size_t i = 0; i < std::min(2, sched->loop_splits.size()); - i++) { - sched->loop_splits[i].step_is_required = true; + for (size_t dim_idx = 0; dim_idx < std::min(output_dims.size(), 2); + ++dim_idx) { + slinky::var sym = output_dims[dim_idx]; + for (size_t i = 0; i < sched->loop_splits.size(); ++i) { + if (sched->loop_splits[i].var == sym) { + sched->loop_splits[i].step_is_required = true; + break; + } + } } // Schedule the output buffer to be stored at the same level as it's