diff --git a/.gitmodules b/.gitmodules index 83613d7542..dcc19dc3ad 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "candle-flash-attn-v3/cutlass"] url = https://github.com/NVIDIA/cutlass.git path = candle-flash-attn-v3/cutlass +[submodule "candle-flash-mla/cutlass"] + path = candle-flash-mla/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/Cargo.toml b/Cargo.toml index c9c3bf0d8f..7f8245b92e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ exclude = [ "candle-kernels", "candle-metal-kernels", "candle-onnx", + "candle-flash-mla", ] resolver = "2" @@ -38,6 +39,7 @@ candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" } candle-datasets = { path = "./candle-datasets", version = "0.8.0" } candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" } candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.8.0" } +candle-flash-mla = { path = "./candle-flash-mla", version = "0.8.0" } candle-kernels = { path = "./candle-kernels", version = "0.8.0" } candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" } candle-nn = { path = "./candle-nn", version = "0.8.0" } @@ -50,7 +52,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = { version = "0.3.3", package = "candle-hf-hub" } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } -float8 = { version = "0.1.2", features = ["num-traits", "rand_distr"] } +float8 = { version = "0.2.0", features = ["num-traits", "rand_distr"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } @@ -61,8 +63,8 @@ memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" parquet = { version = "51.0.0" } -rand = "0.8.5" -rand_distr = "0.4.3" +rand = "0.9.0" +rand_distr = "0.5" rayon = "1.7.0" safetensors = "0.4.1" serde = { version = "1.0.171", features = ["derive"] } diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 271230ed4a..721127afff 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -3159,6 +3159,9 @@ impl BackendStorage for CpuStorage { (Self::F64(src), Self::F64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (_, dst) => { return Err(Error::DTypeMismatchBinaryOp { lhs: self.dtype(), @@ -3182,6 +3185,9 @@ impl BackendStorage for CpuStorage { (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F8E4M3(src), Self::F8E4M3(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_l) + } (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -3540,15 +3546,15 @@ impl BackendDevice for CpuDevice { use rand::prelude::*; let elem_count = shape.elem_count(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); match dtype { DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { let mut data = Vec::with_capacity(elem_count); - let uniform = - rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)) + .map_err(Error::wrap)?; for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -3556,8 +3562,8 @@ impl BackendDevice for CpuDevice { } DType::F16 => { let mut data = Vec::with_capacity(elem_count); - let uniform = - rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max)) + .map_err(Error::wrap)?; for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -3566,7 +3572,8 @@ impl BackendDevice for CpuDevice { DType::F8E4M3 => { let mut data = Vec::with_capacity(elem_count); let uniform = - rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)) + .map_err(Error::wrap)?; for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -3574,7 +3581,8 @@ impl BackendDevice for CpuDevice { } DType::F32 => { let mut data = Vec::with_capacity(elem_count); - let uniform = rand::distributions::Uniform::new(min as f32, max as f32); + let uniform = + rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?; for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -3582,7 +3590,7 @@ impl BackendDevice for CpuDevice { } DType::F64 => { let mut data = Vec::with_capacity(elem_count); - let uniform = rand::distributions::Uniform::new(min, max); + let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?; for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -3595,7 +3603,7 @@ impl BackendDevice for CpuDevice { use rand::prelude::*; let elem_count = shape.elem_count(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); match dtype { DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 08335257c6..b054567e30 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -792,7 +792,7 @@ impl PthTensors { /// # Arguments /// * `path` - Path to the pth file. /// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file -/// contains multiple objects and the state_dict is the one we are interested in. +/// contains multiple objects and the `state_dict` is the one we are interested in. pub fn read_all_with_key>( path: P, key: Option<&str>, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 8591aa259d..4a11419e0d 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -32,6 +32,22 @@ use half::{bf16, f16}; pub use k_quants::GgmlType; +fn as_t_slice(data: Cow<'_, [u8]>) -> &[T] { + let size = std::mem::size_of::(); + assert_eq!( + data.len() % size, + 0, + "Data length must be a multiple of T's size" + ); + let ptr = data.as_ptr(); + assert_eq!( + (ptr as usize) % std::mem::align_of::(), + 0, + "Data pointer must be aligned to T's alignment" + ); + unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) } +} + pub struct QTensor { storage: QStorage, shape: Shape, @@ -63,6 +79,46 @@ pub enum QStorage { } impl QStorage { + pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result { + match device { + Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))), + Device::Metal(d) => match dtype { + GgmlDType::F32 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::F16 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::(data)), + GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::(data)), + }, + Device::Cuda(d) => match dtype { + GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::(data)), + GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::(data)), + }, + } + } + fn block_size(&self) -> usize { match self { QStorage::Cpu(storage) => storage.block_size(), @@ -267,6 +323,27 @@ impl GgmlDType { Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } + + pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box { + match self { + Self::F32 => Box::new(as_t_slice::(data).to_vec()), + Self::F16 => Box::new(as_t_slice::(data).to_vec()), + Self::Q4_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q4_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q5_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q5_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q8_0 => Box::new(as_t_slice::(data).to_vec()), + Self::Q8_1 => Box::new(as_t_slice::(data).to_vec()), + Self::Q2K => Box::new(as_t_slice::(data).to_vec()), + Self::Q3K => Box::new(as_t_slice::(data).to_vec()), + Self::Q4K => Box::new(as_t_slice::(data).to_vec()), + Self::Q5K => Box::new(as_t_slice::(data).to_vec()), + Self::Q6K => Box::new(as_t_slice::(data).to_vec()), + Self::Q8K => Box::new(as_t_slice::(data).to_vec()), + Self::BF16 => Box::new(as_t_slice::(data).to_vec()), + } + } + /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; diff --git a/candle-flash-attn-v3/build.rs b/candle-flash-attn-v3/build.rs index 65f84fd394..d33f2937cf 100644 --- a/candle-flash-attn-v3/build.rs +++ b/candle-flash-attn-v3/build.rs @@ -202,6 +202,9 @@ fn main() -> Result<()> { command.arg("-DCUTLASS_DEBUG_TRACE_LEVEL=0"); command.arg("-DNDEBUG"); + // https://github.com/EricLBuehler/mistral.rs/issues/941 + command.arg("-D_USE_MATH_DEFINES"); + if let Some(ccbin_path) = &ccbin_env { command.arg("-allow-unsupported-compiler"); command.args(["-ccbin", ccbin_path]); diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index bda00909c7..e02ca7ef53 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -95,6 +95,8 @@ fn main() -> Result<()> { builder = builder.arg("-D_USE_MATH_DEFINES"); } } + // https://github.com/EricLBuehler/mistral.rs/issues/941 + builder = builder.arg("-D_USE_MATH_DEFINES"); // https://github.com/EricLBuehler/mistral.rs/issues/286 // https://github.com/huggingface/candle-flash-attn-v1/pull/2 diff --git a/candle-flash-mla/.gitignore b/candle-flash-mla/.gitignore new file mode 100644 index 0000000000..fc378cabab --- /dev/null +++ b/candle-flash-mla/.gitignore @@ -0,0 +1,7 @@ +.idea +target +Cargo.lock +.venv +hkernel/build/* +__pycache__ +*.egg-info \ No newline at end of file diff --git a/candle-flash-mla/.gitmodules b/candle-flash-mla/.gitmodules new file mode 100644 index 0000000000..2b822e9a55 --- /dev/null +++ b/candle-flash-mla/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cutlass"] + path = cutlass + url = https://github.com/NVIDIA/cutlass.git \ No newline at end of file diff --git a/candle-flash-mla/Cargo.toml b/candle-flash-mla/Cargo.toml new file mode 100644 index 0000000000..f8623d0dc5 --- /dev/null +++ b/candle-flash-mla/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "candle-flash-mla" +version = "0.8.0" +edition = "2021" + +description = "Flash MLA layer for the candle ML framework." +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" } +half = { version = "2.3.1", features = ["num-traits"] } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +num_cpus = "1.15.0" +rayon = "1.7.0" + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } +rstest = "0.23" \ No newline at end of file diff --git a/candle-flash-mla/LICENSE-APACHE b/candle-flash-mla/LICENSE-APACHE new file mode 100644 index 0000000000..f49a4e16e6 --- /dev/null +++ b/candle-flash-mla/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/candle-flash-mla/LICENSE-MIT b/candle-flash-mla/LICENSE-MIT new file mode 100644 index 0000000000..995061f1ee --- /dev/null +++ b/candle-flash-mla/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/candle-flash-mla/README.md b/candle-flash-mla/README.md new file mode 100644 index 0000000000..5fc04f030a --- /dev/null +++ b/candle-flash-mla/README.md @@ -0,0 +1,3 @@ +# Candle Flash MLA + +Flash MLA Layer for Hopper (compatible nvidia `sm90a` arch) and the Candle framework. diff --git a/candle-flash-mla/build.rs b/candle-flash-mla/build.rs new file mode 100644 index 0000000000..d21eed276e --- /dev/null +++ b/candle-flash-mla/build.rs @@ -0,0 +1,279 @@ +// build.rs +use anyhow::{anyhow, Context, Result}; +use rayon::prelude::*; +use std::path::PathBuf; +use std::str::FromStr; + +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + +const KERNEL_FILES: &[&str] = &["flash_api.cu", "flash_fwd_mla_bf16_sm90.cu"]; + +fn main() -> Result<()> { + // Use RAYON_NUM_THREADS or else default to the number of physical CPUs + let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( + |_| num_cpus::get_physical(), + |s| usize::from_str(&s).unwrap_or_else(|_| num_cpus::get_physical()), + ); + // limit to 16 cpus to not use to much ram on large servers + let num_cpus = num_cpus.min(16); + + rayon::ThreadPoolBuilder::new() + .num_threads(num_cpus) + .build_global() + .unwrap(); + + // Telling Cargo that if any of these files changes, rebuild. + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + + for file in KERNEL_FILES { + println!("cargo:rerun-if-changed=hkernel/{file}"); + } + println!("cargo:rerun-if-changed=kernels/**.h"); + println!("cargo:rerun-if-changed=kernels/**.hpp"); + println!("cargo:rerun-if-changed=kernels/**.cpp"); + + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + // You can optionally allow an environment variable to cache the compiled artifacts. + // If not found, we compile into the standard OUT_DIR. + let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { + Err(_) => out_dir.clone(), + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().map_err(|_| { + anyhow!( + "Directory doesn't exist: {} (the current directory is {})", + path.display(), + std::env::current_dir().unwrap().display() + ) + })? + } + }; + + // Ensure we set CUDA_INCLUDE_DIR for our crates that might rely on it. + set_cuda_include_dir()?; + + // If set, pass along the custom compiler for NVCC + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN").ok(); + + // Determine the GPU architecture we’re targeting, e.g. 90 for `sm_90`. + let compute_cap = compute_cap()?; + // assert compute cap is sm90 + // TODO TODO TODO + // assert!(compute_cap == 90, "Compute capability must be 90 (90a)"); + + // Our final library name + let out_file = build_dir.join("libflashattentionmla.a"); + + // Construct the list of (input_file -> output_object_file) + let kernel_dir = PathBuf::from("hkernel"); + let cu_files: Vec<(PathBuf, PathBuf)> = KERNEL_FILES + .iter() + .map(|f| { + let mut obj_file = out_dir.join(f); + obj_file.set_extension("o"); + (kernel_dir.join(f), obj_file) + }) + .collect(); + + // Decide whether to skip recompile if outputs are up to date. + // This is a simplistic approach, + // so feel free to refine if you need more robust up-to-date checks. + let out_modified = out_file + .metadata() + .and_then(|m| m.modified()) + .ok() + .unwrap_or_else(|| std::time::SystemTime::UNIX_EPOCH); + let should_compile = !out_file.exists() + || cu_files.iter().any(|(input, _)| { + let input_modified = input + .metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH); + input_modified.duration_since(out_modified).is_ok() // True if input_modified >= out_modified + }); + + if should_compile { + // 1) Compile each .cu/.cpp -> .o + cu_files + .par_iter() + .try_for_each(|(input, obj)| -> Result<()> { + let mut command = std::process::Command::new("nvcc"); + + // Optimization and standard + command.arg("-O3"); + command.arg("-std=c++17"); + + // GPU architecture, hard code sm_90a instead of sm90 + command.arg(format!("--gpu-architecture={}", "sm_90a")); + + // Compile to object file + command.arg("-c"); + command.args(["-o", obj.to_str().unwrap()]); + + // Default stream per-thread + command.args(["--default-stream", "per-thread"]); + + // Include path + command.arg("-Icutlass/include"); + + // Undefine CUDA “no half/bfloat” macros + command.arg("-U__CUDA_NO_HALF_OPERATORS__"); + command.arg("-U__CUDA_NO_HALF_CONVERSIONS__"); + command.arg("-U__CUDA_NO_BFLOAT16_OPERATORS__"); + command.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__"); + command.arg("-U__CUDA_NO_BFLOAT162_OPERATORS__"); + command.arg("-U__CUDA_NO_BFLOAT162_CONVERSIONS__"); + + // Enable relaxed/extended lambda and fast math + command.arg("--expt-relaxed-constexpr"); + command.arg("--expt-extended-lambda"); + command.arg("--use_fast_math"); + + // PTXAS options: verbose output, register usage info, etc. + command.arg("--ptxas-options=-v"); + command.arg("--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"); + + // Additional debug/performance flags + command.arg("-lineinfo"); + command.arg("-DCUTLASS_DEBUG_TRACE_LEVEL=0"); + command.arg("-DNDEBUG"); + + if let Some(ccbin_path) = &ccbin_env { + command.arg("-allow-unsupported-compiler"); + command.args(["-ccbin", ccbin_path]); + } + + // Add the source file + command.arg(input); + + // https://github.com/EricLBuehler/mistral.rs/issues/286 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + command.arg("--compiler-options"); + command.arg(cuda_nvcc_flags_env); + } + + let output = command + .spawn() + .with_context(|| format!("Failed to spawn nvcc for {input:?}"))? + .wait_with_output() + .with_context(|| format!("Failed during nvcc invocation for {input:?}"))?; + + if !output.status.success() { + return Err(anyhow!( + "nvcc error:\nCommand: {:?}\nstdout:\n{}\nstderr:\n{}", + command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + + Ok(()) + })?; + + // 2) Create static library from the .o files + let obj_files = cu_files + .iter() + .map(|(_, obj)| obj.clone()) + .collect::>(); + + let mut command = std::process::Command::new("nvcc"); + command.arg("--lib"); + command.args(["-o", out_file.to_str().unwrap()]); + command.args(obj_files); + + let output = command + .spawn() + .context("Failed spawning nvcc to archive .o files")? + .wait_with_output() + .context("Failed during nvcc archive step")?; + + if !output.status.success() { + return Err(anyhow!( + "nvcc error (archiving):\nCommand: {:?}\nstdout:\n{}\nstderr:\n{}", + command, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + )); + } + } + + // Finally, instruct cargo to link your library + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=static=flashattentionmla"); + + // Link required system libs + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + Ok(()) +} + +/// This function attempts to find a CUDA toolkit root that contains `include/cuda.h`, +/// and prints that path as `CUDA_INCLUDE_DIR`. +fn set_cuda_include_dir() -> Result<()> { + // Adapted from cudarc build.rs + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .filter_map(|v| std::env::var(v).ok()) + .map(Into::::into); + + let common_roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let candidates = env_vars.chain(common_roots.into_iter().map(Into::into)); + + let root = candidates + .filter(|path| path.join("include").join("cuda.h").is_file()) + .next() + .ok_or_else(|| anyhow!("Cannot find a valid CUDA root with include/cuda.h"))?; + + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +/// Determine the compute capability we should target. +/// If the user sets `CUDA_COMPUTE_CAP` we trust that. +/// Otherwise, we attempt to parse it from `nvidia-smi`. +fn compute_cap() -> Result { + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + let cc = compute_cap_str + .parse::() + .context("Failed to parse CUDA_COMPUTE_CAP")?; + Ok(cc) + } else { + // parse from nvidia-smi + let output = std::process::Command::new("nvidia-smi") + .args(["--query-gpu=compute_cap", "--format=csv"]) + .output() + .context("Failed to run nvidia-smi. Make sure it's in PATH.")?; + let stdout = String::from_utf8_lossy(&output.stdout); + let mut lines = stdout.lines(); + if lines.next().unwrap_or("") != "compute_cap" { + return Err(anyhow!("Unexpected output from nvidia-smi: {stdout}")); + } + if let Some(cap_line) = lines.next() { + // e.g. "9.0" -> "90" + let cc_str = cap_line.trim().replace('.', ""); + let cc = cc_str.parse::()?; + Ok(cc) + } else { + Err(anyhow!("nvidia-smi did not return a compute_cap line")) + } + } +} diff --git a/candle-flash-mla/cutlass b/candle-flash-mla/cutlass new file mode 160000 index 0000000000..afa1772203 --- /dev/null +++ b/candle-flash-mla/cutlass @@ -0,0 +1 @@ +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/candle-flash-mla/hkernel/flash_api.cpp b/candle-flash-mla/hkernel/flash_api.cpp new file mode 100644 index 0000000000..5a1cb8e0ed --- /dev/null +++ b/candle-flash-mla/hkernel/flash_api.cpp @@ -0,0 +1,203 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp + +#include +#include +#include +#include + +#include + +#include "flash_mla.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::vector +get_mla_metadata( + at::Tensor &seqlens_k, + const int num_heads_per_head_k, + const int num_heads_k +) { + // This should match the logic in the MLA kernel. + static constexpr int block_size_m = 64; + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + + auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + + return {tile_scheduler_metadata, num_splits}; +} + +std::vector +mha_fwd_kvcache_mla( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size + c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const float softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits // batch_size + 1 +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90); + + at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_ori = sizes[2]; + const int head_size = sizes[3]; + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int ngroups = num_heads_ori / num_heads_k; + const int seqlen_q = seqlen_q_ori * ngroups; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) + .reshape({batch_size, seqlen_q, num_heads, head_size}); + + int head_size_k = head_size; + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + + + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_mla_params params = {}; + // Set the sizes. + params.b = batch_size; + params.seqlen_q = seqlen_q; + params.cu_seqlens_k = seqlens_k.data_ptr(); + params.h = num_heads; + params.h_h_k_ratio = num_heads / num_heads_k; + params.ngroups = ngroups; + params.is_causal = is_causal; + params.d = head_size; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.v_ptr = vcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.v_batch_stride = vcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(-3); + params.v_row_stride = vcache.stride(-3); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(-2); + params.v_head_stride = vcache.stride(-2); + params.o_head_stride = out.stride(-2); + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + CHECK_DEVICE(num_splits); + CHECK_CONTIGUOUS(num_splits); + params.num_splits_ptr = num_splits.data_ptr(); + + at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(head_size == 576); + run_mha_fwd_splitkv_mla(params, stream); + + out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) + .reshape({batch_size, num_heads_ori, seqlen_q_ori}); + + return {out, softmax_lse}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashMLA"; + m.def("get_mla_metadata", &get_mla_metadata); + m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); +} diff --git a/candle-flash-mla/hkernel/flash_api.cu b/candle-flash-mla/hkernel/flash_api.cu new file mode 100644 index 0000000000..54c8051a40 --- /dev/null +++ b/candle-flash-mla/hkernel/flash_api.cu @@ -0,0 +1,46 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp + +#include "flash_fwd_mla_kernel.h" +#include "flash_mla.h" +#include "static_switch.h" + +#include + +#include +#include + +#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) + +extern "C" void get_mla_metadata( + int32_t* seqlens_k_ptr, + int32_t* tile_scheduler_metadata_ptr, // [num_sm_parts, TileSchedulerMetaDataSize] + int32_t* num_splits_ptr, // [batch_size + 1] + const int batch_size, + const int num_sm_parts, + const cudaStream_t stream +) { + // This should match the logic in the MLA kernel. + // static constexpr int block_size_m = 64; MOVED TO lib.rs + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + + return; +} + +extern "C" void mha_fwd_kvcache_mla( + Flash_fwd_mla_params params, + const cudaStream_t stream +) { + assert(params.d == 576); + run_mha_fwd_splitkv_mla(params, stream); +} diff --git a/candle-flash-mla/hkernel/flash_fwd_mla_bf16_sm90.cu b/candle-flash-mla/hkernel/flash_fwd_mla_bf16_sm90.cu new file mode 100644 index 0000000000..35691f2862 --- /dev/null +++ b/candle-flash-mla/hkernel/flash_fwd_mla_bf16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/candle-flash-mla/hkernel/flash_fwd_mla_kernel.h b/candle-flash-mla/hkernel/flash_fwd_mla_kernel.h new file mode 100644 index 0000000000..55f681112b --- /dev/null +++ b/candle-flash-mla/hkernel/flash_fwd_mla_kernel.h @@ -0,0 +1,679 @@ +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::MN>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + GmemLayoutAtomO{}, + Layout>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, params.scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, params.scale_softmax_log2) + : softmax.template softmax(tSrS, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(tSrS); + cute::copy(rP, tPsP); + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); + Tensor rP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, rP); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h * params.seqlen_q; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits == 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; + const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h * params.seqlen_q); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} diff --git a/candle-flash-mla/hkernel/flash_mla.h b/candle-flash-mla/hkernel/flash_mla.h new file mode 100644 index 0000000000..e56ab9ceb8 --- /dev/null +++ b/candle-flash-mla/hkernel/flash_mla.h @@ -0,0 +1,63 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +extern "C" struct Flash_fwd_mla_params { + using index_t = int64_t; + + int b, seqlen_q, d, d_v; + int h, h_h_k_ratio, ngroups; + bool is_causal; + float scale_softmax, scale_softmax_log2; + int *__restrict__ cu_seqlens_k; + + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ o_ptr; + void *__restrict__ softmax_lse_ptr; + + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t o_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t o_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t o_head_stride; + + int *__restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + int *__restrict__ tile_scheduler_metadata_ptr; + int num_sm_parts; + int *__restrict__ num_splits_ptr; + + void *__restrict__ softmax_lseaccum_ptr; + void *__restrict__ oaccum_ptr; +}; + +static constexpr int TileSchedulerMetaDataSize = 8; +// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/candle-flash-mla/hkernel/named_barrier.h b/candle-flash-mla/hkernel/named_barrier.h new file mode 100644 index 0000000000..cefa936ca7 --- /dev/null +++ b/candle-flash-mla/hkernel/named_barrier.h @@ -0,0 +1,15 @@ +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, +}; + +} // flash diff --git a/candle-flash-mla/hkernel/softmax.h b/candle-flash-mla/hkernel/softmax.h new file mode 100644 index 0000000000..4ab6ae9c6c --- /dev/null +++ b/candle-flash-mla/hkernel/softmax.h @@ -0,0 +1,197 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/candle-flash-mla/hkernel/static_switch.h b/candle-flash-mla/hkernel/static_switch.h new file mode 100644 index 0000000000..f156adcca5 --- /dev/null +++ b/candle-flash-mla/hkernel/static_switch.h @@ -0,0 +1,65 @@ +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() diff --git a/candle-flash-mla/hkernel/utils.h b/candle-flash-mla/hkernel/utils.h new file mode 100644 index 0000000000..3b8dd52759 --- /dev/null +++ b/candle-flash-mla/hkernel/utils.h @@ -0,0 +1,238 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-mla/src/ffi.rs b/candle-flash-mla/src/ffi.rs new file mode 100644 index 0000000000..35c9907572 --- /dev/null +++ b/candle-flash-mla/src/ffi.rs @@ -0,0 +1,63 @@ +use core::ffi::{c_int, c_void}; + +use candle::cuda::cudarc::driver::sys::CUstream; + +#[repr(C)] +pub struct FlashFwdMlaParams { + pub b: c_int, + pub seqlen_q: c_int, + pub d: c_int, + pub d_v: c_int, + pub h: c_int, + pub h_h_k_ratio: c_int, + pub ngroups: c_int, + pub is_causal: bool, + pub scale_softmax: f32, + pub scale_softmax_log2: f32, + pub cu_seqlens_k: *mut c_int, + + pub q_ptr: *mut c_void, + pub k_ptr: *mut c_void, + pub v_ptr: *mut c_void, + pub o_ptr: *mut c_void, + pub softmax_lse_ptr: *mut c_void, + + pub q_batch_stride: i64, + pub k_batch_stride: i64, + pub v_batch_stride: i64, + pub o_batch_stride: i64, + pub q_row_stride: i64, + pub k_row_stride: i64, + pub v_row_stride: i64, + pub o_row_stride: i64, + pub q_head_stride: i64, + pub k_head_stride: i64, + pub v_head_stride: i64, + pub o_head_stride: i64, + + pub block_table: *mut c_int, + pub block_table_batch_stride: i64, + pub page_block_size: c_int, + + pub tile_scheduler_metadata_ptr: *mut c_int, + pub num_sm_parts: c_int, + pub num_splits_ptr: *mut c_int, + + pub softmax_lseaccum_ptr: *mut c_void, + pub oaccum_ptr: *mut c_void, +} + +pub const TILE_SCHEDULER_METADATA_SIZE: usize = 8; + +extern "C" { + pub(crate) fn get_mla_metadata( + seqlens_k_ptr: *mut c_int, + tile_scheduler_metadata_ptr: *mut c_int, + num_splits_ptr: *mut c_int, + batch_size: c_int, + num_sm_parts: c_int, + stream: CUstream, + ); + + pub(crate) fn mha_fwd_kvcache_mla(params: FlashFwdMlaParams, stream: CUstream); +} diff --git a/candle-flash-mla/src/lib.rs b/candle-flash-mla/src/lib.rs new file mode 100644 index 0000000000..54c5c60a56 --- /dev/null +++ b/candle-flash-mla/src/lib.rs @@ -0,0 +1,317 @@ +mod ffi; + +use std::f32; + +use candle::backend::BackendStorage; +use candle::cuda::cudarc; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; +use half::bf16; + +pub struct FlashAttn { + pub softmax_scale: f32, + pub block_table: Tensor, + pub cache_seqlens: Tensor, + pub head_size_v: usize, + pub seqlen_q_ori: usize, + pub ngroups: usize, + pub num_heads_per_head_k: usize, +} + +impl FlashAttn { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k_c_k_pe_cache: &candle::CudaStorage, + k_c_k_pe_cache_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = q.device(); + let (b_sz, seqlen_q, num_heads, head_size_q) = q_l.shape().dims4()?; + + let out_shape = Shape::from_dims(&[b_sz, seqlen_q, num_heads, self.head_size_v]); + let out_l = Layout::contiguous(&out_shape); + + let q = q.as_cuda_slice::()?; + let k_c_k_pe_cache = k_c_k_pe_cache.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k_c_k_pe_cache = k_c_k_pe_cache.slice(k_c_k_pe_cache_l.start_offset()..); + + let v_l = k_c_k_pe_cache_l; + let v = &k_c_k_pe_cache; + + let q_stride = q_l.stride(); + let k_stride = k_c_k_pe_cache_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + + if q_rank != 4 || k_rank != 4 || v_rank != 4 { + candle::bail!( + "flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank})" + ) + } + if q_stride[q_rank - 1] != 1 { + candle::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + candle::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + candle::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + if self.block_table.dtype() != DType::I32 { + candle::bail!("block_table must be i32"); + } + + if self.block_table.stride()[self.block_table.stride().len() - 1] != 1 { + candle::bail!("block_table must have contiguous last dim"); + } + + let max_num_blocks_per_seq = self.block_table.dim(1)?; + let num_blocks = k_c_k_pe_cache_l.dim(0)?; + let page_block_size = k_c_k_pe_cache_l.dim(1)?; + let num_heads_k = k_c_k_pe_cache_l.dim(2)?; + + if head_size_q % 8 != 0 { + candle::bail!("only supports q/k head sizes that are a multiple of 8") + } + if self.head_size_v % 32 != 0 { + candle::bail!("only supports v head sizes that are a multiple of 32") + } + if num_heads % num_heads_k != 0 { + candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + + let candle::Storage::Cuda(block_table) = &*self.block_table.storage_and_layout().0 else { + candle::bail!("block_table must be CUDA") + }; + let block_table = block_table + .as_cuda_slice::()? + .slice(self.block_table.layout().start_offset()..); + + let candle::Storage::Cuda(cache_seqlens) = &*self.cache_seqlens.storage_and_layout().0 + else { + candle::bail!("cache_seqlens must be CUDA") + }; + let cache_seqlens = cache_seqlens + .as_cuda_slice::()? + .slice(self.cache_seqlens.layout().start_offset()..); + + let is_causal = self.seqlen_q_ori != 1; + + let num_heads = num_heads_k; + let head_size_k = head_size_q; + + if q_l.dims() != [b_sz, seqlen_q, num_heads, head_size_q] { + candle::bail!( + "Expected q shape {:?}, got {:?} instead.", + [b_sz, seqlen_q, num_heads, head_size_q], + q_l.dims() + ); + } + if k_c_k_pe_cache_l.dims() != [num_blocks, page_block_size, num_heads_k, head_size_k] { + candle::bail!( + "Expected k shape {:?}, got {:?} instead.", + [num_blocks, page_block_size, num_heads_k, head_size_k], + k_c_k_pe_cache_l.dims() + ); + } + if self.block_table.dims() != [b_sz, max_num_blocks_per_seq] { + candle::bail!( + "Expected block_table shape {:?}, got {:?} instead.", + [b_sz, max_num_blocks_per_seq], + self.block_table.dims() + ); + } + if self.cache_seqlens.dims() != [b_sz] { + candle::bail!( + "Expected cache_seqlens shape {:?}, got {:?} instead.", + [b_sz], + self.cache_seqlens.dims() + ); + } + + // This should match the logic in the MLA kernel. + let block_size_m = 64usize; + let sm_count = dev + .attribute( + cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + ) + .w()? as usize; + let num_sm_parts = sm_count + / num_heads_k + / (self.seqlen_q_ori * self.num_heads_per_head_k).div_ceil(block_size_m); + + let tile_scheduler_metadata = + unsafe { dev.alloc::(num_sm_parts * ffi::TILE_SCHEDULER_METADATA_SIZE) }.w()?; + let num_splits = unsafe { dev.alloc::(b_sz + 1) }.w()?; + + unsafe { + ffi::get_mla_metadata( + (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int, + (*tile_scheduler_metadata.device_ptr()) as *mut core::ffi::c_int, + (*num_splits.device_ptr()) as *mut core::ffi::c_int, + b_sz as i32, + num_sm_parts as i32, + *dev.cu_stream(), + ); + } + + let dst = unsafe { dev.alloc::(b_sz * seqlen_q * num_heads * self.head_size_v) }.w()?; + let softmax_lse = unsafe { dev.alloc::(b_sz * num_heads * seqlen_q) }.w()?; + + let dst_accum = unsafe { + dev.alloc::((b_sz + num_sm_parts) * seqlen_q * num_heads * self.head_size_v) + } + .w()?; + let softmax_lse_accum = + unsafe { dev.alloc::((b_sz + num_sm_parts) * num_heads * seqlen_q) }.w()?; + + // Expect: + if head_size_q != 576 { + candle::bail!("Expected head_size_q to be 576, got {head_size_q}"); + } + if self.head_size_v != 512 { + candle::bail!("Expected head_size_v to be 512, got {}", self.head_size_v); + } + if page_block_size != 64 { + candle::bail!("Expected page_block_size to be 64, got {page_block_size}"); + } + + let params = ffi::FlashFwdMlaParams { + b: b_sz as i32, + seqlen_q: seqlen_q as i32, + cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int, + h: num_heads as i32, + h_h_k_ratio: (num_heads / num_heads_k) as i32, + ngroups: self.ngroups as i32, + is_causal, + d: head_size_q as i32, + d_v: self.head_size_v as i32, + scale_softmax: self.softmax_scale, + scale_softmax_log2: self.softmax_scale * f32::consts::LOG2_E, + q_ptr: (*q.device_ptr()) as *mut core::ffi::c_void, + k_ptr: (*k_c_k_pe_cache.device_ptr()) as *mut core::ffi::c_void, + v_ptr: (*v.device_ptr()) as *mut core::ffi::c_void, + o_ptr: (*dst.device_ptr()) as *mut core::ffi::c_void, + softmax_lse_ptr: (*softmax_lse.device_ptr()) as *mut core::ffi::c_void, + q_batch_stride: q_stride[0] as i64, + k_batch_stride: k_stride[0] as i64, + v_batch_stride: v_stride[0] as i64, + o_batch_stride: o_stride[0] as i64, + q_row_stride: q_stride[q_stride.len() - 3] as i64, + k_row_stride: k_stride[k_stride.len() - 3] as i64, + v_row_stride: v_stride[v_stride.len() - 3] as i64, + o_row_stride: o_stride[o_stride.len() - 3] as i64, + q_head_stride: q_stride[q_stride.len() - 2] as i64, + k_head_stride: k_stride[k_stride.len() - 2] as i64, + v_head_stride: v_stride[v_stride.len() - 2] as i64, + o_head_stride: o_stride[o_stride.len() - 2] as i64, + block_table: (*block_table.device_ptr()) as *mut core::ffi::c_int, + block_table_batch_stride: self.block_table.stride()[0] as i64, + page_block_size: page_block_size as i32, + tile_scheduler_metadata_ptr: (*tile_scheduler_metadata.device_ptr()) + as *mut core::ffi::c_int, + num_sm_parts: num_sm_parts as i32, + num_splits_ptr: (*num_splits.device_ptr()) as *mut core::ffi::c_int, + oaccum_ptr: (*dst_accum.device_ptr()) as *mut core::ffi::c_void, + softmax_lseaccum_ptr: (*softmax_lse_accum.device_ptr()) as *mut core::ffi::c_void, + }; + + unsafe { ffi::mha_fwd_kvcache_mla(params, *dev.cu_stream()) } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl candle::CustomOp2 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k_c_k_pe_cache: &candle::CudaStorage, + k_c_k_pe_cache_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::BF16 => { + self.cuda_fwd_t::(q, q_l, k_c_k_pe_cache, k_c_k_pe_cache_l) + } + dt => candle::bail!("flash-mla is only supported for bf16 ({dt:?})"), + } + } +} + +/// FlashMLA layer. +/// +/// This implements MLA attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// +/// # Arguments +/// +/// * `q`: (batch_size, seq_len_q, num_heads_q, head_dim). +/// * `k_c_k_pe_cache`: (num_blocks, page_block_size, num_heads_k, head_dim). +/// * `block_table`: (batch_size, max_num_blocks_per_seq), i32. +/// * `cache_seqlens`: (batch_size), i32 +/// * `softmax_scale: The scale of QK^T before applying softmax. +/// * `head_size_v`: v_head_dim in the config +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size_v)`. +pub fn flash_attn_mla( + q: &Tensor, + k_c_k_pe_cache: &Tensor, + block_table: Tensor, + cache_seqlens: Tensor, + softmax_scale: f32, + head_size_v: usize, +) -> Result { + let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?; + + let num_heads_k = k_c_k_pe_cache.dim(2)?; + let ngroups = num_heads / num_heads_k; + + let seqlen_q = seqlen_q_ori * ngroups; + let num_heads_per_head_k = num_heads / num_heads_k; + + let q = q + .reshape((b_sz, seqlen_q_ori, num_heads_k, ngroups, head_size))? + .transpose(2, 3)? + .reshape((b_sz, seqlen_q, num_heads_k, head_size))?; + + let op = FlashAttn { + softmax_scale, + block_table, + cache_seqlens, + head_size_v, + seqlen_q_ori, + ngroups, + num_heads_per_head_k, + }; + + let out = q.apply_op2(k_c_k_pe_cache, op)?; + + out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))? + .transpose(2, 3)? + .reshape((b_sz, seqlen_q_ori, num_heads, head_size_v)) +} diff --git a/candle-flash-mla/tests/flash_mla_tests.rs b/candle-flash-mla/tests/flash_mla_tests.rs new file mode 100644 index 0000000000..8ae18dc4ef --- /dev/null +++ b/candle-flash-mla/tests/flash_mla_tests.rs @@ -0,0 +1,128 @@ +use anyhow::Result; +use candle::{DType, Device, IndexOp, Tensor, D}; +use candle_flash_mla; +use rstest::rstest; + +pub trait RepeatInterleaveOp { + fn repeat_interleave(&self, repeats: usize, dim: usize) -> candle::Result; +} + +impl RepeatInterleaveOp for Tensor { + fn repeat_interleave(&self, repeats: usize, dim: usize) -> candle::Result { + #[allow(clippy::cast_possible_truncation)] + let indices = Tensor::new( + (0..self.dim(dim)?) + .flat_map(|i| vec![i as u32; repeats]) + .collect::>(), + self.device(), + )?; + self.index_select(&indices, dim) + } +} + +fn sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + h_q: usize, + h_kv: usize, + softmax_scale: f32, +) -> Result { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + + let k = k.repeat_interleave(h_q / h_kv, 0)?; + let v = v.repeat_interleave(h_q / h_kv, 0)?; + + let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; + Ok(output) +} + +#[rstest( + b => [128], + s_k => [4096, 8192], + h_q => [16, 32, 64, 128], // TP = 8, 4, 2, 1 + // s_q => [1, 2], // MTP = 1, 2 + s_q => [1], +)] +fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> { + let device = Device::new_cuda(0)?; + + let h_kv = 1; + let d = 576; + let dv = 512; + + let cache_seqlens_vec = vec![s_k as i32; b]; + let cache_seqlens = Tensor::new(cache_seqlens_vec.clone(), &device)?; + let max_seqlen = cache_seqlens.max(0)?.to_scalar::()? as usize; + let max_seqlen_pad = max_seqlen.div_ceil(256) * 256; + + let q = Tensor::randn(0., 1., (b, s_q, h_q, d), &device)?.to_dtype(DType::BF16)?; + let block_size = 64; + let block_table = Tensor::arange(0i32, (b * max_seqlen_pad / block_size) as i32, &device)? + .reshape((b, max_seqlen_pad / block_size))?; + let blocked_k = Tensor::randn( + 0., + 1., + (block_table.elem_count(), block_size, h_kv, d), + &device, + )? + .to_dtype(DType::BF16)?; + let blocked_v = blocked_k.narrow(D::Minus1, 0, dv)?.copy()?; + + let softmax_scale = 1. / (q.dim(D::Minus1)? as f32).sqrt(); + + let out_flash = candle_flash_mla::flash_attn_mla( + &q, + &blocked_k, + block_table, + cache_seqlens, + softmax_scale, + dv, + )?; + + let truth = { + let mut out = Vec::new(); + for i in 0..b { + let begin = i * max_seqlen_pad; + let end = begin + cache_seqlens_vec[i] as usize; + + let q = q.i(i)?.transpose(0, 1)?; + let k = blocked_k + .reshape(((), h_kv, d))? + .i(begin..end)? + .transpose(0, 1)?; + let v = blocked_v + .reshape(((), h_kv, dv))? + .i(begin..end)? + .transpose(0, 1)?; + + let res = sdpa(&q, &k, &v, h_q, h_kv, softmax_scale)?; + out.push(res.transpose(0, 1)?); + } + + Tensor::stack(&out, 0)? + }; + + assert_eq!(out_flash.dims(), truth.dims()); + + { + let out_flash = out_flash.to_dtype(DType::F64)?; + let truth = truth.to_dtype(DType::F64)?; + + let cos_diff = 1. + - 2. * (&out_flash * &truth)?.sum_all()?.to_scalar::()? + / (out_flash.sqr()? + truth.sqr()?)? + .sum_all()? + .to_scalar::()? + .max(1e-12); + assert!(cos_diff < 1e-5, "{cos_diff} > {}", 1e-5); + } + + Ok(()) +} diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index bc8e2b9861..f975640212 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -1,6 +1,6 @@ fn main() { std::env::set_var("NVCC_PREPEND_FLAGS", "-D_USE_MATH_DEFINES"); - + println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=src/compatibility.cuh"); println!("cargo:rerun-if-changed=src/cuda_utils.cuh"); diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e2a1ba54cd..aa8701821d 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -201,7 +201,7 @@ impl From> for MetalKernelError { } type Libraries = HashMap; -type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; +type Pipelines = HashMap<(String, Option), ComputePipelineState>; #[derive(Debug)] pub struct Kernels { @@ -282,7 +282,7 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: &str, constants: Option, ) -> Result { let func = self @@ -299,11 +299,11 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: &str, constants: Option, ) -> Result { let mut pipelines = self.pipelines.write()?; - let key = (name, constants); + let key = (name.to_string(), constants); if let Some(pipeline) = pipelines.get(&key) { Ok(pipeline.clone()) } else { @@ -311,13 +311,13 @@ impl Kernels { let func = self.load_function( device, source, - name, + &name, constants.as_ref().map(|c| c.function_constant_values()), )?; let pipeline = device .new_compute_pipeline_state_with_function(&func) .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; - pipelines.insert((name, constants), pipeline.clone()); + pipelines.insert((name.to_string(), constants), pipeline.clone()); Ok(pipeline) } @@ -330,7 +330,7 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: &str, ) -> Result { self.load_pipeline_with_constants(device, source, name, None) } @@ -1775,174 +1775,201 @@ pub fn call_sdpa_full( kernels: &Kernels, q_offset: usize, q_shape: &[usize], + q_strides: &[usize], q_buffer: &Buffer, k_offset: usize, + k_shape: &[usize], + k_strides: &[usize], k_buffer: &Buffer, v_offset: usize, v_buffer: &Buffer, + v_strides: &[usize], + mask_type: Option, + mask_buffer: Option<&Buffer>, + m_strides: Option<&[usize]>, output: &Buffer, - alpha: f32, - softcapping: f32, + o_strides: &[usize], + scale: f32, + do_causal: bool, itype: SdpaDType, ) -> Result<(), MetalKernelError> { #[derive(Debug)] #[repr(C)] - struct MLXFastAttentionParams { - m: i32, - n: i32, - k: i32, - - ldq: i32, // ldq == ldo - ldk: i32, - ldv: i32, - lds: i32, - ldo: i32, + struct AttnParams { + b: i32, + h: i32, + d: i32, + ql: i32, + kl: i32, + gqa_factor: i32, + scale: f32, + nq: i32, + nk: i32, + nq_aligned: i32, + nk_aligned: i32, + ql_rem: i32, + kl_rem: i32, + ql_off: i32, + q_strides: [i64; 3], + k_strides: [i64; 3], + v_strides: [i64; 3], + o_strides: [i64; 3], + } - tiles_n: i32, - tiles_m: i32, + #[derive(Debug)] + #[repr(C)] + struct AttnMaskParams { + m_strides: [i64; 3], + } - batch_stride_q: i32, - batch_stride_k: i32, - batch_stride_v: i32, - batch_stride_o: i32, + const WM: usize = 4; + const WN: usize = 1; - swizzle_log: i32, - gemm_n_iterations_aligned: i32, - gemm_k_iterations_aligned: i32, - gemm_sv_m_block_iterations: i32, + const BQ: usize = 32; + let bd = q_shape[q_shape.len() - 1]; + let bk = if bd < 128 { 32 } else { 16 }; - batch_ndim: i32, - alpha: f32, - softcapping: f32, - } + let b = q_shape[0]; + let h = q_shape[1]; + let d = q_shape[3]; + let gqa_factor = q_shape[1] / k_shape[1]; - let bk = q_shape.last().unwrap(); + let ql = q_shape[2]; + let kl = k_shape[2]; - const BN: usize = 16; - const BM: usize = 16; - const WM: usize = 2; - const WN: usize = 2; + let align_q = (ql % BQ) == 0; + let align_k = (kl % bk) == 0; + let has_mask = mask_buffer.is_some(); - let name = match (bk, itype) { - (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", - (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", - (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", - (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", - (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", - (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", - (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", - (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", - (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", - (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", - (other, SdpaDType::F16 | SdpaDType::F32) => { - return Err(MetalKernelError::SdpaHeadSizeMismatch { - variation: "full", - got: *other, - expected: vec![32, 64, 96, 128, 256], - }) - } - (_, SdpaDType::BF16) => { - return Err(MetalKernelError::SdpaHeadDTypeMismatch { - variation: "full", - got: SdpaDType::BF16, - }) - } + let itype_repr = match itype { + SdpaDType::BF16 => "bfloat16", + SdpaDType::F16 => "float16", + SdpaDType::F32 => "float32", }; + let mask_repr = match mask_type { + Some(SdpaDType::BF16) => "bfloat16", + Some(SdpaDType::F16) => "float16", + Some(SdpaDType::F32) => "float32", + None => itype_repr, + }; + let name = + format!("steel_attention_{itype_repr}_bq{BQ}_bk{bk}_bd{bd}_wm{WM}_wn{WN}_mask{mask_repr}"); - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let constants = Some(ConstantValues::new(vec![ + (200, Value::Bool(/* align_Q */ align_q)), + (201, Value::Bool(/* align_K */ align_k)), + (300, Value::Bool(/* has_mask */ has_mask)), + (301, Value::Bool(/* do_causal */ do_causal)), + ])); + + let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, &name, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - // q = (bs, qhead, seq, hidden) - // k/v = (bs, kv_head, seq, hidden) - - let qseq = q_shape[q_shape.len() - 2]; - - let m = q_shape[q_shape.len() - 2]; - let n = m; - let k = q_shape[q_shape.len() - 1]; - let bs_out = q_shape[0] * q_shape[1]; - - let batch_shape = [q_shape[0] * q_shape[1]]; - let dk = q_shape[q_shape.len() - 1]; - let ldq = dk; - let ldk = dk; - let ldv = dk; - let lds = BN; - let ldo = dk; - - let tn = 1; - let tm = (m + BM - 1) / BM; - - let b_stride_q = dk * qseq; - let b_stride_k = dk * qseq; - let b_stride_v = dk * qseq; - let b_stride_o = dk * qseq; - let swizzle_log = 0; - let gemm_n_iterations_aligned = (n + BN - 1) / BN; - let gemm_k_iterations_aligned = (k + bk - 1) / bk; - let gemm_sv_m_block_iterations = (m + BM - 1) / BM; - let batch_ndim = batch_shape.len(); - - let alpha = if softcapping != 1. { - alpha / softcapping - } else { - alpha + let nq = (ql + BQ - 1) / BQ; + let nk = (kl + bk - 1) / bk; + + let nq_aligned = ql / BQ; + let nk_aligned = kl / bk; + + let params = AttnParams { + b: b as i32, + h: h as i32, + d: d as i32, + ql: ql as i32, + kl: kl as i32, + gqa_factor: gqa_factor as i32, + scale, + nq: nq as i32, + nk: nk as i32, + nq_aligned: nq_aligned as i32, + nk_aligned: nk_aligned as i32, + ql_rem: (ql - nq_aligned * BQ) as i32, + kl_rem: (kl - nk_aligned * bk) as i32, + ql_off: (kl - ql) as i32, + q_strides: [ + q_strides[0] as i64, + q_strides[1] as i64, + q_strides[2] as i64, + ], + k_strides: [ + k_strides[0] as i64, + k_strides[1] as i64, + k_strides[2] as i64, + ], + v_strides: [ + v_strides[0] as i64, + v_strides[1] as i64, + v_strides[2] as i64, + ], + o_strides: [ + o_strides[0] as i64, + o_strides[1] as i64, + o_strides[2] as i64, + ], }; - let params = MLXFastAttentionParams { - m: m as i32, - n: n as i32, - k: k as i32, - ldq: ldq as i32, - ldk: ldk as i32, - ldv: ldv as i32, - lds: lds as i32, - ldo: ldo as i32, - tiles_n: tn, - tiles_m: tm as i32, - batch_stride_q: b_stride_q as i32, - batch_stride_k: b_stride_k as i32, - batch_stride_v: b_stride_v as i32, - batch_stride_o: b_stride_o as i32, - swizzle_log, - gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, - gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, - gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, - batch_ndim: batch_ndim as i32, - alpha, - softcapping, - }; - let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; + impl EncoderParam for AttnParams { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::() as u64, + &data as *const AttnParams as *const c_void, + ); + } + } - impl EncoderParam for MLXFastAttentionParams { + impl EncoderParam for AttnMaskParams { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - core::mem::size_of::() as u64, - &data as *const MLXFastAttentionParams as *const c_void, + core::mem::size_of::() as u64, + &data as *const AttnMaskParams as *const c_void, ); } } - set_params!( - encoder, - ( - (q_buffer, q_offset), - (k_buffer, k_offset), - (v_buffer, v_offset), - output, - params, - &batch_shape[..], - &batch_strides[..] - ) - ); + if let Some(mask) = mask_buffer { + let mask_strides = m_strides.unwrap(); + let mask_params = AttnMaskParams { + m_strides: [ + mask_strides[0] as i64, + mask_strides[1] as i64, + mask_strides[2] as i64, + ], + }; + encoder.use_resource(mask, metal::MTLResourceUsage::Read); + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + mask_params, + mask + ) + ); + } else { + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params + ) + ); + } let grid_dims = MTLSize { - width: 1, - height: tm as u64, - depth: bs_out as u64, + width: nq as u64, + height: h as u64, + depth: b as u64, }; let group_dims = MTLSize { width: 32, @@ -1954,6 +1981,7 @@ pub fn call_sdpa_full( encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) } diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index 09f727dce2..207a2c2346 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -5,9 +5,13 @@ using namespace metal; +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + #if defined(__HAVE_BFLOAT__) typedef bfloat bfloat16_t; +typedef half float16_t; #else @@ -621,111 +625,55 @@ template } } -// ============ "mlx/backend/metal/kernels/steel/defines.h" - -#define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - -// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT) { - return static_cast(x); - } -}; - -template -struct TransformAdd { - TransformAdd(const float, const float) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT c) { - return static_cast(x) + c; - } -}; - -template -struct TransformAxpby { - const float alpha; - const float beta; - - TransformAxpby(const float alpha_, const float beta_) - : alpha(alpha_), beta(beta_) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast(x * alpha + (beta * c)); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; +// ============ "mlx/backend/metal/kernels/utils.h" -struct BlockSwizzle { - static METAL_FUNC int2 - swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { - const int tid_x = (tid.x) >> swizzle_log; - const int tid_y = - ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); - return int2(tid_x, tid_y); - } +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); }; -// ============ "mlx/backend/metal/kernels/utils.h" +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; -typedef half float16_t; +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; -METAL_FUNC ulong2 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - } - return ulong2(loc_a, loc_b); -} +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); -METAL_FUNC ulong3 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - ulong loc_c{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - loc_c += pos_in_dim * c_strides[i]; - } - return ulong3(loc_a, loc_b, loc_c); -} -// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal" +// ============ "mlx/backend/metal/kernels/steel/attn/loader.h" template < typename T, @@ -738,7 +686,7 @@ template < short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { +struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; @@ -760,7 +708,7 @@ struct BlockLoaderFA { }; /* Constructor */ - METAL_FUNC BlockLoaderFA( + METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, @@ -774,6 +722,18 @@ struct BlockLoaderFA { dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL @@ -835,243 +795,926 @@ struct BlockLoaderFA { METAL_FUNC void next() { src += tile_stride; } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } }; -template -struct LoopAlignment {}; +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; template < typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; - short sm; - short sn; + // Leading dimension for src + const int src_ld; + const int tile_stride; - ushort sid; - ushort slid; + // Thread location indices + const short thread_idx; + const short bi; + const short bj; - short As_offset; - short Bs_offset; + // threadgroup and device memory + threadgroup T* dst; + const device T* src; /* Constructor */ - METAL_FUNC BlockMMAFA( + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} - // Iterate over BK in blocks of 8 + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); } + } + } - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; } + } + } - simdgroup_barrier(mem_flags::mem_none); + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); - // Multiply and accumulate into result simdgroup matrices + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { + for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); } } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; + return; } - } - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; - + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { + // Read valid indices into tmp_val STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; } } } - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, - const int ldc, - short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); +// ============ "mlx/backend/metal/kernels/steel/utils/type_traits.h" - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } +template +struct make_void { + typedef void type; +}; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } +template +using void_t = typename make_void::type; - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); +template +struct pointer_element {}; - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } +template +using pointer_element_t = typename pointer_element>::type; - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } +// ============ "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// ============ "mlx/backend/metal/kernels/steel/attn/mma.h" + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x.value + j * str_y.value]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y.value]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y.value] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y.value] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { + mat_type D_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; + + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + template + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), + B.frag_at(k, n_serp), + C.frag_at(m_serp, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } } } @@ -1084,8 +1727,10 @@ struct BlockMMAFA { const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -1093,18 +1738,15 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -1118,9 +1760,14 @@ struct BlockMMAFA { short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { @@ -1128,556 +1775,547 @@ struct BlockMMAFA { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } } } +}; - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } +// ============ "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; } }; +// clang-format off template < typename T, - typename U, - int BM, - int BN, + int BQ, int BK, + int BD, int WM, int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - threadgroup_barrier(mem_flags::mem_threadgroup); + // Pacifying compiler + (void)lid; - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + // Prepare threadgroup memory + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale * 1.44269504089)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; } - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL_rem)); + } else { + loader_k.load_unsafe(); + } + + // Do S = Q @ K.T + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } } } - } - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha, - float softcapping) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + // Mask out if causal + if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } } - m_ij = max(m_ij, val); } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); - m_i_new = max(m_ij, m_i_new); + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; - float rowsum = 0.f; // lij + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - if (softcapping != 1.) { - val = precise::tanh(val); - val = val * softcapping; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + } } - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; } } - } - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL_rem)); + } else { + loader_v.load_unsafe(); } - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); + // Do softmax - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha, - params->softcapping); + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); - loader_v.load_safe(short2(BK, tgp_bn_qk)); + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Update O + Otile.template row_bin_op(factor); - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.mma(Ss, Vs); + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); + const short kk = ik * kFragSize; + const short dd = id * kFragSize; - mma_softmax_sv_op.rescale_output(final_output_scales); + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); - loader_v.next(); - loader_k.next(BN); + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - mma_qk_op.clear_results(); + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + // Prepare for next iteration + loader_k.next(); + loader_v.next(); } -}; -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); + Otile.template store(O, params->O_strides[2]); } } // clang-format off // SDPA full instantiations -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(5)]], \ - const constant size_t* batch_strides [[buffer(6)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 32, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 96, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 256, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) + +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat16_t); +instantiate_attn_mask_helper(float32, float); // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 6cd3b15a68..285f7f2dc0 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1549,6 +1549,8 @@ impl Module for Identity { struct Sdpa { scale: f32, softcapping: f32, + mask: Option, + do_causal: bool, } impl candle::CustomOp3 for Sdpa { @@ -1585,6 +1587,8 @@ impl candle::CustomOp3 for Sdpa { let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; let elem_count: usize = out_dims.iter().product(); + let out_shape = Shape::from_dims(&out_dims); + let out_layout = Layout::contiguous(out_shape.clone()); let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; @@ -1608,18 +1612,16 @@ impl candle::CustomOp3 for Sdpa { let q_seq = q_l.dim(2)?; let mut implementation_supports_use_case = q_head == k_head; - let supported_head_dim = + let supported_full_head_dim = q_head == 64 || q_head == 80 || q_head == 128; + let supported_vector_head_dim = q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; - const SDPA_FULL_THRESHOLD: usize = 2; - - let supports_sdpa_full = - q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; - let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + let supports_sdpa_full = supported_full_head_dim; + let supports_sdpa_vector = q_seq == 1 && supported_vector_head_dim; implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; - if !supported_head_dim { + if !(supported_vector_head_dim || supported_full_head_dim) { candle::bail!( "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.", q_l.dims(), @@ -1726,27 +1728,70 @@ impl candle::CustomOp3 for Sdpa { .map_err(candle::Error::wrap)?; } } else if supports_sdpa_full { - if q_l.dim(2)? != k_l.dim(2)? { - candle::bail!( - "query and key sequence length must be equal if using full metal sdpa" - ) + command_buffer.set_label("full_attention"); + if self.softcapping != 1. { + candle::bail!("SDPA full requires softcapping to be disabled (1.0)"); } - command_buffer.set_label("full_attention"); + let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout()); + + let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask { + let (mask_s, mask_l) = mask_s_l.as_ref().unwrap(); + + let mask_buffer = match &**mask_s { + candle::Storage::Metal(m) => m.buffer(), + _ => candle::bail!("Expected metal device for mask"), + }; + + let mask_type = match mask.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => candle::bail!("unsupported sdpa type {other:?}"), + }; + if mask_type != itype { + candle::bail!("Mask type {mask_type:?} must match q type {itype:?}"); + } + + if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_l.dim(2)?] { + candle::bail!( + "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}", + [q_l.dim(0)?, q_head, q_l.dim(2)?, k_l.dim(2)?], + mask_l.dims() + ); + } + + ( + Some(mask_type), + Some(mask_buffer), + Some(mask_l.stride().to_vec()), + ) + } else { + (None, None, None) + }; + candle_metal_kernels::call_sdpa_full( q.device().device(), &command_buffer, q.device().kernels(), q_l.start_offset(), q_l.dims(), + q_l.stride(), q.buffer(), k_l.start_offset(), + k_l.dims(), + k_l.stride(), k.buffer(), v_l.start_offset(), v.buffer(), + v_l.stride(), + mask_type, + mask_buffer, + mask_strides.as_deref(), &output, + out_layout.stride(), self.scale, - self.softcapping, + self.do_causal, itype, ) .map_err(candle::Error::wrap)?; @@ -1755,7 +1800,7 @@ impl candle::CustomOp3 for Sdpa { } let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); - Ok((newstorage, Shape::from_dims(&out_dims))) + Ok((newstorage, out_shape)) } } @@ -1767,13 +1812,15 @@ impl candle::CustomOp3 for Sdpa { /// - `q`: (bs, qhead, seq, hidden) /// - `k`: (bs, kv_head, kv_seq, hidden) /// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `mask`: (bs, qhead, seq, kv_seq) +/// - `do_causal`: Apply causal masking. If this is true, the mask does not need to be provided. /// - `scale` is applied before softmax. /// - If `softcapping` != 1.0: /// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v /// /// **Output shape:** (bs, qhead, seq, v_hidden) /// -/// **Supported head dims:** 32, 64, 96, 128, 256. +/// Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q. /// /// ## On Metal: /// - If `seq` == 1: @@ -1781,13 +1828,32 @@ impl candle::CustomOp3 for Sdpa { /// - Supports `seq` != `kv_seq` (cross attn. support) /// - Supports GQA when `qhead` is a multiple of `kv_head` /// - Otherwise: -/// - Use an alternate kernel -/// - Requires `seq` == `kv_seq` -/// - GQA is not supported (requires `qhead` == `kv_head`) -pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { - q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +/// - Masking is supported +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Softcapping is not supported. +pub fn sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + do_causal: bool, + scale: f32, + softcapping: f32, +) -> Result { + q.apply_op3_no_bwd( + k, + v, + &Sdpa { + scale, + softcapping, + mask: mask.cloned(), + do_causal, + }, + ) } +#[allow(unused)] struct MulAndAct { act: Activation, } diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs index 058387669c..46f628d520 100644 --- a/candle-nn/tests/sdpa.rs +++ b/candle-nn/tests/sdpa.rs @@ -25,7 +25,7 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -33,7 +33,7 @@ mod metal_sdpa_tests { .sum_all()? .to_scalar()?; - assert!(error <= 0.0005, "{}", error); + assert!(error <= 0.002, "{}", error); Ok(()) } @@ -63,7 +63,7 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -101,7 +101,7 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -115,9 +115,8 @@ mod metal_sdpa_tests { } #[test] - fn sdpa_full_softcapping() -> candle::Result<()> { + fn sdpa_full_masked() -> candle::Result<()> { use candle::{DType, Device, Tensor}; - use std::ops::{Div, Mul}; // Allow vectorized, seqlen = 1 const BS: usize = 4; @@ -125,7 +124,6 @@ mod metal_sdpa_tests { const L: usize = 4; const DK: usize = 64; const H: usize = 3; - const SOFTCAP: f64 = 50.; let scale: f64 = f64::from(DK as u32).sqrt().recip(); let device = Device::new_metal(0)?; @@ -133,20 +131,16 @@ mod metal_sdpa_tests { let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let mask = Tensor::randn(0f32, 1f32, (BS, H, R, L), &device)?; let ground_truth = { let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; - let att = candle_nn::ops::softmax_last_dim( - &att.to_dtype(DType::F32)? - .div(SOFTCAP)? - .tanh()? - .mul(SOFTCAP)?, - )? - .to_dtype(q.dtype())?; + let att = candle_nn::ops::softmax_last_dim(&(att.to_dtype(DType::F32)? + &mask)?)? + .to_dtype(q.dtype())?; att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, Some(&mask), false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -154,7 +148,7 @@ mod metal_sdpa_tests { .sum_all()? .to_scalar()?; - assert!(error <= 0.0004, "{}", error); + assert!(error <= 0.006, "{}", error); Ok(()) } @@ -191,7 +185,8 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -236,7 +231,8 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + let sdpa_output = + candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, SOFTCAP as f32)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -274,7 +270,7 @@ mod metal_sdpa_tests { att.matmul(&v.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); @@ -406,7 +402,12 @@ mod metal_sdpa_tests { // Using cat is faster than a broadcast as it avoids going through a potentially // strided copy. // https://github.com/huggingface/candle/pull/2043 - Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim)) + Tensor::cat(&vec![&xs; n_rep], 2)?.reshape(( + b_sz, + n_kv_head * n_rep, + seq_len, + head_dim, + )) } } @@ -432,7 +433,7 @@ mod metal_sdpa_tests { .to_dtype(q.dtype())?; att.matmul(&v_aligned.clone())? }; - let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, None, false, scale as f32, 1.)?; assert_eq!(ground_truth.shape(), sdpa_output.shape()); let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? .sum_all()?