@@ -639,16 +639,22 @@ impl Drop for CuTensor {
639639/// ([sys::cutensorTensorDescriptor_t]).
640640///
641641/// Automatically destroyed on drop.
642+ ///
643+ /// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
644+ /// the CUDA context outlives the descriptor. Without this, dropping the
645+ /// [`CuTensor`] (and its last stream reference) before the descriptor would
646+ /// leave [`Drop`] calling into a torn-down context.
642647pub struct TensorDescriptor {
643648 pub ( crate ) desc : sys:: cutensorTensorDescriptor_t ,
649+ pub ( crate ) stream : Arc < CudaStream > ,
644650}
645651
646652impl TensorDescriptor {
647653 /// Creates a new tensor descriptor.
648654 ///
649655 /// # Arguments
650656 ///
651- /// * `handle` - The cuTENSOR handle
657+ /// * `handle` - The cuTENSOR handle (its stream is retained)
652658 /// * `extent` - Size of each dimension
653659 /// * `stride` - Stride (in elements) of each dimension
654660 /// * `data_type` - Element data type
@@ -665,6 +671,9 @@ impl TensorDescriptor {
665671 stride. len( ) ,
666672 "extent and stride must have the same length"
667673 ) ;
674+ let stream = handle. stream . clone ( ) ;
675+ let ctx = stream. context ( ) ;
676+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
668677 let desc = unsafe {
669678 result:: create_tensor_descriptor (
670679 handle. handle ,
@@ -675,19 +684,26 @@ impl TensorDescriptor {
675684 alignment,
676685 ) ?
677686 } ;
678- Ok ( Self { desc } )
687+ Ok ( Self { desc, stream } )
679688 }
680689
681690 /// Returns the underlying descriptor pointer.
682691 pub fn desc ( & self ) -> sys:: cutensorTensorDescriptor_t {
683692 self . desc
684693 }
694+
695+ /// Returns the [`Arc<CudaStream>`] retained by this descriptor.
696+ pub fn stream ( & self ) -> & Arc < CudaStream > {
697+ & self . stream
698+ }
685699}
686700
687701impl Drop for TensorDescriptor {
688702 fn drop ( & mut self ) {
689703 let desc = std:: mem:: replace ( & mut self . desc , std:: ptr:: null_mut ( ) ) ;
690704 if !desc. is_null ( ) {
705+ let ctx = self . stream . context ( ) ;
706+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
691707 unsafe { result:: destroy_tensor_descriptor ( desc) } . unwrap ( ) ;
692708 }
693709 }
@@ -698,12 +714,17 @@ impl Drop for TensorDescriptor {
698714///
699715/// Created by [`OperationDescriptor::new_contraction`] or
700716/// [`OperationDescriptor::new_reduction`]. Automatically destroyed on drop.
717+ ///
718+ /// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
719+ /// the CUDA context outlives the descriptor.
701720pub struct OperationDescriptor {
702721 pub ( crate ) desc : sys:: cutensorOperationDescriptor_t ,
722+ pub ( crate ) stream : Arc < CudaStream > ,
703723}
704724
705725impl OperationDescriptor {
706- /// Creates a contraction operation descriptor.
726+ /// Creates a contraction operation descriptor. The handle's stream is
727+ /// retained so the CUDA context outlives this descriptor.
707728 ///
708729 /// # Safety
709730 ///
@@ -725,6 +746,9 @@ impl OperationDescriptor {
725746 mode_d : & [ i32 ] ,
726747 compute_desc : sys:: cutensorComputeDescriptor_t ,
727748 ) -> Result < Self , CutensorError > {
749+ let stream = handle. stream . clone ( ) ;
750+ let ctx = stream. context ( ) ;
751+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
728752 let desc = result:: create_contraction (
729753 handle. handle ,
730754 desc_a. desc ,
@@ -740,10 +764,11 @@ impl OperationDescriptor {
740764 mode_d. as_ptr ( ) ,
741765 compute_desc,
742766 ) ?;
743- Ok ( Self { desc } )
767+ Ok ( Self { desc, stream } )
744768 }
745769
746- /// Creates a reduction operation descriptor.
770+ /// Creates a reduction operation descriptor. The handle's stream is
771+ /// retained so the CUDA context outlives this descriptor.
747772 ///
748773 /// # Safety
749774 ///
@@ -763,6 +788,9 @@ impl OperationDescriptor {
763788 op_reduce : Operator ,
764789 compute_desc : sys:: cutensorComputeDescriptor_t ,
765790 ) -> Result < Self , CutensorError > {
791+ let stream = handle. stream . clone ( ) ;
792+ let ctx = stream. context ( ) ;
793+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
766794 let desc = result:: create_reduction (
767795 handle. handle ,
768796 desc_a. desc ,
@@ -776,40 +804,55 @@ impl OperationDescriptor {
776804 op_reduce. into ( ) ,
777805 compute_desc,
778806 ) ?;
779- Ok ( Self { desc } )
807+ Ok ( Self { desc, stream } )
780808 }
781809
782810 /// Returns the underlying descriptor pointer.
783811 pub fn desc ( & self ) -> sys:: cutensorOperationDescriptor_t {
784812 self . desc
785813 }
814+
815+ /// Returns the [`Arc<CudaStream>`] retained by this descriptor.
816+ pub fn stream ( & self ) -> & Arc < CudaStream > {
817+ & self . stream
818+ }
786819}
787820
788821impl Drop for OperationDescriptor {
789822 fn drop ( & mut self ) {
790823 let desc = std:: mem:: replace ( & mut self . desc , std:: ptr:: null_mut ( ) ) ;
791824 if !desc. is_null ( ) {
825+ let ctx = self . stream . context ( ) ;
826+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
792827 unsafe { result:: destroy_operation_descriptor ( desc) } . unwrap ( ) ;
793828 }
794829 }
795830}
796831
797832/// RAII wrapper around a cuTENSOR plan preference
798833/// ([sys::cutensorPlanPreference_t]).
834+ ///
835+ /// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
836+ /// the CUDA context outlives the descriptor.
799837pub struct PlanPreference {
800838 pub ( crate ) pref : sys:: cutensorPlanPreference_t ,
839+ pub ( crate ) stream : Arc < CudaStream > ,
801840}
802841
803842impl PlanPreference {
804- /// Creates a new plan preference.
843+ /// Creates a new plan preference. The handle's stream is retained so
844+ /// the CUDA context outlives this descriptor.
805845 pub fn new (
806846 handle : & CuTensor ,
807847 algo : Algorithm ,
808848 jit_mode : JitMode ,
809849 ) -> Result < Self , CutensorError > {
850+ let stream = handle. stream . clone ( ) ;
851+ let ctx = stream. context ( ) ;
852+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
810853 let pref =
811854 unsafe { result:: create_plan_preference ( handle. handle , algo. into ( ) , jit_mode. into ( ) ) ? } ;
812- Ok ( Self { pref } )
855+ Ok ( Self { pref, stream } )
813856 }
814857
815858 /// Estimates the workspace size for a given operation.
@@ -833,45 +876,67 @@ impl PlanPreference {
833876 pub fn pref ( & self ) -> sys:: cutensorPlanPreference_t {
834877 self . pref
835878 }
879+
880+ /// Returns the [`Arc<CudaStream>`] retained by this descriptor.
881+ pub fn stream ( & self ) -> & Arc < CudaStream > {
882+ & self . stream
883+ }
836884}
837885
838886impl Drop for PlanPreference {
839887 fn drop ( & mut self ) {
840888 let pref = std:: mem:: replace ( & mut self . pref , std:: ptr:: null_mut ( ) ) ;
841889 if !pref. is_null ( ) {
890+ let ctx = self . stream . context ( ) ;
891+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
842892 unsafe { result:: destroy_plan_preference ( pref) } . unwrap ( ) ;
843893 }
844894 }
845895}
846896
847897/// RAII wrapper around a cuTENSOR execution plan ([sys::cutensorPlan_t]).
898+ ///
899+ /// Holds an [`Arc<CudaStream>`] cloned from the originating [`CuTensor`] so
900+ /// the CUDA context outlives the descriptor.
848901pub struct ContractionPlan {
849902 pub ( crate ) plan : sys:: cutensorPlan_t ,
903+ pub ( crate ) stream : Arc < CudaStream > ,
850904}
851905
852906impl ContractionPlan {
853- /// Creates a new execution plan.
907+ /// Creates a new execution plan. The handle's stream is retained so the
908+ /// CUDA context outlives this descriptor.
854909 pub fn new (
855910 handle : & CuTensor ,
856911 op_desc : & OperationDescriptor ,
857912 pref : & PlanPreference ,
858913 workspace_size : u64 ,
859914 ) -> Result < Self , CutensorError > {
915+ let stream = handle. stream . clone ( ) ;
916+ let ctx = stream. context ( ) ;
917+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
860918 let plan =
861919 unsafe { result:: create_plan ( handle. handle , op_desc. desc , pref. pref , workspace_size) ? } ;
862- Ok ( Self { plan } )
920+ Ok ( Self { plan, stream } )
863921 }
864922
865923 /// Returns the underlying plan pointer.
866924 pub fn plan ( & self ) -> sys:: cutensorPlan_t {
867925 self . plan
868926 }
927+
928+ /// Returns the [`Arc<CudaStream>`] retained by this descriptor.
929+ pub fn stream ( & self ) -> & Arc < CudaStream > {
930+ & self . stream
931+ }
869932}
870933
871934impl Drop for ContractionPlan {
872935 fn drop ( & mut self ) {
873936 let plan = std:: mem:: replace ( & mut self . plan , std:: ptr:: null_mut ( ) ) ;
874937 if !plan. is_null ( ) {
938+ let ctx = self . stream . context ( ) ;
939+ ctx. record_err ( ctx. bind_to_thread ( ) ) ;
875940 unsafe { result:: destroy_plan ( plan) } . unwrap ( ) ;
876941 }
877942 }
0 commit comments