Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/beman/execution/detail/bulk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct bulk_transform_signatures<IsChunked, F, Shape, ::beman::execution::comple
};

template <bool IsChunked>
struct bulk_algo_t : ::beman::execution::sender_adaptor_closure<bulk_algo_t<IsChunked>> {
struct bulk_algo_t {
template <typename Policy, typename Shape, typename F>
requires(::beman::execution::is_execution_policy_v<::std::remove_cvref_t<Policy>> && ::std::integral<Shape> &&
::std::copy_constructible<::std::decay_t<F>>)
Expand Down Expand Up @@ -190,7 +190,7 @@ using bulk_chunked_t = ::beman::execution::detail::bulk_algo_t<true>;

using bulk_unchunked_t = ::beman::execution::detail::bulk_algo_t<false>;

struct bulk_t : ::beman::execution::sender_adaptor_closure<bulk_t> {
struct bulk_t {
template <typename Policy, typename Shape, typename F>
requires(::beman::execution::is_execution_policy_v<::std::remove_cvref_t<Policy>> && ::std::integral<Shape> &&
::std::copy_constructible<::std::decay_t<F>>)
Expand Down
8 changes: 8 additions & 0 deletions include/beman/execution/detail/write_env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import beman.execution.detail.make_sender;
import beman.execution.detail.nested_sender_has_affine;
import beman.execution.detail.queryable;
import beman.execution.detail.sender;
import beman.execution.detail.sender_adaptor_closure;
#else
#include <beman/execution/detail/default_impls.hpp>
#include <beman/execution/detail/get_env.hpp>
Expand All @@ -33,6 +34,7 @@ import beman.execution.detail.sender;
#include <beman/execution/detail/nested_sender_has_affine.hpp>
#include <beman/execution/detail/queryable.hpp>
#include <beman/execution/detail/sender.hpp>
#include <beman/execution/detail/sender_adaptor_closure.hpp>
#endif

// ----------------------------------------------------------------------------
Expand All @@ -53,6 +55,12 @@ struct write_env_t {
return ::beman::execution::detail::make_sender(
*this, ::std::forward<Env>(env), ::std::forward<Sender>(sender));
}

template <::beman::execution::detail::queryable Env>
constexpr auto operator()(Env&& env) const {
return ::beman::execution::detail::make_sender_adaptor(*this, ::std::forward<Env>(env));
}

template <::beman::execution::sender Sender>
requires ::beman::execution::detail::nested_sender_has_affine<Sender>
static auto affine(Sender&& sndr) noexcept {
Expand Down
2 changes: 0 additions & 2 deletions src/beman/execution/basic_sender.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace beman::execution::detail {
export using beman::execution::detail::basic_sender;
} // namespace beman::execution::detail

#if defined(_MSC_VER) && _MSC_VER <= 1944L
namespace std {
template <typename Tag, typename Data, typename... Child>
struct tuple_size<::beman::execution::detail::basic_sender<Tag, Data, Child...>>
Expand All @@ -24,4 +23,3 @@ struct tuple_element<I, ::beman::execution::detail::basic_sender<T...>> {
::std::decay_t<decltype(::std::declval<::beman::execution::detail::basic_sender<T...>>().template get<I>())>;
};
} // namespace std
#endif
162 changes: 161 additions & 1 deletion tests/beman/execution/exec-bulk.test.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
// src/beman/execution/tests/exec-bulk.test.cpp -*-C++-*-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <version>
#include <algorithm>
#include <cstdlib>
#include <functional>
#include <numeric>
#include <ranges>
#include <stdexcept>
#include <tuple>
#include <vector>
#ifdef __cpp_lib_parallel_algorithm
#include <execution>
#endif
#include <test/execution.hpp>
#ifdef BEMAN_HAS_MODULES
import beman.execution;
import beman.execution.detail.meta.combine;
import beman.execution.detail.meta.unique;
#else
#include <beman/execution/execution.hpp>
#endif
Expand Down Expand Up @@ -369,6 +379,154 @@ auto test_bulk_multiple_values() {
ASSERT(sum_a == 30);
ASSERT(sum_b == 60 + 0 + 1 + 2);
}
#ifdef __cpp_lib_parallel_algorithm
template <bool IsChunked, typename Policy, typename Shape, typename Fn, typename Child>
struct pstl_for_each_sender {
using sender_concept = test_std::sender_tag;

template <typename Rcvr>
struct receiver {
using receiver_concept = test_std::receiver_tag;

template <typename... Args>
auto set_value(Args&&... args) noexcept -> void {
try {
auto iota = std::views::iota(Shape(0), shape);
std::for_each(policy, std::ranges::begin(iota), std::ranges::end(iota), [&](Shape i) {
if constexpr (IsChunked) {
std::invoke(fn, i, i + 1, args...);
} else {
std::invoke(fn, i, args...);
};
});
test_std::set_value(std::move(rcvr), std::forward<Args>(args)...);
} catch (...) {
test_std::set_error(std::move(rcvr), std::current_exception());
}
}

template <typename E>
auto set_error(E e) noexcept -> void {
test_std::set_error(std::move(rcvr), std::move(e));
}

auto set_stopped() noexcept -> void { test_std::set_stopped(std::move(rcvr)); }

auto get_env() const noexcept { return test_std::get_env(rcvr); }

Rcvr rcvr;
Policy policy;
Shape shape;
Fn fn;
};

template <typename, typename... Env>
static consteval auto get_completion_signatures() {
constexpr auto compl_sigs = test_std::get_completion_signatures<Child, Env...>();
return test_std::detail::meta::unique<test_std::detail::meta::combine<
std::remove_cvref_t<decltype(compl_sigs)>,
test_std::completion_signatures<test_std::set_error_t(std::exception_ptr)>>>{};
}

template <typename Rcvr>
auto connect(Rcvr rcvr) && noexcept {
return test_std::connect(child, receiver{std::move(rcvr), std::move(policy), std::move(shape), std::move(fn)});
}

auto get_env() const noexcept { return test_std::get_env(child); }

Policy policy;
Shape shape;
Fn fn;
Child child;
};

template <typename Tag>
struct pstl_domain {
template <typename Sndr, typename Env>
requires std::same_as<test_std::tag_of_t<Sndr>, test_std::bulk_chunked_t> ||
std::same_as<test_std::tag_of_t<Sndr>, test_std::bulk_unchunked_t>
static auto transform_sender(Tag, Sndr sndr, const Env&) {
auto [_, data, child] = std::move(sndr);
auto [policy, shape, fn] = std::move(data);
return pstl_for_each_sender<std::same_as<test_std::tag_of_t<Sndr>, test_std::bulk_chunked_t>,
decltype(policy),
decltype(shape),
decltype(fn),
decltype(child)>{
std::move(policy), std::move(shape), std::move(fn), std::move(child)};
}
};

struct pstl_env1 {
static auto query(test_std::get_domain_t) noexcept { return pstl_domain<test_std::start_t>{}; }
};

struct pstl_env2 {
template <typename Env>
static auto query(test_std::get_completion_domain_t<>, const Env&) noexcept {
return pstl_domain<test_std::set_value_t>{};
}
};

struct pstl_just_sender {
using sender_concept = test_std::sender_tag;

template <typename Rcvr>
struct state {
using operation_state_concept = test_std::operation_state_tag;

auto start() & noexcept -> void { test_std::set_value(std::move(rcvr)); }

Rcvr rcvr;
};

template <typename...>
static consteval auto get_completion_signatures() noexcept {
return test_std::completion_signatures<test_std::set_value_t()>{};
}

template <typename Rcvr>
auto connect(Rcvr rcvr) && noexcept {
return state<Rcvr>{std::move(rcvr)};
}

static auto get_env() noexcept { return pstl_env2{}; }
};

static_assert(test_std::sender<pstl_just_sender>);

auto test_bulk_customization() {
{
// starting-domain customization
std::vector<int> vec(32);
std::ranges::iota(vec, 1);
std::vector<int> result(vec.size());
test_std::sync_wait(
test_std::just() |
test_std::bulk(test_std::par, vec.size(), [&](std::size_t i) noexcept { result[i] = vec[i] * vec[i]; }) |
test_std::write_env(pstl_env1{}));

for (std::size_t i = 0; i < vec.size(); ++i) {
ASSERT(result[i] == vec[i] * vec[i]);
}
}
{
// completing-domain customization
test::use_type<pstl_just_sender::sender_concept>();
std::vector<int> vec(32);
std::ranges::iota(vec, 1);
std::vector<int> result(vec.size());
test_std::sync_wait(
pstl_just_sender{} |
test_std::bulk(test_std::par, vec.size(), [&](std::size_t i) noexcept { result[i] = vec[i] * 2; }));

for (std::size_t i = 0; i < vec.size(); ++i) {
ASSERT(result[i] == vec[i] * 2);
}
}
}
#endif

} // namespace

Expand All @@ -395,7 +553,9 @@ TEST(exec_bulk) {
test_bulk_shape_one();
test_bulk_chunked_covers_full_range();
test_bulk_multiple_values();

#ifdef __cpp_lib_parallel_algorithm
test_bulk_customization();
#endif
} catch (...) {

ASSERT(nullptr == +"the bulk tests shouldn't throw");
Expand Down
Loading