Skip to content

Commit b6f8a0e

Browse files
authored
Merge pull request #1927 from ericniebler/fix-secondary-env-of-nvexec-let-algorithms
change `nvexec::let_*` to tell the secondary sender where it is executing
2 parents b21a841 + 72e93c8 commit b6f8a0e

File tree

5 files changed

+303
-226
lines changed

5 files changed

+303
-226
lines changed

include/nvexec/stream/common.cuh

Lines changed: 107 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -488,27 +488,31 @@ namespace nv::execution
488488
template <class Env, class Variant>
489489
struct stream_enqueue_receiver
490490
{
491-
Env* env_;
492-
Variant* variant_;
493-
queue::task_base* task_;
494-
queue::producer producer_;
495-
496-
public:
497491
using receiver_concept = STDEXEC::receiver_t;
498492

493+
explicit stream_enqueue_receiver(Env const * env,
494+
Variant* variant,
495+
queue::task_base* task,
496+
queue::producer producer)
497+
: env_(env)
498+
, variant_(variant)
499+
, task_(task)
500+
, producer_(producer)
501+
{}
502+
499503
template <class... Args>
500504
STDEXEC_ATTRIBUTE(host, device)
501505
void set_value(Args&&... args) noexcept
502506
{
503-
variant_->template emplace<decayed_tuple_t<set_value_t, Args...>>(set_value_t(),
504-
static_cast<Args&&>(
505-
args)...);
507+
using tuple_t = decayed_tuple_t<set_value_t, Args...>;
508+
variant_->template emplace<tuple_t>(set_value_t(), static_cast<Args&&>(args)...);
506509
producer_(task_);
507510
}
508511

509512
STDEXEC_ATTRIBUTE(host, device) void set_stopped() noexcept
510513
{
511-
variant_->template emplace<decayed_tuple_t<set_stopped_t>>(set_stopped_t());
514+
using tuple_t = decayed_tuple_t<set_stopped_t>;
515+
variant_->template emplace<tuple_t>(set_stopped_t());
512516
producer_(task_);
513517
}
514518

@@ -519,14 +523,13 @@ namespace nv::execution
519523
if constexpr (__decays_to<Error, std::exception_ptr>)
520524
{
521525
// What is `exception_ptr` but death pending
522-
variant_->template emplace<decayed_tuple_t<set_error_t, cudaError_t>>(STDEXEC::set_error,
523-
cudaErrorUnknown);
526+
using tuple_t = decayed_tuple_t<set_error_t, cudaError_t>;
527+
variant_->template emplace<tuple_t>(STDEXEC::set_error, cudaErrorUnknown);
524528
}
525529
else
526530
{
527-
variant_->template emplace<decayed_tuple_t<set_error_t, Error>>(set_error_t(),
528-
static_cast<Error&&>(
529-
err));
531+
using tuple_t = decayed_tuple_t<set_error_t, Error>;
532+
variant_->template emplace<tuple_t>(set_error_t(), static_cast<Error&&>(err));
530533
}
531534
producer_(task_);
532535
}
@@ -536,15 +539,11 @@ namespace nv::execution
536539
return *env_;
537540
}
538541

539-
stream_enqueue_receiver(Env* env,
540-
Variant* variant,
541-
queue::task_base* task,
542-
queue::producer producer)
543-
: env_(env)
544-
, variant_(variant)
545-
, task_(task)
546-
, producer_(producer)
547-
{}
542+
private:
543+
Env const * env_;
544+
Variant* variant_;
545+
queue::task_base* task_;
546+
queue::producer producer_;
548547
};
549548

550549
template <class Receiver, class... Args, class Tag>
@@ -558,16 +557,10 @@ namespace nv::execution
558557
template <class Receiver, class Variant>
559558
struct continuation_task : queue::task_base
560559
{
561-
Receiver rcvr_;
562-
Variant* variant_;
563-
cudaStream_t stream_{};
564-
std::pmr::memory_resource* pinned_resource_{};
565-
cudaError_t status_{cudaSuccess};
566-
567-
continuation_task(Receiver rcvr,
568-
Variant* variant,
569-
cudaStream_t stream,
570-
std::pmr::memory_resource* pinned_resource) noexcept
560+
explicit continuation_task(Receiver rcvr,
561+
Variant* variant,
562+
cudaStream_t stream,
563+
std::pmr::memory_resource* pinned_resource) noexcept
571564
: rcvr_{rcvr}
572565
, variant_{variant}
573566
, stream_{stream}
@@ -606,6 +599,18 @@ namespace nv::execution
606599
status_ = STDEXEC_LOG_CUDA_API(cudaMemsetAsync(this->atom_next_, 0, ptr_size, stream_));
607600
}
608601
}
602+
603+
cudaError_t status() const noexcept
604+
{
605+
return status_;
606+
}
607+
608+
private:
609+
Receiver rcvr_;
610+
Variant* variant_;
611+
cudaStream_t stream_{};
612+
std::pmr::memory_resource* pinned_resource_{};
613+
cudaError_t status_{cudaSuccess};
609614
};
610615

611616
template <class Env>
@@ -695,6 +700,7 @@ namespace nv::execution
695700
}
696701
}
697702

703+
[[nodiscard]]
698704
auto make_env() const noexcept -> env_t
699705
{
700706
return make_stream_env(get_env(rcvr_), get_stream_provider());
@@ -738,10 +744,12 @@ namespace nv::execution
738744
stream_provider stream_provider_;
739745
};
740746

741-
template <class OuterReceiver>
747+
template <class OpState, class Env = decltype(__declval<OpState&>().make_env())>
742748
struct propagate_receiver : stream_receiver_base
743749
{
744-
opstate_base<OuterReceiver>& opstate_;
750+
explicit propagate_receiver(OpState& opstate) noexcept
751+
: opstate_(opstate)
752+
{}
745753

746754
template <class... Args>
747755
void set_value(Args&&... args) noexcept
@@ -760,10 +768,14 @@ namespace nv::execution
760768
opstate_.propagate_completion_signal(set_stopped_t());
761769
}
762770

763-
auto get_env() const noexcept -> decltype(auto)
771+
[[nodiscard]]
772+
auto get_env() const noexcept -> Env
764773
{
765774
return opstate_.make_env();
766775
}
776+
777+
private:
778+
OpState& opstate_;
767779
};
768780

769781
template <class CvSender, class InnerReceiver, class OuterReceiver>
@@ -780,38 +792,6 @@ namespace nv::execution
780792
__if_c<stream_sender<CvSender, env_t>, InnerReceiver, stream_enqueue_receiver_t>;
781793
using inner_opstate_t = connect_result_t<CvSender, intermediate_receiver_t>;
782794

