Skip to content

Commit 74134fa

Browse files
committed
hold Arc<CudaStream> in cutensor descriptor wrappers
1 parent c522547 commit 74134fa

1 file changed

Lines changed: 75 additions & 10 deletions

File tree

src/cutensor/safe.rs

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,22 @@ impl Drop for CuTensor {
639639
/// ([sys::cutensorTensorDescriptor_t]).
640640
///
641641
/// Automatically destroyed on drop.
642+
///
643+
/// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
644+
/// the CUDA context outlives the descriptor. Without this, dropping the
645+
/// [`CuTensor`] (and its last stream reference) before the descriptor would
646+
/// leave [`Drop`] calling into a torn-down context.
642647
pub struct TensorDescriptor {
643648
pub(crate) desc: sys::cutensorTensorDescriptor_t,
649+
pub(crate) stream: Arc<CudaStream>,
644650
}
645651

646652
impl TensorDescriptor {
647653
/// Creates a new tensor descriptor.
648654
///
649655
/// # Arguments
650656
///
651-
/// * `handle` - The cuTENSOR handle
657+
/// * `handle` - The cuTENSOR handle (its stream is retained)
652658
/// * `extent` - Size of each dimension
653659
/// * `stride` - Stride (in elements) of each dimension
654660
/// * `data_type` - Element data type
@@ -665,6 +671,9 @@ impl TensorDescriptor {
665671
stride.len(),
666672
"extent and stride must have the same length"
667673
);
674+
let stream = handle.stream.clone();
675+
let ctx = stream.context();
676+
ctx.record_err(ctx.bind_to_thread());
668677
let desc = unsafe {
669678
result::create_tensor_descriptor(
670679
handle.handle,
@@ -675,19 +684,26 @@ impl TensorDescriptor {
675684
alignment,
676685
)?
677686
};
678-
Ok(Self { desc })
687+
Ok(Self { desc, stream })
679688
}
680689

681690
/// Returns the underlying descriptor pointer.
682691
pub fn desc(&self) -> sys::cutensorTensorDescriptor_t {
683692
self.desc
684693
}
694+
695+
/// Returns the [`Arc<CudaStream>`] retained by this descriptor.
696+
pub fn stream(&self) -> &Arc<CudaStream> {
697+
&self.stream
698+
}
685699
}
686700

687701
impl Drop for TensorDescriptor {
688702
fn drop(&mut self) {
689703
let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
690704
if !desc.is_null() {
705+
let ctx = self.stream.context();
706+
ctx.record_err(ctx.bind_to_thread());
691707
unsafe { result::destroy_tensor_descriptor(desc) }.unwrap();
692708
}
693709
}
@@ -698,12 +714,17 @@ impl Drop for TensorDescriptor {
698714
///
699715
/// Created by [`OperationDescriptor::new_contraction`] or
700716
/// [`OperationDescriptor::new_reduction`]. Automatically destroyed on drop.
717+
///
718+
/// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
719+
/// the CUDA context outlives the descriptor.
701720
pub struct OperationDescriptor {
702721
pub(crate) desc: sys::cutensorOperationDescriptor_t,
722+
pub(crate) stream: Arc<CudaStream>,
703723
}
704724

705725
impl OperationDescriptor {
706-
/// Creates a contraction operation descriptor.
726+
/// Creates a contraction operation descriptor. The handle's stream is
727+
/// retained so the CUDA context outlives this descriptor.
707728
///
708729
/// # Safety
709730
///
@@ -725,6 +746,9 @@ impl OperationDescriptor {
725746
mode_d: &[i32],
726747
compute_desc: sys::cutensorComputeDescriptor_t,
727748
) -> Result<Self, CutensorError> {
749+
let stream = handle.stream.clone();
750+
let ctx = stream.context();
751+
ctx.record_err(ctx.bind_to_thread());
728752
let desc = result::create_contraction(
729753
handle.handle,
730754
desc_a.desc,
@@ -740,10 +764,11 @@ impl OperationDescriptor {
740764
mode_d.as_ptr(),
741765
compute_desc,
742766
)?;
743-
Ok(Self { desc })
767+
Ok(Self { desc, stream })
744768
}
745769

746-
/// Creates a reduction operation descriptor.
770+
/// Creates a reduction operation descriptor. The handle's stream is
771+
/// retained so the CUDA context outlives this descriptor.
747772
///
748773
/// # Safety
749774
///
@@ -763,6 +788,9 @@ impl OperationDescriptor {
763788
op_reduce: Operator,
764789
compute_desc: sys::cutensorComputeDescriptor_t,
765790
) -> Result<Self, CutensorError> {
791+
let stream = handle.stream.clone();
792+
let ctx = stream.context();
793+
ctx.record_err(ctx.bind_to_thread());
766794
let desc = result::create_reduction(
767795
handle.handle,
768796
desc_a.desc,
@@ -776,40 +804,55 @@ impl OperationDescriptor {
776804
op_reduce.into(),
777805
compute_desc,
778806
)?;
779-
Ok(Self { desc })
807+
Ok(Self { desc, stream })
780808
}
781809

782810
/// Returns the underlying descriptor pointer.
783811
pub fn desc(&self) -> sys::cutensorOperationDescriptor_t {
784812
self.desc
785813
}
814+
815+
/// Returns the [`Arc<CudaStream>`] retained by this descriptor.
816+
pub fn stream(&self) -> &Arc<CudaStream> {
817+
&self.stream
818+
}
786819
}
787820

788821
impl Drop for OperationDescriptor {
789822
fn drop(&mut self) {
790823
let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
791824
if !desc.is_null() {
825+
let ctx = self.stream.context();
826+
ctx.record_err(ctx.bind_to_thread());
792827
unsafe { result::destroy_operation_descriptor(desc) }.unwrap();
793828
}
794829
}
795830
}
796831

797832
/// RAII wrapper around a cuTENSOR plan preference
798833
/// ([sys::cutensorPlanPreference_t]).
834+
///
835+
/// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
836+
/// the CUDA context outlives the descriptor.
799837
pub struct PlanPreference {
800838
pub(crate) pref: sys::cutensorPlanPreference_t,
839+
pub(crate) stream: Arc<CudaStream>,
801840
}
802841

803842
impl PlanPreference {
804-
/// Creates a new plan preference.
843+
/// Creates a new plan preference. The handle's stream is retained so
844+
/// the CUDA context outlives this descriptor.
805845
pub fn new(
806846
handle: &CuTensor,
807847
algo: Algorithm,
808848
jit_mode: JitMode,
809849
) -> Result<Self, CutensorError> {
850+
let stream = handle.stream.clone();
851+
let ctx = stream.context();
852+
ctx.record_err(ctx.bind_to_thread());
810853
let pref =
811854
unsafe { result::create_plan_preference(handle.handle, algo.into(), jit_mode.into())? };
812-
Ok(Self { pref })
855+
Ok(Self { pref, stream })
813856
}
814857

815858
/// Estimates the workspace size for a given operation.
@@ -833,45 +876,67 @@ impl PlanPreference {
833876
pub fn pref(&self) -> sys::cutensorPlanPreference_t {
834877
self.pref
835878
}
879+
880+
/// Returns the [`Arc<CudaStream>`] retained by this descriptor.
881+
pub fn stream(&self) -> &Arc<CudaStream> {
882+
&self.stream
883+
}
836884
}
837885

838886
impl Drop for PlanPreference {
839887
fn drop(&mut self) {
840888
let pref = std::mem::replace(&mut self.pref, std::ptr::null_mut());
841889
if !pref.is_null() {
890+
let ctx = self.stream.context();
891+
ctx.record_err(ctx.bind_to_thread());
842892
unsafe { result::destroy_plan_preference(pref) }.unwrap();
843893
}
844894
}
845895
}
846896

847897
/// RAII wrapper around a cuTENSOR execution plan ([sys::cutensorPlan_t]).
898+
///
899+
/// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
900+
/// the CUDA context outlives the descriptor.
848901
pub struct ContractionPlan {
849902
pub(crate) plan: sys::cutensorPlan_t,
903+
pub(crate) stream: Arc<CudaStream>,
850904
}
851905

852906
impl ContractionPlan {
853-
/// Creates a new execution plan.
907+
/// Creates a new execution plan. The handle's stream is retained so the
908+
/// CUDA context outlives this descriptor.
854909
pub fn new(
855910
handle: &CuTensor,
856911
op_desc: &OperationDescriptor,
857912
pref: &PlanPreference,
858913
workspace_size: u64,
859914
) -> Result<Self, CutensorError> {
915+
let stream = handle.stream.clone();
916+
let ctx = stream.context();
917+
ctx.record_err(ctx.bind_to_thread());
860918
let plan =
861919
unsafe { result::create_plan(handle.handle, op_desc.desc, pref.pref, workspace_size)? };
862-
Ok(Self { plan })
920+
Ok(Self { plan, stream })
863921
}
864922

865923
/// Returns the underlying plan pointer.
866924
pub fn plan(&self) -> sys::cutensorPlan_t {
867925
self.plan
868926
}
927+
928+
/// Returns the [`Arc<CudaStream>`] retained by this descriptor.
929+
pub fn stream(&self) -> &Arc<CudaStream> {
930+
&self.stream
931+
}
869932
}
870933

871934
impl Drop for ContractionPlan {
872935
fn drop(&mut self) {
873936
let plan = std::mem::replace(&mut self.plan, std::ptr::null_mut());
874937
if !plan.is_null() {
938+
let ctx = self.stream.context();
939+
ctx.record_err(ctx.bind_to_thread());
875940
unsafe { result::destroy_plan(plan) }.unwrap();
876941
}
877942
}

0 commit comments

Comments
 (0)