Medial Code Documentation
Loading...
Searching...
No Matches
socket.h
Go to the documentation of this file.
1
6#ifndef RABIT_INTERNAL_SOCKET_H_
7#define RABIT_INTERNAL_SOCKET_H_
8#include "xgboost/collective/socket.h"
9
10#if defined(_WIN32)
11#include <winsock2.h>
12#include <ws2tcpip.h>
13
14#else
15
16#include <arpa/inet.h>
17#include <fcntl.h>
18#include <netdb.h>
19#include <netinet/in.h>
20#include <sys/ioctl.h>
21#include <sys/socket.h>
22#include <unistd.h>
23
24#include <cerrno>
25
26#endif // defined(_WIN32)
27
28#include <chrono>
29#include <cstring>
30#include <string>
31#include <unordered_map>
32#include <vector>
33
34#include "utils.h"
35
36#if !defined(_WIN32)
37
38#include <sys/poll.h>
39
40using SOCKET = int;
41using sock_size_t = size_t; // NOLINT
42#endif // !defined(_WIN32)
43
44#define IS_MINGW() defined(__MINGW32__)
45
46#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
47/*
48 * On later mingw versions poll should be supported (with bugs). See:
49 * https://stackoverflow.com/a/60623080
50 *
51 * But right now the mingw distributed with R 3.6 doesn't support it.
52 * So we just give a warning and provide dummy implementation to get
53 * compilation passed. Otherwise we will have to provide a stub for
54 * RABIT.
55 *
56 * Even on mingw version that has these structures and flags defined,
57 * functions like `send` and `listen` might have unresolved linkage to
58 * their implementation. So supporting mingw is quite difficult at
59 * the time of writing.
60 */
61#pragma message("Distributed training on mingw is not supported.")
62typedef struct pollfd {
63 SOCKET fd;
64 short events;
65 short revents;
66} WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
67
68// POLLRDNORM | POLLRDBAND
69#define POLLIN (0x0100 | 0x0200)
70#define POLLPRI 0x0400
71// POLLWRNORM
72#define POLLOUT 0x0010
73
74#endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
75
76namespace rabit {
77namespace utils {
78
79template <typename PollFD>
80int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
81#if defined(_WIN32)
82
83#if IS_MINGW()
84 xgboost::MingWError();
85 return -1;
86#else
87 return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
88#endif // IS_MINGW()
89
90#else
91 return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
92#endif // IS_MINGW()
93}
94
96struct PollHelper {
97 public:
102 inline void WatchRead(SOCKET fd) {
103 auto& pfd = fds[fd];
104 pfd.fd = fd;
105 pfd.events |= POLLIN;
106 }
107 void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
108
113 inline void WatchWrite(SOCKET fd) {
114 auto& pfd = fds[fd];
115 pfd.fd = fd;
116 pfd.events |= POLLOUT;
117 }
118 void WatchWrite(xgboost::collective::TCPSocket const &socket) {
119 this->WatchWrite(socket.Handle());
120 }
121
126 inline void WatchException(SOCKET fd) {
127 auto& pfd = fds[fd];
128 pfd.fd = fd;
129 pfd.events |= POLLPRI;
130 }
132 this->WatchException(socket.Handle());
133 }
138 inline bool CheckRead(SOCKET fd) const {
139 const auto& pfd = fds.find(fd);
140 return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
141 }
142 bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
143 return this->CheckRead(socket.Handle());
144 }
145
150 inline bool CheckWrite(SOCKET fd) const {
151 const auto& pfd = fds.find(fd);
152 return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
153 }
154 bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
155 return this->CheckWrite(socket.Handle());
156 }
162 inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
163 std::vector<pollfd> fdset;
164 fdset.reserve(fds.size());
165 for (auto kv : fds) {
166 fdset.push_back(kv.second);
167 }
168 int ret = PollImpl(fdset.data(), fdset.size(), timeout);
169 if (ret == 0) {
170 LOG(FATAL) << "Poll timeout";
171 } else if (ret < 0) {
172 LOG(FATAL) << "Failed to poll.";
173 } else {
174 for (auto& pfd : fdset) {
175 auto revents = pfd.revents & pfd.events;
176 if (!revents) {
177 fds.erase(pfd.fd);
178 } else {
179 fds[pfd.fd].events = revents;
180 }
181 }
182 }
183 }
184
185 std::unordered_map<SOCKET, pollfd> fds;
186};
187} // namespace utils
188} // namespace rabit
189
190#if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
191#undef POLLIN
192#undef POLLPRI
193#undef POLLOUT
194#endif // IS_MINGW()
195
196#endif // RABIT_INTERNAL_SOCKET_H_
TCP socket for simple communication.
Definition socket.h:249
HandleT const & Handle() const
Return the native socket file descriptor.
Definition socket.h:397
namespace of rabit
Definition engine.h:18
helper data structure to perform poll
Definition socket.h:96
void WatchException(SOCKET fd)
add file descriptor to watch for exception
Definition socket.h:126
bool CheckRead(SOCKET fd) const
Check if the descriptor is ready for read.
Definition socket.h:138
bool CheckWrite(SOCKET fd) const
Check if the descriptor is ready for write.
Definition socket.h:150
void WatchWrite(SOCKET fd)
add file descriptor to watch for write
Definition socket.h:113
void Poll(std::chrono::seconds timeout)
perform poll on the set defined, read, write, exception
Definition socket.h:162
void WatchRead(SOCKET fd)
add file descriptor to watch for read
Definition socket.h:102
simple utils to support the code