mptensor v0.4.0
Parallel Library for Tensor Network Methods
Loading...
Searching...
No Matches
mpi_wrapper.hpp
Go to the documentation of this file.
1/*
2 mptensor - Parallel Library for Tensor Network Methods
3
4 Copyright 2016 Satoshi Morita
5
6 mptensor is free software: you can redistribute it and/or modify it
7 under the terms of the GNU Lesser General Public License as
8 published by the Free Software Foundation, either version 3 of the
9 License, or (at your option) any later version.
10
11 mptensor is distributed in the hope that it will be useful, but
12 WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15
16 You should have received a copy of the GNU Lesser General Public
17 License along with mptensor. If not, see
18 <https://www.gnu.org/licenses/>.
19*/
20
31#ifndef _MPI_WRAPPER_HPP_
32#define _MPI_WRAPPER_HPP_
33
34#ifndef _NO_MPI
35
36#include <mpi.h>
37
38#include <vector>
39
40namespace mptensor {
41
43namespace mpi_wrapper {
44
46template <typename C>
48
49template <>
51 return MPI_CHAR;
52};
53template <>
57template <>
61template <>
63 return MPI_SHORT;
64};
65template <>
69template <>
71 return MPI_INT;
72};
73template <>
77template <>
79 return MPI_LONG;
80};
81template <>
85template <>
89template <>
93template <>
95 return MPI_DOUBLE;
96};
97template <>
101
103
109template <typename C>
110inline C allreduce_sum(C val, const MPI_Comm &comm) {
111 C recv;
113 return recv;
114};
115
117
123template <typename C>
124inline std::vector<C> allreduce_vec(const std::vector<C> &vec,
125 const MPI_Comm &comm) {
126 size_t n = vec.size();
127 std::vector<C> recv(n);
128 MPI_Allreduce(const_cast<C *>(&(vec[0])), &(recv[0]), static_cast<int>(n),
129 mpi_datatype<C>(), MPI_SUM, comm);
130 return recv;
131};
132
134
141template <typename C>
142inline void allreduce(const C *sendbuf, C *recvbuf, int count, MPI_Op op,
143 const MPI_Comm &comm) {
145 comm);
146};
147
149
160template <typename C>
161inline void sendrecv(const C *sendbuf, int sendcount, int dest, int sendtag,
162 C *recvbuf, int recvcount, int source, int recvtag,
163 const MPI_Comm &comm) {
166 comm, MPI_STATUS_IGNORE);
167};
168
170
179template <typename C>
180inline void sendrecv(const std::vector<C> &send_vec, int dest, int sendtag,
181 std::vector<C> &recv_vec, int source, int recvtag,
182 const MPI_Comm &comm) {
183 MPI_Sendrecv(const_cast<C *>(&(send_vec[0])),
184 static_cast<int>(send_vec.size()), mpi_datatype<C>(), dest,
185 sendtag, &(recv_vec[0]), static_cast<int>(recv_vec.size()),
187};
188
190
197template <typename C>
198inline void alltoall(const C *sendbuf, int sendcount, C *recvbuf, int recvcount,
199 const MPI_Comm &comm) {
201 recvcount, mpi_datatype<C>(), comm);
202};
203
205
219template <typename C>
220inline void alltoallv(const C *sendbuf, const int *sendcounts,
221 const int *sdispls, C *recvbuf, const int *recvcounts,
222 const int *rdispls, const MPI_Comm &comm) {
223 MPI_Alltoallv(const_cast<C *>(sendbuf), const_cast<int *>(sendcounts),
224 const_cast<int *>(sdispls), mpi_datatype<C>(), recvbuf,
225 const_cast<int *>(recvcounts), const_cast<int *>(rdispls),
226 mpi_datatype<C>(), comm);
227};
228
230
236template <typename C>
237inline void bcast(C *buffer, int count, int root, const MPI_Comm &comm) {
239};
240
241} // namespace mpi_wrapper
242} // namespace mptensor
243
244#endif // _NO_MPI
245#endif // _MPI_WRAPPER_HPP_
std::complex< double > complex
Definition complex.hpp:38
void sendrecv(const C *sendbuf, int sendcount, int dest, int sendtag, C *recvbuf, int recvcount, int source, int recvtag, const MPI_Comm &comm)
Wrapper of MPI_Sendrecv.
Definition mpi_wrapper.hpp:161
MPI_Datatype mpi_datatype< unsigned long long int >()
Definition mpi_wrapper.hpp:90
void allreduce(const C *sendbuf, C *recvbuf, int count, MPI_Op op, const MPI_Comm &comm)
Wrapper of MPI_Allreduce.
Definition mpi_wrapper.hpp:142
MPI_Datatype mpi_datatype< int >()
Definition mpi_wrapper.hpp:70
MPI_Datatype mpi_datatype< unsigned char >()
Definition mpi_wrapper.hpp:58
MPI_Datatype mpi_datatype< unsigned short >()
Definition mpi_wrapper.hpp:66
MPI_Datatype mpi_datatype< long long int >()
Definition mpi_wrapper.hpp:86
MPI_Datatype mpi_datatype< double >()
Definition mpi_wrapper.hpp:94
MPI_Datatype mpi_datatype< signed char >()
Definition mpi_wrapper.hpp:54
MPI_Datatype mpi_datatype< short >()
Definition mpi_wrapper.hpp:62
MPI_Datatype mpi_datatype< unsigned int >()
Definition mpi_wrapper.hpp:74
void bcast(C *buffer, int count, int root, const MPI_Comm &comm)
Wrapper of MPI_Bcast.
Definition mpi_wrapper.hpp:237
void alltoallv(const C *sendbuf, const int *sendcounts, const int *sdispls, C *recvbuf, const int *recvcounts, const int *rdispls, const MPI_Comm &comm)
Wrapper of MPI_Alltoallv.
Definition mpi_wrapper.hpp:220
void alltoall(const C *sendbuf, int sendcount, C *recvbuf, int recvcount, const MPI_Comm &comm)
Wrapper of MPI_Alltoall.
Definition mpi_wrapper.hpp:198
MPI_Datatype mpi_datatype< char >()
Definition mpi_wrapper.hpp:50
MPI_Datatype mpi_datatype< long int >()
Definition mpi_wrapper.hpp:78
MPI_Datatype mpi_datatype< complex >()
Definition mpi_wrapper.hpp:98
C allreduce_sum(C val, const MPI_Comm &comm)
Calculate a summation over MPI communicator.
Definition mpi_wrapper.hpp:110
MPI_Datatype mpi_datatype< unsigned long int >()
Definition mpi_wrapper.hpp:82
MPI_Datatype mpi_datatype()
Template function for MPI Datatype.
std::vector< C > allreduce_vec(const std::vector< C > &vec, const MPI_Comm &comm)
Calculate a summation of each element of vector over MPI communicator.
Definition mpi_wrapper.hpp:124
Definition complex.hpp:34