@@ -488,27 +488,31 @@ namespace nv::execution
488488 template <class Env , class Variant >
489489 struct stream_enqueue_receiver
490490 {
491- Env* env_;
492- Variant* variant_;
493- queue::task_base* task_;
494- queue::producer producer_;
495-
496- public:
497491 using receiver_concept = STDEXEC::receiver_t ;
498492
493+ explicit stream_enqueue_receiver (Env const * env,
494+ Variant* variant,
495+ queue::task_base* task,
496+ queue::producer producer)
497+ : env_(env)
498+ , variant_(variant)
499+ , task_(task)
500+ , producer_(producer)
501+ {}
502+
499503 template <class ... Args>
500504 STDEXEC_ATTRIBUTE (host, device)
501505 void set_value (Args&&... args) noexcept
502506 {
503- variant_->template emplace <decayed_tuple_t <set_value_t , Args...>>(set_value_t (),
504- static_cast <Args&&>(
505- args)...);
507+ using tuple_t = decayed_tuple_t <set_value_t , Args...>;
508+ variant_->template emplace <tuple_t >(set_value_t (), static_cast <Args&&>(args)...);
506509 producer_ (task_);
507510 }
508511
509512 STDEXEC_ATTRIBUTE (host, device) void set_stopped () noexcept
510513 {
511- variant_->template emplace <decayed_tuple_t <set_stopped_t >>(set_stopped_t ());
514+ using tuple_t = decayed_tuple_t <set_stopped_t >;
515+ variant_->template emplace <tuple_t >(set_stopped_t ());
512516 producer_ (task_);
513517 }
514518
@@ -519,14 +523,13 @@ namespace nv::execution
519523 if constexpr (__decays_to<Error, std::exception_ptr>)
520524 {
521525 // What is `exception_ptr` but death pending
522- variant_-> template emplace < decayed_tuple_t <set_error_t , cudaError_t>>(STDEXEC::set_error,
523- cudaErrorUnknown);
526+ using tuple_t = decayed_tuple_t <set_error_t , cudaError_t>;
527+ variant_-> template emplace < tuple_t >(STDEXEC::set_error, cudaErrorUnknown);
524528 }
525529 else
526530 {
527- variant_->template emplace <decayed_tuple_t <set_error_t , Error>>(set_error_t (),
528- static_cast <Error&&>(
529- err));
531+ using tuple_t = decayed_tuple_t <set_error_t , Error>;
532+ variant_->template emplace <tuple_t >(set_error_t (), static_cast <Error&&>(err));
530533 }
531534 producer_ (task_);
532535 }
@@ -536,15 +539,11 @@ namespace nv::execution
536539 return *env_;
537540 }
538541
539- stream_enqueue_receiver (Env* env,
540- Variant* variant,
541- queue::task_base* task,
542- queue::producer producer)
543- : env_(env)
544- , variant_(variant)
545- , task_(task)
546- , producer_(producer)
547- {}
542+ private:
543+ Env const * env_;
544+ Variant* variant_;
545+ queue::task_base* task_;
546+ queue::producer producer_;
548547 };
549548
550549 template <class Receiver , class ... Args, class Tag >
@@ -558,16 +557,10 @@ namespace nv::execution
558557 template <class Receiver , class Variant >
559558 struct continuation_task : queue::task_base
560559 {
561- Receiver rcvr_;
562- Variant* variant_;
563- cudaStream_t stream_{};
564- std::pmr::memory_resource* pinned_resource_{};
565- cudaError_t status_{cudaSuccess};
566-
567- continuation_task (Receiver rcvr,
568- Variant* variant,
569- cudaStream_t stream,
570- std::pmr::memory_resource* pinned_resource) noexcept
560+ explicit continuation_task (Receiver rcvr,
561+ Variant* variant,
562+ cudaStream_t stream,
563+ std::pmr::memory_resource* pinned_resource) noexcept
571564 : rcvr_{rcvr}
572565 , variant_{variant}
573566 , stream_{stream}
@@ -606,6 +599,18 @@ namespace nv::execution
606599 status_ = STDEXEC_LOG_CUDA_API (cudaMemsetAsync (this ->atom_next_ , 0 , ptr_size, stream_));
607600 }
608601 }
602+
603+ cudaError_t status () const noexcept
604+ {
605+ return status_;
606+ }
607+
608+ private:
609+ Receiver rcvr_;
610+ Variant* variant_;
611+ cudaStream_t stream_{};
612+ std::pmr::memory_resource* pinned_resource_{};
613+ cudaError_t status_{cudaSuccess};
609614 };
610615
611616 template <class Env >
@@ -695,6 +700,7 @@ namespace nv::execution
695700 }
696701 }
697702
703+ [[nodiscard]]
698704 auto make_env () const noexcept -> env_t
699705 {
700706 return make_stream_env (get_env (rcvr_), get_stream_provider ());
@@ -738,10 +744,12 @@ namespace nv::execution
738744 stream_provider stream_provider_;
739745 };
740746
741- template <class OuterReceiver >
747+ template <class OpState , class Env = decltype (__declval<OpState&>().make_env()) >
742748 struct propagate_receiver : stream_receiver_base
743749 {
744- opstate_base<OuterReceiver>& opstate_;
750+ explicit propagate_receiver (OpState& opstate) noexcept
751+ : opstate_(opstate)
752+ {}
745753
746754 template <class ... Args>
747755 void set_value (Args&&... args) noexcept
@@ -760,10 +768,14 @@ namespace nv::execution
760768 opstate_.propagate_completion_signal (set_stopped_t ());
761769 }
762770
763- auto get_env () const noexcept -> decltype(auto )
771+ [[nodiscard]]
772+ auto get_env () const noexcept -> Env
764773 {
765774 return opstate_.make_env ();
766775 }
776+
777+ private:
778+ OpState& opstate_;
767779 };
768780
769781 template <class CvSender , class InnerReceiver , class OuterReceiver >
@@ -780,38 +792,6 @@ namespace nv::execution
780792 __if_c<stream_sender<CvSender, env_t >, InnerReceiver, stream_enqueue_receiver_t >;
781793 using inner_opstate_t = connect_result_t <CvSender, intermediate_receiver_t >;
782794
783- void start () & noexcept
784- {
785- started_.test_and_set (::cuda::std::memory_order::relaxed);
786-
787- if (this ->stream_provider_ .status_ != cudaSuccess)
788- {
789- // Couldn't allocate memory for opstate state, complete with error
790- this ->propagate_completion_signal (STDEXEC::set_error,
791- std::move (this ->stream_provider_ .status_ ));
792- return ;
793- }
794-
795- if constexpr (stream_receiver<InnerReceiver>)
796- {
797- if (InnerReceiver::memory_allocation_size ())
798- {
799- STDEXEC_TRY
800- {
801- this ->temp_storage_ = this ->ctx_ .managed_resource_ ->allocate (
802- InnerReceiver::memory_allocation_size ());
803- }
804- STDEXEC_CATCH_ALL
805- {
806- this ->propagate_completion_signal (STDEXEC::set_error, cudaErrorMemoryAllocation);
807- return ;
808- }
809- }
810- }
811-
812- STDEXEC::start (inner_op_);
813- }
814-
815795 template <class ReceiverProvider >
816796 requires stream_sender<CvSender, env_t >
817797 opstate (CvSender&& sender,
@@ -846,7 +826,7 @@ namespace nv::execution
846826 {
847827 if (this ->stream_provider_ .status_ == cudaSuccess)
848828 {
849- this ->stream_provider_ .status_ = task_->status_ ;
829+ this ->stream_provider_ .status_ = task_->status () ;
850830 }
851831 }
852832
@@ -870,6 +850,39 @@ namespace nv::execution
870850
871851 STDEXEC_IMMOVABLE (opstate);
872852
853+ void start () & noexcept
854+ {
855+ started_.test_and_set (::cuda::std::memory_order::relaxed);
856+
857+ if (this ->stream_provider_ .status_ != cudaSuccess)
858+ {
859+ // Couldn't allocate memory for opstate state, complete with error
860+ this ->propagate_completion_signal (STDEXEC::set_error,
861+ std::move (this ->stream_provider_ .status_ ));
862+ return ;
863+ }
864+
865+ if constexpr (stream_receiver<InnerReceiver>)
866+ {
867+ if (InnerReceiver::memory_allocation_size ())
868+ {
869+ STDEXEC_TRY
870+ {
871+ this ->temp_storage_ = this ->ctx_ .managed_resource_ ->allocate (
872+ InnerReceiver::memory_allocation_size ());
873+ }
874+ STDEXEC_CATCH_ALL
875+ {
876+ this ->propagate_completion_signal (STDEXEC::set_error, cudaErrorMemoryAllocation);
877+ return ;
878+ }
879+ }
880+ }
881+
882+ STDEXEC::start (inner_op_);
883+ }
884+
885+ private:
873886 host_ptr_t <variant_t > storage_;
874887 task_t * task_{};
875888 ::cuda::std::atomic_flag started_{};
@@ -880,60 +893,61 @@ namespace nv::execution
880893 template <class CvSender , class OuterReceiver >
881894 requires stream_receiver<OuterReceiver>
882895 using exit_opstate_t =
883- _strm::opstate<CvSender, propagate_receiver<OuterReceiver>, OuterReceiver>;
896+ _strm::opstate<CvSender, propagate_receiver<opstate_base< OuterReceiver> >, OuterReceiver>;
884897
885- template <class Sender , class OuterReceiver >
886- auto exit_opstate (Sender && sndr, OuterReceiver rcvr, context ctx) noexcept
887- -> exit_opstate_t<Sender , OuterReceiver>
898+ template <class CvSender , class OuterReceiver >
899+ auto exit_opstate (CvSender && sndr, OuterReceiver rcvr, context ctx) noexcept
900+ -> exit_opstate_t<CvSender , OuterReceiver>
888901 {
889- return exit_opstate_t <Sender , OuterReceiver>(
890- static_cast <Sender &&>(sndr),
902+ return exit_opstate_t <CvSender , OuterReceiver>(
903+ static_cast <CvSender &&>(sndr),
891904 static_cast <OuterReceiver&&>(rcvr),
892- [](opstate_base<OuterReceiver>& op) -> propagate_receiver<OuterReceiver>
893- { return propagate_receiver<OuterReceiver>{{}, op} ; },
905+ [](opstate_base<OuterReceiver>& op) noexcept
906+ { return propagate_receiver<opstate_base< OuterReceiver>>(op) ; },
894907 ctx);
895908 }
896909
897- template <class Sender , class E >
910+ template <class Sender , class Env >
898911 concept stream_completing_sender =
899912 sender<Sender>
900913 && gpu_stream_scheduler<
901- __result_of<get_completion_scheduler<set_value_t >, env_of_t <Sender>, E >,
902- E >;
914+ __result_of<get_completion_scheduler<set_value_t >, env_of_t <Sender>, Env >,
915+ Env >;
903916
904917 template <class InnerReceiverProvider , class OuterReceiver >
905918 using inner_receiver_t = __call_result_t <InnerReceiverProvider, opstate_base<OuterReceiver>&>;
906919
907920 template <class CvSender , class InnerReceiver , class OuterReceiver >
908921 using stream_opstate_t = _strm::opstate<CvSender, InnerReceiver, OuterReceiver>;
909922
910- template <class Sender , class OuterReceiver , class ReceiverProvider >
911- requires stream_completing_sender<Sender, env_of_t <OuterReceiver>>
912- auto
913- stream_opstate (Sender&& sndr, OuterReceiver&& out_receiver, ReceiverProvider receiver_provider)
914- -> stream_opstate_t<Sender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
923+ template <class CvSender , class OuterReceiver , class ReceiverProvider >
924+ requires stream_completing_sender<CvSender, env_of_t <OuterReceiver>>
925+ auto stream_opstate (CvSender&& sndr,
926+ OuterReceiver&& out_receiver,
927+ ReceiverProvider receiver_provider)
928+ -> stream_opstate_t<CvSender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
915929 {
916930 auto sch = get_completion_scheduler<set_value_t >(get_env (sndr), get_env (out_receiver));
917931 context ctx = sch.ctx_ ;
918932
919- return stream_opstate_t <Sender ,
933+ return stream_opstate_t <CvSender ,
920934 inner_receiver_t <ReceiverProvider, OuterReceiver>,
921- OuterReceiver>(static_cast <Sender &&>(sndr),
935+ OuterReceiver>(static_cast <CvSender &&>(sndr),
922936 static_cast <OuterReceiver&&>(out_receiver),
923937 receiver_provider,
924938 ctx);
925939 }
926940
927- template <class Sender , class OuterReceiver , class ReceiverProvider >
928- auto stream_opstate (Sender&& sndr,
941+ template <class CvSender , class OuterReceiver , class ReceiverProvider >
942+ auto stream_opstate (CvSender&& sndr,
929943 OuterReceiver&& out_receiver,
930944 ReceiverProvider receiver_provider,
931945 context ctx)
932- -> stream_opstate_t<Sender , inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
946+ -> stream_opstate_t<CvSender , inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
933947 {
934- return stream_opstate_t <Sender ,
948+ return stream_opstate_t <CvSender ,
935949 inner_receiver_t <ReceiverProvider, OuterReceiver>,
936- OuterReceiver>(static_cast <Sender &&>(sndr),
950+ OuterReceiver>(static_cast <CvSender &&>(sndr),
937951 static_cast <OuterReceiver&&>(out_receiver),
938952 receiver_provider,
939953 ctx);
0 commit comments