diff --git a/candle-metal-kernels/src/kernel.rs b/candle-metal-kernels/src/kernel.rs index b05eac7fa8..f941e30232 100644 --- a/candle-metal-kernels/src/kernel.rs +++ b/candle-metal-kernels/src/kernel.rs @@ -2,10 +2,13 @@ use crate::source::{ AFFINE, BINARY, CAST, CONV, FILL, INDEXING, MLX_GEMM, MLX_SORT, QUANTIZED, RANDOM, REDUCE, SDPA, SORT, TERNARY, UNARY, }; +use crate::utils::get_env_bool; use crate::{ - ComputePipeline, ConstantValues, Device, Function, Library, MTLCompileOptions, MTLMathMode, - MetalKernelError, Source, + ComputePipeline, ConstantValues, Device, Function, Library, MTLCompileOptions, + MTLMathFloatingPointFunctions, MTLMathMode, MetalKernelError, Source, }; +use objc2::available; +use objc2::rc::Retained; use std::collections::HashMap; use std::sync::RwLock; @@ -113,9 +116,7 @@ impl Kernels { } else { let lib = { let source_content = self.get_library_source(source); - let compile_options = MTLCompileOptions::new(); - //unsafe { compile_options.setEnableLogging(true) }; - compile_options.setMathMode(MTLMathMode::Fast); + let compile_options = get_compile_options(); device .new_library_with_source(source_content, Some(&compile_options)) .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? @@ -176,3 +177,26 @@ impl Kernels { self.load_pipeline_with_constants(device, source, name, None) } } + +fn get_compile_options() -> Retained { + let compile_options = MTLCompileOptions::new(); + //unsafe { compile_options.setEnableLogging(true) }; + + let fast_math_enabled = get_env_bool("CANDLE_METAL_ENABLE_FAST_MATH", true); + // Ref availability: + // https://developer.apple.com/documentation/metal/mtlcompileoptions/mathmode + if available!(macos = 15, ios = 18) { + if fast_math_enabled { + compile_options.setMathMode(MTLMathMode::Fast); + compile_options.setMathFloatingPointFunctions(MTLMathFloatingPointFunctions::Fast); + } else { + compile_options.setMathMode(MTLMathMode::Relaxed); + compile_options.setMathFloatingPointFunctions(MTLMathFloatingPointFunctions::Precise); + } + } else { + // For older OS versions we use the old api + #[allow(deprecated)] + compile_options.setFastMathEnabled(fast_math_enabled); + } + compile_options +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d278c2f8a1..827d2837b0 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -16,7 +16,7 @@ use metal::{ BlitCommandEncoder, Buffer, CommandQueue, ComputeCommandEncoder, ComputePipeline, ConstantValues, Device, Function, Library, MTLResourceOptions, Value, }; -use objc2_metal::{MTLCompileOptions, MTLMathMode, MTLSize}; +use objc2_metal::{MTLCompileOptions, MTLMathFloatingPointFunctions, MTLMathMode, MTLSize}; use source::Source; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 20a1fff681..1ad647d79d 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -1,5 +1,6 @@ use crate::metal::{Buffer, CommandBuffer, ComputeCommandEncoder, ComputePipeline}; -use objc2_metal::MTLSize; +use crate::MTLSize; +use std::ffi::OsStr; use std::ops::Deref; use std::sync::{RwLockReadGuard, RwLockWriteGuard}; @@ -236,3 +237,14 @@ impl<'a, T> From> for RwLockGuard<'a, T> { RwLockGuard::Write(g) } } + +fn is_truthy(s: String) -> bool { + match s.as_str() { + "true" | "t" | "yes" | "y" | "1" => true, + _ => false, + } +} + +pub(crate) fn get_env_bool>(key: K, default: bool) -> bool { + std::env::var(key).map(is_truthy).unwrap_or(default) +}