Skip to content

Commit 6024d85

Browse files
committed
ensure that displacements must be given in bytes
This is to ensure that the api behaves consistently, i.e., there should be no method expecting displacements that are given in the unit size of some date type.
1 parent 29a7963 commit 6024d85

File tree

5 files changed

+40
-55
lines changed

5 files changed

+40
-55
lines changed

mpl/comm_group.hpp

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,13 @@ namespace mpl {
411411
std::vector<int> counts;
412412
counts.reserve(layouts.size());
413413
std::transform(layouts.begin(), layouts.end(), std::back_inserter(counts),
414-
[](const auto &layout) {
415-
return static_cast<int>(layout.size());
416-
});
414+
[](const auto &layout) { return static_cast<int>(layout.size()); });
417415
return counts;
418-
}
416+
}
419417

420418
template<typename T>
421-
std::vector<int> count_displacements_as_vector_of_ints(const displacements &displs) const {
419+
std::vector<int> count_displacements_as_vector_of_ints(
420+
const displacements &displs) const {
422421
std::vector<int> displs_as_int;
423422
displs_as_int.reserve(displs.size());
424423
std::transform(displs.begin(), displs.end(), std::back_inserter(displs_as_int),
@@ -2698,13 +2697,10 @@ namespace mpl {
26982697
check_root(root_rank);
26992698
check_size(recvls);
27002699
check_size(recvdispls);
2701-
std::vector<int> recvcounts;
2702-
recvcounts.reserve(recvls.size());
2703-
for (const auto &layout : recvls)
2704-
recvcounts.push_back(static_cast<int>(layout.size()));
2705-
const std::vector<int> displs(recvdispls.begin(), recvdispls.end());
2700+
const auto recvcounts{sizes_as_vector_of_ints(recvls)};
2701+
const auto reacvdispls_int(count_displacements_as_vector_of_ints<T>(recvdispls));
27062702
MPI_Gatherv(send_data, sendl.size(), detail::datatype_traits<T>::get_datatype(),
2707-
recv_data, recvcounts.data(), displs.data(),
2703+
recv_data, recvcounts.data(), reacvdispls_int.data(),
27082704
detail::datatype_traits<T>::get_datatype(), root_rank, comm_);
27092705
}
27102706

@@ -2780,14 +2776,11 @@ namespace mpl {
27802776
check_root(root_rank);
27812777
check_size(recvls);
27822778
check_size(recvdispls);
2783-
std::vector<int> recvcounts;
2784-
recvcounts.reserve(recvls.size());
2785-
for (const auto &layout : recvls)
2786-
recvcounts.push_back(static_cast<int>(layout.size()));
2787-
const std::vector<int> displs(recvdispls.begin(), recvdispls.end());
2779+
const auto recvcounts{sizes_as_vector_of_ints(recvls)};
2780+
const auto recvdispls_int(count_displacements_as_vector_of_ints<T>(recvdispls));
27882781
MPI_Request req;
27892782
MPI_Igatherv(send_data, sendl.size(), detail::datatype_traits<T>::get_datatype(),
2790-
recv_data, recvcounts.data(), displs.data(),
2783+
recv_data, recvcounts.data(), recvdispls_int.data(),
27912784
detail::datatype_traits<T>::get_datatype(), root_rank, comm_, &req);
27922785
return base_irequest{req};
27932786
}
@@ -3004,13 +2997,11 @@ namespace mpl {
30042997
const displacements &recvdispls) const {
30052998
check_size(recvls);
30062999
check_size(recvdispls);
3007-
std::vector<int> recvcounts;
3008-
recvcounts.reserve(recvls.size());
3009-
for (const auto &layout : recvls)
3010-
recvcounts.push_back(static_cast<int>(layout.size()));
3000+
const auto recvcounts{sizes_as_vector_of_ints(recvls)};
3001+
const auto recvdispls_int(count_displacements_as_vector_of_ints<T>(recvdispls));
30113002
const std::vector<int> displs(recvdispls.begin(), recvdispls.end());
30123003
MPI_Allgatherv(send_data, sendl.size(), detail::datatype_traits<T>::get_datatype(),
3013-
recv_data, recvcounts.data(), displs.data(),
3004+
recv_data, recvcounts.data(), recvdispls_int.data(),
30143005
detail::datatype_traits<T>::get_datatype(), comm_);
30153006
}
30163007

@@ -3073,14 +3064,11 @@ namespace mpl {
30733064
const displacements &recvdispls) const {
30743065
check_size(recvls);
30753066
check_size(recvdispls);
3076-
std::vector<int> recvcounts;
3077-
recvcounts.reserve(recvls.size());
3078-
for (const auto &layout : recvls)
3079-
recvcounts.push_back(static_cast<int>(layout.size()));
3080-
const std::vector<int> displs(recvdispls.begin(), recvdispls.end());
3067+
const auto recvcounts{sizes_as_vector_of_ints(recvls)};
3068+
const auto recvdispls_int(count_displacements_as_vector_of_ints<T>(recvdispls));
30813069
MPI_Request req;
30823070
MPI_Iallgatherv(send_data, sendl.size(), detail::datatype_traits<T>::get_datatype(),
3083-
recv_data, recvcounts.data(), displs.data(),
3071+
recv_data, recvcounts.data(), recvdispls_int.data(),
30843072
detail::datatype_traits<T>::get_datatype(), comm_, &req);
30853073
return base_irequest{req};
30863074
}
@@ -3304,12 +3292,10 @@ namespace mpl {
33043292
check_root(root_rank);
33053293
check_size(sendls);
33063294
check_size(senddispls);
3307-
std::vector<int> sendcounts;
3308-
sendcounts.reserve(sendls.size());
3309-
for (const auto &layout : sendls)
3310-
sendcounts.push_back(static_cast<int>(layout.size()));
3295+
const auto sendcounts{sizes_as_vector_of_ints(sendls)};
3296+
const auto senddispls_int(count_displacements_as_vector_of_ints<T>(senddispls));
33113297
const std::vector<int> displs(senddispls.begin(), senddispls.end());
3312-
MPI_Scatterv(send_data, sendcounts.data(), displs.data(),
3298+
MPI_Scatterv(send_data, sendcounts.data(), senddispls_int.data(),
33133299
detail::datatype_traits<T>::get_datatype(), recv_data, recvl.size(),
33143300
detail::datatype_traits<T>::get_datatype(), root_rank, comm_);
33153301
}
@@ -3387,14 +3373,10 @@ namespace mpl {
33873373
check_root(root_rank);
33883374
check_size(sendls);
33893375
check_size(senddispls);
3390-
3391-
std::vector<int> sendcounts;
3392-
sendcounts.reserve(sendls.size());
3393-
for (const auto &layout : sendls)
3394-
sendcounts.push_back(static_cast<int>(layout.size()));
3395-
const std::vector<int> displs(senddispls.begin(), senddispls.end());
3376+
const auto sendcounts{sizes_as_vector_of_ints(sendls)};
3377+
const auto senddispls_int(count_displacements_as_vector_of_ints<T>(senddispls));
33963378
MPI_Request req;
3397-
MPI_Iscatterv(send_data, sendcounts.data(), displs.data(),
3379+
MPI_Iscatterv(send_data, sendcounts.data(), senddispls_int.data(),
33983380
detail::datatype_traits<T>::get_datatype(), recv_data, recvl.size(),
33993381
detail::datatype_traits<T>::get_datatype(), root_rank, comm_, &req);
34003382
return base_irequest{req};
@@ -3613,8 +3595,8 @@ namespace mpl {
36133595
check_size(recvdispls);
36143596
check_size(recvls);
36153597
const std::vector<int> counts(recvls.size(), 1);
3616-
const std::vector<int> senddispls_int(senddispls.begin(), senddispls.end());
3617-
const std::vector<int> recvdispls_int(recvdispls.begin(), recvdispls.end());
3598+
const auto senddispls_int{byte_displacements_as_vector_of_ints(senddispls)};
3599+
const auto recvdispls_int{byte_displacements_as_vector_of_ints(recvdispls)};
36183600
static_assert(
36193601
sizeof(decltype(*sendls())) == sizeof(MPI_Datatype),
36203602
"compiler adds some unexpected padding, reinterpret cast will yield wrong results");
@@ -3841,7 +3823,8 @@ namespace mpl {
38413823
template<typename T>
38423824
irequest ialltoallv(const T *send_data, const contiguous_layouts<T> &sendls,
38433825
const displacements &senddispls, T *recv_data,
3844-
const contiguous_layouts<T> &recvls, const displacements &recvdispls) const {
3826+
const contiguous_layouts<T> &recvls,
3827+
const displacements &recvdispls) const {
38453828
check_size(senddispls);
38463829
check_size(sendls);
38473830
check_size(recvdispls);

mpl/displacements.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
namespace mpl {
1111

12-
/// Indicates the beginning of data buffers in various collective communication
13-
/// operations.
12+
/// Set of %displacements indicates the beginning of data buffers in various collective
13+
/// communication operations.
14+
/// \note Individual %displacements are always given in bytes.
1415
class displacements : private std::vector<MPI_Aint> {
1516
using base = std::vector<MPI_Aint>;
1617

@@ -48,7 +49,7 @@ namespace mpl {
4849
using base::resize;
4950

5051
/// Get raw displacement data.
51-
/// \return pointer to array of displacements
52+
/// \return pointer to the array of displacements
5253
const MPI_Aint *operator()() const {
5354
return base::data();
5455
}

test/test_communicator_allgatherv.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ bool allgatherv_contiguous_test(const T &val) {
3535
mpl::displacements displacements;
3636
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
3737
l.push_back(mpl::contiguous_layout<T>(i + 1));
38-
displacements.push_back(offset);
38+
displacements.push_back(sizeof(T) * offset);
3939
offset += i + 1;
4040
}
4141
const auto rank{comm_world.rank()};
42-
comm_world.allgatherv(v1.data() + displacements[rank], l[rank], v2.data(), l, displacements);
42+
comm_world.allgatherv(v1.data() + displacements[rank] / sizeof(T), l[rank], v2.data(), l,
43+
displacements);
4344
return v1 == v2;
4445
}
4546

@@ -73,12 +74,12 @@ bool iallgatherv_contiguous_test(const T &val) {
7374
mpl::displacements displacements;
7475
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
7576
l.push_back(mpl::contiguous_layout<T>(i + 1));
76-
displacements.push_back(offset);
77+
displacements.push_back(sizeof(T) * offset);
7778
offset += i + 1;
7879
}
7980
const auto rank{comm_world.rank()};
80-
auto r{comm_world.iallgatherv(v1.data() + displacements[rank], l[rank], v2.data(), l,
81-
displacements)};
81+
auto r{comm_world.iallgatherv(v1.data() + displacements[rank] / sizeof(T), l[rank], v2.data(),
82+
l, displacements)};
8283
r.wait();
8384
return v1 == v2;
8485
}

test/test_communicator_gatherv.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ bool gatherv_contiguous_test(const T &val) {
4949
mpl::displacements displacements;
5050
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
5151
layouts.push_back(mpl::contiguous_layout<T>(i + 1));
52-
displacements.push_back(offset);
52+
displacements.push_back(sizeof(T) * offset);
5353
offset += i + 1;
5454
}
5555
T t_val{val};
@@ -117,7 +117,7 @@ bool igatherv_contiguous_test(const T &val) {
117117
mpl::displacements displacements;
118118
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
119119
layouts.push_back(mpl::contiguous_layout<T>(i + 1));
120-
displacements.push_back(offset);
120+
displacements.push_back(sizeof(T) * offset);
121121
offset += i + 1;
122122
}
123123
T t_val{val};

test/test_communicator_scatterv.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ bool scatterv_contiguous_test(const T &val) {
4949
mpl::displacements displacements;
5050
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
5151
layouts.push_back(mpl::contiguous_layout<T>(i + 1));
52-
displacements.push_back(offset);
52+
displacements.push_back(sizeof(T) * offset);
5353
offset += i + 1;
5454
}
5555
T t_val{val};
@@ -117,7 +117,7 @@ bool iscatterv_contiguous_test(const T &val) {
117117
mpl::displacements displacements;
118118
for (int i{0}, i_end{comm_world.size()}, offset{0}; i < i_end; ++i) {
119119
layouts.push_back(mpl::contiguous_layout<T>(i + 1));
120-
displacements.push_back(offset);
120+
displacements.push_back(sizeof(T) * offset);
121121
offset += i + 1;
122122
}
123123
T t_val{val};

0 commit comments

Comments
 (0)