Medial Code Documentation
Loading...
Searching...
No Matches
socket.h
1
4#pragma once
5
6#if !defined(NOMINMAX) && defined(_WIN32)
7#define NOMINMAX
8#endif // !defined(NOMINMAX)
9
10#include <cerrno> // errno, EINTR, EBADF
11#include <climits> // HOST_NAME_MAX
12#include <cstddef> // std::size_t
13#include <cstdint> // std::int32_t, std::uint16_t
14#include <cstring> // memset
15#include <limits> // std::numeric_limits
16#include <string> // std::string
17#include <system_error> // std::error_code, std::system_category
18#include <utility> // std::swap
19
20#if !defined(xgboost_IS_MINGW)
21
22#if defined(__MINGW32__)
23#define xgboost_IS_MINGW 1
24#endif // defined(__MINGW32__)
25
26#endif // xgboost_IS_MINGW
27
28#if defined(_WIN32)
29
30#include <winsock2.h>
31#include <ws2tcpip.h>
32
33using in_port_t = std::uint16_t;
34
35#ifdef _MSC_VER
36#pragma comment(lib, "Ws2_32.lib")
37#endif // _MSC_VER
38
39#if !defined(xgboost_IS_MINGW)
40using ssize_t = int;
41#endif // !xgboost_IS_MINGW()
42
43#else // UNIX
44
45#include <arpa/inet.h> // inet_ntop
46#include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
47#include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
48#include <netinet/in.h> // IPPROTO_TCP
49#include <netinet/tcp.h> // TCP_NODELAY
50#include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
51#include <unistd.h> // close
52
53#if defined(__sun) || defined(sun)
54#include <sys/sockio.h>
55#endif // defined(__sun) || defined(sun)
56
57#endif // defined(_WIN32)
58
59#include "xgboost/base.h" // XGBOOST_EXPECT
60#include "xgboost/logging.h" // LOG
61#include "xgboost/string_view.h" // StringView
62
63#if !defined(HOST_NAME_MAX)
64#define HOST_NAME_MAX 256 // macos
65#endif
66
67namespace xgboost {
68
69#if defined(xgboost_IS_MINGW)
70// see the dummy implementation of `poll` in rabit for more info.
71inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
72#endif // defined(xgboost_IS_MINGW)
73
74namespace system {
75inline std::int32_t LastError() {
76#if defined(_WIN32)
77 return WSAGetLastError();
78#else
79 int errsv = errno;
80 return errsv;
81#endif
82}
83
84#if defined(__GLIBC__)
85inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
86 std::int32_t line = __builtin_LINE(),
87 char const *file = __builtin_FILE()) {
88 auto err = std::error_code{errsv, std::system_category()};
89 LOG(FATAL) << "\n"
90 << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
91 << std::endl;
92}
93#else
94inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
95 auto err = std::error_code{errsv, std::system_category()};
96 LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
97}
98#endif // defined(__GLIBC__)
99
100#if defined(_WIN32)
101using SocketT = SOCKET;
102#else
103using SocketT = int;
104#endif // defined(_WIN32)
105
106#if !defined(xgboost_CHECK_SYS_CALL)
107#define xgboost_CHECK_SYS_CALL(exp, expected) \
108 do { \
109 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
110 ::xgboost::system::ThrowAtError(#exp); \
111 } \
112 } while (false)
113#endif // !defined(xgboost_CHECK_SYS_CALL)
114
115inline std::int32_t CloseSocket(SocketT fd) {
116#if defined(_WIN32)
117 return closesocket(fd);
118#else
119 return close(fd);
120#endif
121}
122
123inline bool LastErrorWouldBlock() {
124 int errsv = LastError();
125#ifdef _WIN32
126 return errsv == WSAEWOULDBLOCK;
127#else
128 return errsv == EAGAIN || errsv == EWOULDBLOCK;
129#endif // _WIN32
130}
131
132inline void SocketStartup() {
133#if defined(_WIN32)
134 WSADATA wsa_data;
135 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
136 ThrowAtError("WSAStartup");
137 }
138 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
139 WSACleanup();
140 LOG(FATAL) << "Could not find a usable version of Winsock.dll";
141 }
142#endif // defined(_WIN32)
143}
144
145inline void SocketFinalize() {
146#if defined(_WIN32)
147 WSACleanup();
148#endif // defined(_WIN32)
149}
150
151#if defined(_WIN32) && defined(xgboost_IS_MINGW)
152// dummy definition for old mysys32.
153inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
154 MingWError();
155 return nullptr;
156}
157#else
158using ::inet_ntop;
159#endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
160
161} // namespace system
162
163namespace collective {
164class SockAddress;
165
166enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
167
172SockAddress MakeSockAddress(StringView host, in_port_t port);
173
175 sockaddr_in6 addr_;
176
177 public:
178 explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
179 SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
180
181 static SockAddrV6 Loopback();
182 static SockAddrV6 InaddrAny();
183
184 in_port_t Port() const { return ntohs(addr_.sin6_port); }
185
186 std::string Addr() const {
187 char buf[INET6_ADDRSTRLEN];
188 auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
189 buf, INET6_ADDRSTRLEN);
190 if (s == nullptr) {
191 system::ThrowAtError("inet_ntop");
192 }
193 return {buf};
194 }
195 sockaddr_in6 const &Handle() const { return addr_; }
196};
197
199 private:
200 sockaddr_in addr_;
201
202 public:
203 explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
204 SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
205
206 static SockAddrV4 Loopback();
207 static SockAddrV4 InaddrAny();
208
209 in_port_t Port() const { return ntohs(addr_.sin_port); }
210
211 std::string Addr() const {
212 char buf[INET_ADDRSTRLEN];
213 auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
214 buf, INET_ADDRSTRLEN);
215 if (s == nullptr) {
216 system::ThrowAtError("inet_ntop");
217 }
218 return {buf};
219 }
220 sockaddr_in const &Handle() const { return addr_; }
221};
222
227 private:
228 SockAddrV6 v6_;
229 SockAddrV4 v4_;
230 SockDomain domain_{SockDomain::kV4};
231
232 public:
233 SockAddress() = default;
234 explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
235 explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
236
237 auto Domain() const { return domain_; }
238
239 bool IsV4() const { return Domain() == SockDomain::kV4; }
240 bool IsV6() const { return !IsV4(); }
241
242 auto const &V4() const { return v4_; }
243 auto const &V6() const { return v6_; }
244};
245
250 public:
251 using HandleT = system::SocketT;
252
253 private:
254 HandleT handle_{InvalidSocket()};
255 // There's reliable no way to extract domain from a socket without first binding that
256 // socket on macos.
257#if defined(__APPLE__)
258 SockDomain domain_{SockDomain::kV4};
259#endif
260
261 constexpr static HandleT InvalidSocket() { return -1; }
262
263 explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
264
265 public:
266 TCPSocket() = default;
270 auto Domain() const -> SockDomain {
271 auto ret_iafamily = [](std::int32_t domain) {
272 switch (domain) {
273 case AF_INET:
274 return SockDomain::kV4;
275 case AF_INET6:
276 return SockDomain::kV6;
277 default: {
278 LOG(FATAL) << "Unknown IA family.";
279 }
280 }
281 return SockDomain::kV4;
282 };
283
284#if defined(_WIN32)
285 WSAPROTOCOL_INFOA info;
286 socklen_t len = sizeof(info);
287 xgboost_CHECK_SYS_CALL(
288 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
289 0);
290 return ret_iafamily(info.iAddressFamily);
291#elif defined(__APPLE__)
292 return domain_;
293#elif defined(__unix__)
294#ifndef __PASE__
295 std::int32_t domain;
296 socklen_t len = sizeof(domain);
297 xgboost_CHECK_SYS_CALL(
298 getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
299 return ret_iafamily(domain);
300#else
301 struct sockaddr sa;
302 socklen_t sizeofsa = sizeof(sa);
303 xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
304 if (sizeofsa < sizeof(uchar_t) * 2) {
305 return ret_iafamily(AF_INET);
306 }
307 return ret_iafamily(sa.sa_family);
308#endif // __PASE__
309#else
310 LOG(FATAL) << "Unknown platform.";
311 return ret_iafamily(AF_INET);
312#endif // platforms
313 }
314
315 bool IsClosed() const { return handle_ == InvalidSocket(); }
316
318 std::int32_t GetSockError() const {
319 std::int32_t error = 0;
320 socklen_t len = sizeof(error);
321 xgboost_CHECK_SYS_CALL(
322 getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
323 return error;
324 }
326 bool BadSocket() const {
327 if (IsClosed()) return true;
328 std::int32_t err = GetSockError();
329 if (err == EBADF || err == EINTR) return true;
330 return false;
331 }
332
333 void SetNonBlock() {
334 bool non_block{true};
335#if defined(_WIN32)
336 u_long mode = non_block ? 1 : 0;
337 xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
338#else
339 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
340 if (flag == -1) {
341 system::ThrowAtError("fcntl");
342 }
343 if (non_block) {
344 flag |= O_NONBLOCK;
345 } else {
346 flag &= ~O_NONBLOCK;
347 }
348 if (fcntl(handle_, F_SETFL, flag) == -1) {
349 system::ThrowAtError("fcntl");
350 }
351#endif // _WIN32
352 }
353
354 void SetKeepAlive() {
355 std::int32_t keepalive = 1;
356 xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
357 reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
358 0);
359 }
360
361 void SetNoDelay() {
362 std::int32_t tcp_no_delay = 1;
363 xgboost_CHECK_SYS_CALL(
364 setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
365 sizeof(tcp_no_delay)),
366 0);
367 }
368
373 HandleT newfd = accept(handle_, nullptr, nullptr);
374 if (newfd == InvalidSocket()) {
375 system::ThrowAtError("accept");
376 }
377 TCPSocket newsock{newfd};
378 return newsock;
379 }
380
381 ~TCPSocket() {
382 if (!IsClosed()) {
383 Close();
384 }
385 }
386
387 TCPSocket(TCPSocket const &that) = delete;
388 TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
389 TCPSocket &operator=(TCPSocket const &that) = delete;
390 TCPSocket &operator=(TCPSocket &&that) {
391 std::swap(this->handle_, that.handle_);
392 return *this;
393 }
397 HandleT const &Handle() const { return handle_; }
401 void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
405 in_port_t BindHost() {
406 if (Domain() == SockDomain::kV6) {
407 auto addr = SockAddrV6::InaddrAny();
408 auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
409 xgboost_CHECK_SYS_CALL(
410 bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
411
412 sockaddr_in6 res_addr;
413 socklen_t addrlen = sizeof(res_addr);
414 xgboost_CHECK_SYS_CALL(
415 getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
416 return ntohs(res_addr.sin6_port);
417 } else {
418 auto addr = SockAddrV4::InaddrAny();
419 auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
420 xgboost_CHECK_SYS_CALL(
421 bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
422
423 sockaddr_in res_addr;
424 socklen_t addrlen = sizeof(res_addr);
425 xgboost_CHECK_SYS_CALL(
426 getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
427 return ntohs(res_addr.sin_port);
428 }
429 }
433 auto SendAll(void const *buf, std::size_t len) {
434 char const *_buf = reinterpret_cast<const char *>(buf);
435 std::size_t ndone = 0;
436 while (ndone < len) {
437 ssize_t ret = send(handle_, _buf, len - ndone, 0);
438 if (ret == -1) {
439 if (system::LastErrorWouldBlock()) {
440 return ndone;
441 }
442 system::ThrowAtError("send");
443 }
444 _buf += ret;
445 ndone += ret;
446 }
447 return ndone;
448 }
452 auto RecvAll(void *buf, std::size_t len) {
453 char *_buf = reinterpret_cast<char *>(buf);
454 std::size_t ndone = 0;
455 while (ndone < len) {
456 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
457 if (ret == -1) {
458 if (system::LastErrorWouldBlock()) {
459 return ndone;
460 }
461 system::ThrowAtError("recv");
462 }
463 if (ret == 0) {
464 return ndone;
465 }
466 _buf += ret;
467 ndone += ret;
468 }
469 return ndone;
470 }
478 auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
479 const char *buf = reinterpret_cast<const char *>(buf_);
480 return send(handle_, buf, len, flags);
481 }
489 auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
490 char *_buf = reinterpret_cast<char *>(buf);
491 return recv(handle_, _buf, len, flags);
492 }
496 std::size_t Send(StringView str);
500 std::size_t Recv(std::string *p_str);
504 void Close() {
505 if (InvalidSocket() != handle_) {
506 xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
507 handle_ = InvalidSocket();
508 }
509 }
513 static TCPSocket Create(SockDomain domain) {
514#if defined(xgboost_IS_MINGW)
515 MingWError();
516 return {};
517#else
518 auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
519 if (fd == InvalidSocket()) {
520 system::ThrowAtError("socket");
521 }
522
523 TCPSocket socket{fd};
524#if defined(__APPLE__)
525 socket.domain_ = domain;
526#endif // defined(__APPLE__)
527 return socket;
528#endif // defined(xgboost_IS_MINGW)
529 }
530};
531
536std::error_code Connect(SockAddress const &addr, TCPSocket *out);
537
541inline std::string GetHostName() {
542 char buf[HOST_NAME_MAX];
543 xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
544 return buf;
545}
546} // namespace collective
547} // namespace xgboost
548
549#undef xgboost_CHECK_SYS_CALL
550
551#if defined(xgboost_IS_MINGW)
552#undef xgboost_IS_MINGW
553#endif
Definition socket.h:198
Definition socket.h:174
Address for TCP socket, can be either IPv4 or IPv6.
Definition socket.h:226
TCP socket for simple communication.
Definition socket.h:249
std::int32_t GetSockError() const
get last error code if any
Definition socket.h:318
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition socket.h:433
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition socket.h:504
void Listen(std::int32_t backlog=16)
Listen to incoming requests.
Definition socket.h:401
auto Domain() const -> SockDomain
Return the socket domain.
Definition socket.h:270
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition socket.h:513
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition socket.h:489
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition socket.h:452
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition socket.h:372
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition socket.h:478
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition socket.h:405
bool BadSocket() const
check if anything bad happens
Definition socket.h:326
HandleT const & Handle() const
Return the native socket file descriptor.
Definition socket.h:397
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
Definition socket.cc:74
std::string GetHostName()
Get the local host name.
Definition socket.h:541
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance.
Definition socket.cc:17
namespace of xgboost
Definition base.h:90
Definition string_view.h:15