Skip to content

Conversation

@zackangelo
Copy link
Contributor

Make candle ops public. Not sure if we're okay with this being part of the public API, but wanted to put this change up for feedback.

My primary use case is I want to write in place versions of some of these ops in our inference code, eg:

    use candle::backend::BackendStorage;
    use candle::cuda::{kernel_name, kernels, CudaDType, SlicePtrOrNull, WrapErr};
    use candle::op::UnaryOpT;
    use candle::{builder_arg, Layout, Result, WithDType};
    use std::marker::PhantomData;

    use cudarc::driver::{DeviceRepr, LaunchConfig, PushKernelArg};

    pub struct UnaryInPlace<'a, U: UnaryOpT, T: DeviceRepr + WithDType + CudaDType> {
        pub op: &'a U,
        pub _dtype: PhantomData<T>,
    }

    impl<'a, U: UnaryOpT, T: DeviceRepr + WithDType + CudaDType> UnaryInPlace<'a, U, T> {
        pub fn new(op: &'a U) -> Self {
            Self {
                op,
                _dtype: PhantomData,
            }
        }
    }

    impl<'a, U: UnaryOpT, T: DeviceRepr + WithDType + CudaDType> candle::InplaceOp1
        for UnaryInPlace<'a, U, T>
    {
        fn name(&self) -> &'static str {
            "inplace_op"
        }

        fn cpu_fwd(&self, _: &mut candle::CpuStorage, _: &Layout) -> Result<()> {
            candle::bail!("inplace is only supported on cuda")
        }

        #[cfg(feature = "cuda")]
        fn cuda_fwd(&self, storage: &mut candle::CudaStorage, layout: &Layout) -> Result<()> {
            let dev = storage.device().clone();
            let src = storage.as_cuda_slice_mut::<T>()?;
            let shape = layout.shape();
            let dims = shape.dims();
            let el_count = shape.elem_count();
            let cfg = LaunchConfig::for_num_elems(el_count as u32);
            let ds = SlicePtrOrNull::params_from_layout(&dev, layout)?;
            let src = &src.slice(layout.start_offset()..);
            let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;

            let mut builder = func.builder();
            //TODO theoretical occupancy low, check these args
            builder_arg!(builder, el_count);
            builder_arg!(builder, dims.len());
            ds.builder_arg(&mut builder);
            builder.arg(src); //src and dst are same for in place
            builder.arg(src);
            // SAFETY: ffi.
            unsafe { builder.launch(cfg) }.w()?;

            Ok(())
        }
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant