Skip to content

Commit

Permalink
Merge pull request #1117 from NVIDIA/correct-domain-of-on
Browse files Browse the repository at this point in the history
`on` should use the input scheduler's domain
  • Loading branch information
ericniebler authored Oct 20, 2023
2 parents 9d84acb + ed39902 commit 089ddcb
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions include/stdexec/execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3098,7 +3098,8 @@ namespace stdexec {
completion_signatures<
__minvoke< __remove<void, __qf<set_value_t>>, std::invoke_result_t<_Fun, _Args...>>>;

struct __get_env_fn {
struct __default_get_env_fn {
// BUGBUG This should hide all but the forwarding queries
template <class _Sender>
env_of_t<_Sender> operator()(__ignore, __ignore, const _Sender& __child) const noexcept {
return get_env(__child);
Expand All @@ -3112,11 +3113,11 @@ namespace stdexec {
};

template <class _Tag>
struct __default_get_env {
struct __with_default_get_env {
template <sender_expr_for<_Tag> _Sender>
static auto get_env(const _Sender& __sndr) noexcept
-> __call_result_t<apply_sender_t, const _Sender&, __get_env_fn> {
return apply_sender(__sndr, __get_env_fn());
-> __call_result_t<apply_sender_t, const _Sender&, __default_get_env_fn> {
return apply_sender(__sndr, __default_get_env_fn());
}
};

Expand Down Expand Up @@ -3212,7 +3213,7 @@ namespace stdexec {
};

////////////////////////////////////////////////////////////////////////////////////////////////
struct then_t : __default_get_env<then_t> {
struct then_t : __with_default_get_env<then_t> {
template <sender _Sender, __movable_value _Fun>
auto operator()(_Sender&& __sndr, _Fun __fun) const {
auto __domain = __get_sender_domain((_Sender&&) __sndr);
Expand Down Expand Up @@ -3353,7 +3354,7 @@ namespace stdexec {
};

////////////////////////////////////////////////////////////////////////////////////////////////
struct upon_error_t : __default_get_env<upon_error_t> {
struct upon_error_t : __with_default_get_env<upon_error_t> {
template <sender _Sender, __movable_value _Fun>
auto operator()(_Sender&& __sndr, _Fun __fun) const {
auto __domain = __get_sender_domain((_Sender&&) __sndr, set_error);
Expand Down Expand Up @@ -3495,7 +3496,7 @@ namespace stdexec {
};

////////////////////////////////////////////////////////////////////////////////////////////////
struct upon_stopped_t : __default_get_env<upon_stopped_t> {
struct upon_stopped_t : __with_default_get_env<upon_stopped_t> {
template <sender _Sender, __movable_value _Fun>
requires __callable<_Fun>
auto operator()(_Sender&& __sndr, _Fun __fun) const {
Expand Down Expand Up @@ -3688,7 +3689,7 @@ namespace stdexec {
}
};

struct bulk_t : __default_get_env<bulk_t> {
struct bulk_t : __with_default_get_env<bulk_t> {
template <sender _Sender, integral _Shape, __movable_value _Fun>
STDEXEC_ATTRIBUTE((host, device))
auto operator()(_Sender&& __sndr, _Shape __shape, _Fun __fun) const {
Expand Down Expand Up @@ -4742,7 +4743,7 @@ namespace stdexec {
};

template <class _LetTag, class _SetTag>
struct __let_xxx_t : __default_get_env<_LetTag> {
struct __let_xxx_t : __with_default_get_env<_LetTag> {
using _Sender = __1;
using _Function = __0;
using __legacy_customizations_t = __types<
Expand Down Expand Up @@ -5561,6 +5562,26 @@ namespace stdexec {
};
}

template <class _Env>
auto __make_transform_fn(const _Env& __env) {
return [&]<class _Scheduler, class... _Values>(_Scheduler&& __sched, _Values&&... __vals) {
auto __domain = __get_env_domain(__env);
return stdexec::transform_sender(
__domain,
transfer(
stdexec::transform_sender(__domain, just((_Values&&) __vals...), __env),
(_Scheduler&&) __sched),
__env);
};
}

template <class _Env>
auto __transform_sender_fn(const _Env& __env) {
return [&]<class _Data>(__ignore, _Data&& __data) {
return __apply(__make_transform_fn(__env), (_Data&&) __data);
};
}

struct transfer_just_t {
using _Data = __0;
using __legacy_customizations_t = //
Expand All @@ -5582,17 +5603,7 @@ namespace stdexec {

template <class _Sender, class _Env>
static auto transform_sender(_Sender&& __sndr, const _Env& __env) {
return __apply(
[&]<class _Scheduler, class... _Values>(_Scheduler&& __sched, _Values&&... __vals) {
auto __domain = __get_env_domain(__env);
return stdexec::transform_sender(
__domain,
transfer(
stdexec::transform_sender(__domain, just((_Values&&) __vals...), __env),
(_Scheduler&&) __sched),
__env);
},
apply_sender((_Sender&&) __sndr, __detail::__get_data()));
return apply_sender((_Sender&&) __sndr, __transform_sender_fn(__env));
}
};
} // namespace __transfer_just
Expand Down Expand Up @@ -5636,7 +5647,7 @@ namespace stdexec {
}
};

struct __write_t : __default_get_env<__write_t> {
struct __write_t : __with_default_get_env<__write_t> {
template <sender _Sender, class... _Envs>
auto operator()(_Sender&& __sndr, _Envs... __envs) const {
auto __domain = __get_sender_domain(__sndr);
Expand Down Expand Up @@ -5854,7 +5865,7 @@ namespace stdexec {

template <scheduler _Scheduler, sender _Sender>
auto operator()(_Scheduler&& __sched, _Sender&& __sndr) const {
auto __domain = __get_sender_domain((_Sender&&) __sndr);
auto __domain = query_or(get_domain, __sched, default_domain());
return stdexec::transform_sender(
__domain, make_sender_expr<on_t>((_Scheduler&&) __sched, (_Sender&&) __sndr));
}
Expand All @@ -5865,8 +5876,8 @@ namespace stdexec {

template <sender_expr_for<on_t> _Sender>
static auto get_env(const _Sender& __sndr) noexcept
-> __call_result_t<apply_sender_t, const _Sender&, __get_env_fn> {
return apply_sender(__sndr, __get_env_fn());
-> __call_result_t<apply_sender_t, const _Sender&, __default_get_env_fn> {
return apply_sender(__sndr, __default_get_env_fn());
}

template <class _Sender, class _Env>
Expand Down Expand Up @@ -5960,7 +5971,7 @@ namespace stdexec {
__meval<__variant_completions, __variant_t<_Sender, _Env>>,
__mconst<completion_signatures<>>>;

struct into_variant_t : __default_get_env<into_variant_t> {
struct into_variant_t : __with_default_get_env<into_variant_t> {
template <sender _Sender>
auto operator()(_Sender&& __sndr) const {
auto __domain = __get_sender_domain(__sndr);
Expand Down Expand Up @@ -6832,7 +6843,7 @@ namespace stdexec {
};

struct on_t
: __default_get_env<on_t>
: __with_default_get_env<on_t>
, __no_scheduler_in_environment {
template <scheduler _Scheduler, sender _Sender>
auto operator()(_Scheduler&& __sched, _Sender&& __sndr) const {
Expand Down Expand Up @@ -6871,7 +6882,7 @@ namespace stdexec {
__continue_on_data(_Scheduler, _Closure) -> __continue_on_data<_Scheduler, _Closure>;

struct continue_on_t
: __default_get_env<continue_on_t>
: __with_default_get_env<continue_on_t>
, __no_scheduler_in_environment {
template <sender _Sender, scheduler _Scheduler, __sender_adaptor_closure_for<_Sender> _Closure>
auto operator()(_Sender&& __sndr, _Scheduler&& __sched, _Closure&& __clsur) const {
Expand Down

0 comments on commit 089ddcb

Please sign in to comment.