783-
void start() & noexcept
784-
{
785-
started_.test_and_set(::cuda::std::memory_order::relaxed);
786-
787-
if (this->stream_provider_.status_ != cudaSuccess)
788-
{
789-
// Couldn't allocate memory for opstate state, complete with error
790-
this->propagate_completion_signal(STDEXEC::set_error,
791-
std::move(this->stream_provider_.status_));
792-
return;
793-
}
794-
795-
if constexpr (stream_receiver<InnerReceiver>)
796-
{
797-
if (InnerReceiver::memory_allocation_size())
798-
{
799-
STDEXEC_TRY
800-
{
801-
this->temp_storage_ = this->ctx_.managed_resource_->allocate(
802-
InnerReceiver::memory_allocation_size());
803-
}
804-
STDEXEC_CATCH_ALL
805-
{
806-
this->propagate_completion_signal(STDEXEC::set_error, cudaErrorMemoryAllocation);
807-
return;
808-
}
809-
}
810-
}
811-
812-
STDEXEC::start(inner_op_);
813-
}
814-
815795
template <class ReceiverProvider>
816796
requires stream_sender<CvSender, env_t>
817797
opstate(CvSender&& sender,
@@ -846,7 +826,7 @@ namespace nv::execution
846826
{
847827
if (this->stream_provider_.status_ == cudaSuccess)
848828
{
849-
this->stream_provider_.status_ = task_->status_;
829+
this->stream_provider_.status_ = task_->status();
850830
}
851831
}
852832

@@ -870,6 +850,39 @@ namespace nv::execution
870850

871851
STDEXEC_IMMOVABLE(opstate);
872852

853+
void start() & noexcept
854+
{
855+
started_.test_and_set(::cuda::std::memory_order::relaxed);
856+
857+
if (this->stream_provider_.status_ != cudaSuccess)
858+
{
859+
// Couldn't allocate memory for opstate state, complete with error
860+
this->propagate_completion_signal(STDEXEC::set_error,
861+
std::move(this->stream_provider_.status_));
862+
return;
863+
}
864+
865+
if constexpr (stream_receiver<InnerReceiver>)
866+
{
867+
if (InnerReceiver::memory_allocation_size())
868+
{
869+
STDEXEC_TRY
870+
{
871+
this->temp_storage_ = this->ctx_.managed_resource_->allocate(
872+
InnerReceiver::memory_allocation_size());
873+
}
874+
STDEXEC_CATCH_ALL
875+
{
876+
this->propagate_completion_signal(STDEXEC::set_error, cudaErrorMemoryAllocation);
877+
return;
878+
}
879+
}
880+
}
881+
882+
STDEXEC::start(inner_op_);
883+
}
884+
885+
private:
873886
host_ptr_t<variant_t> storage_;
874887
task_t* task_{};
875888
::cuda::std::atomic_flag started_{};
@@ -880,60 +893,61 @@ namespace nv::execution
880893
template <class CvSender, class OuterReceiver>
881894
requires stream_receiver<OuterReceiver>
882895
using exit_opstate_t =
883-
_strm::opstate<CvSender, propagate_receiver<OuterReceiver>, OuterReceiver>;
896+
_strm::opstate<CvSender, propagate_receiver<opstate_base<OuterReceiver>>, OuterReceiver>;
884897

885-
template <class Sender, class OuterReceiver>
886-
auto exit_opstate(Sender&& sndr, OuterReceiver rcvr, context ctx) noexcept
887-
-> exit_opstate_t<Sender, OuterReceiver>
898+
template <class CvSender, class OuterReceiver>
899+
auto exit_opstate(CvSender&& sndr, OuterReceiver rcvr, context ctx) noexcept
900+
-> exit_opstate_t<CvSender, OuterReceiver>
888901
{
889-
return exit_opstate_t<Sender, OuterReceiver>(
890-
static_cast<Sender&&>(sndr),
902+
return exit_opstate_t<CvSender, OuterReceiver>(
903+
static_cast<CvSender&&>(sndr),
891904
static_cast<OuterReceiver&&>(rcvr),
892-
[](opstate_base<OuterReceiver>& op) -> propagate_receiver<OuterReceiver>
893-
{ return propagate_receiver<OuterReceiver>{{}, op}; },
905+
[](opstate_base<OuterReceiver>& op) noexcept
906+
{ return propagate_receiver<opstate_base<OuterReceiver>>(op); },
894907
ctx);
895908
}
896909

897-
template <class Sender, class E>
910+
template <class Sender, class Env>
898911
concept stream_completing_sender =
899912
sender<Sender>
900913
&& gpu_stream_scheduler<
901-
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, E>,
902-
E>;
914+
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, Env>,
915+
Env>;
903916

904917
template <class InnerReceiverProvider, class OuterReceiver>
905918
using inner_receiver_t = __call_result_t<InnerReceiverProvider, opstate_base<OuterReceiver>&>;
906919

907920
template <class CvSender, class InnerReceiver, class OuterReceiver>
908921
using stream_opstate_t = _strm::opstate<CvSender, InnerReceiver, OuterReceiver>;
909922

910-
template <class Sender, class OuterReceiver, class ReceiverProvider>
911-
requires stream_completing_sender<Sender, env_of_t<OuterReceiver>>
912-
auto
913-
stream_opstate(Sender&& sndr, OuterReceiver&& out_receiver, ReceiverProvider receiver_provider)
914-
-> stream_opstate_t<Sender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
923+
template <class CvSender, class OuterReceiver, class ReceiverProvider>
924+
requires stream_completing_sender<CvSender, env_of_t<OuterReceiver>>
925+
auto stream_opstate(CvSender&& sndr,
926+
OuterReceiver&& out_receiver,
927+
ReceiverProvider receiver_provider)
928+
-> stream_opstate_t<CvSender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
915929
{
916930
auto sch = get_completion_scheduler<set_value_t>(get_env(sndr), get_env(out_receiver));
917931
context ctx = sch.ctx_;
918932

919-
return stream_opstate_t<Sender,
933+
return stream_opstate_t<CvSender,
920934
inner_receiver_t<ReceiverProvider, OuterReceiver>,
921-
OuterReceiver>(static_cast<Sender&&>(sndr),
935+
OuterReceiver>(static_cast<CvSender&&>(sndr),
922936
static_cast<OuterReceiver&&>(out_receiver),
923937
receiver_provider,
924938
ctx);
925939
}
926940

927-
template <class Sender, class OuterReceiver, class ReceiverProvider>
928-
auto stream_opstate(Sender&& sndr,
941+
template <class CvSender, class OuterReceiver, class ReceiverProvider>
942+
auto stream_opstate(CvSender&& sndr,
929943
OuterReceiver&& out_receiver,
930944
ReceiverProvider receiver_provider,
931945
context ctx)
932-
-> stream_opstate_t<Sender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
946+
-> stream_opstate_t<CvSender, inner_receiver_t<ReceiverProvider, OuterReceiver>, OuterReceiver>
933947
{
934-
return stream_opstate_t<Sender,
948+
return stream_opstate_t<CvSender,
935949
inner_receiver_t<ReceiverProvider, OuterReceiver>,
936-
OuterReceiver>(static_cast<Sender&&>(sndr),
950+
OuterReceiver>(static_cast<CvSender&&>(sndr),
937951
static_cast<OuterReceiver&&>(out_receiver),
938952
receiver_provider,
939953
ctx);

0 commit comments

Comments
 (0)