6#if !defined(NOMINMAX) && defined(_WIN32)
17#include <system_error>
20#if !defined(xgboost_IS_MINGW)
22#if defined(__MINGW32__)
23#define xgboost_IS_MINGW 1
33using in_port_t = std::uint16_t;
36#pragma comment(lib, "Ws2_32.lib")
39#if !defined(xgboost_IS_MINGW)
47#include <netinet/in.h>
48#include <netinet/in.h>
49#include <netinet/tcp.h>
50#include <sys/socket.h>
53#if defined(__sun) || defined(sun)
54#include <sys/sockio.h>
61#include "xgboost/string_view.h"
63#if !defined(HOST_NAME_MAX)
64#define HOST_NAME_MAX 256
69#if defined(xgboost_IS_MINGW)
71inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
75inline std::int32_t LastError() {
77 return WSAGetLastError();
85inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
86 std::int32_t line = __builtin_LINE(),
87 char const *file = __builtin_FILE()) {
88 auto err = std::error_code{errsv, std::system_category()};
90 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
94inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
95 auto err = std::error_code{errsv, std::system_category()};
96 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
101using SocketT = SOCKET;
106#if !defined(xgboost_CHECK_SYS_CALL)
107#define xgboost_CHECK_SYS_CALL(exp, expected) \
109 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
110 ::xgboost::system::ThrowAtError(#exp); \
115inline std::int32_t CloseSocket(SocketT fd) {
117 return closesocket(fd);
123inline bool LastErrorWouldBlock() {
124 int errsv = LastError();
126 return errsv == WSAEWOULDBLOCK;
128 return errsv == EAGAIN || errsv == EWOULDBLOCK;
132inline void SocketStartup() {
135 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
136 ThrowAtError(
"WSAStartup");
138 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
140 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
145inline void SocketFinalize() {
151#if defined(_WIN32) && defined(xgboost_IS_MINGW)
153inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
163namespace collective {
166enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
178 explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
179 SockAddrV6() { std::memset(&addr_,
'\0',
sizeof(addr_)); }
184 in_port_t Port()
const {
return ntohs(addr_.sin6_port); }
186 std::string Addr()
const {
187 char buf[INET6_ADDRSTRLEN];
188 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(SockDomain::kV6), &addr_.sin6_addr,
189 buf, INET6_ADDRSTRLEN);
191 system::ThrowAtError(
"inet_ntop");
195 sockaddr_in6
const &Handle()
const {
return addr_; }
203 explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
204 SockAddrV4() { std::memset(&addr_,
'\0',
sizeof(addr_)); }
209 in_port_t Port()
const {
return ntohs(addr_.sin_port); }
211 std::string Addr()
const {
212 char buf[INET_ADDRSTRLEN];
213 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(SockDomain::kV4), &addr_.sin_addr,
214 buf, INET_ADDRSTRLEN);
216 system::ThrowAtError(
"inet_ntop");
220 sockaddr_in
const &Handle()
const {
return addr_; }
230 SockDomain domain_{SockDomain::kV4};
237 auto Domain()
const {
return domain_; }
239 bool IsV4()
const {
return Domain() == SockDomain::kV4; }
240 bool IsV6()
const {
return !IsV4(); }
242 auto const &V4()
const {
return v4_; }
243 auto const &V6()
const {
return v6_; }
251 using HandleT = system::SocketT;
254 HandleT handle_{InvalidSocket()};
257#if defined(__APPLE__)
258 SockDomain domain_{SockDomain::kV4};
261 constexpr static HandleT InvalidSocket() {
return -1; }
263 explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
271 auto ret_iafamily = [](std::int32_t domain) {
274 return SockDomain::kV4;
276 return SockDomain::kV6;
278 LOG(FATAL) <<
"Unknown IA family.";
281 return SockDomain::kV4;
285 WSAPROTOCOL_INFOA info;
286 socklen_t len =
sizeof(info);
287 xgboost_CHECK_SYS_CALL(
288 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
290 return ret_iafamily(info.iAddressFamily);
291#elif defined(__APPLE__)
293#elif defined(__unix__)
296 socklen_t len =
sizeof(domain);
297 xgboost_CHECK_SYS_CALL(
298 getsockopt(handle_, SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len), 0);
299 return ret_iafamily(domain);
302 socklen_t sizeofsa =
sizeof(sa);
303 xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
304 if (sizeofsa <
sizeof(uchar_t) * 2) {
305 return ret_iafamily(AF_INET);
307 return ret_iafamily(sa.sa_family);
310 LOG(FATAL) <<
"Unknown platform.";
311 return ret_iafamily(AF_INET);
315 bool IsClosed()
const {
return handle_ == InvalidSocket(); }
319 std::int32_t error = 0;
320 socklen_t len =
sizeof(error);
321 xgboost_CHECK_SYS_CALL(
322 getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&error), &len), 0);
327 if (IsClosed())
return true;
329 if (err == EBADF || err == EINTR)
return true;
334 bool non_block{
true};
336 u_long mode = non_block ? 1 : 0;
337 xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
339 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
341 system::ThrowAtError(
"fcntl");
348 if (fcntl(handle_, F_SETFL, flag) == -1) {
349 system::ThrowAtError(
"fcntl");
354 void SetKeepAlive() {
355 std::int32_t keepalive = 1;
356 xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
357 reinterpret_cast<char *
>(&keepalive),
sizeof(keepalive)),
362 std::int32_t tcp_no_delay = 1;
363 xgboost_CHECK_SYS_CALL(
364 setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&tcp_no_delay),
365 sizeof(tcp_no_delay)),
373 HandleT newfd = accept(handle_,
nullptr,
nullptr);
374 if (newfd == InvalidSocket()) {
375 system::ThrowAtError(
"accept");
387 TCPSocket(TCPSocket
const &that) =
delete;
388 TCPSocket(TCPSocket &&that)
noexcept(
true) {
std::swap(this->handle_, that.handle_); }
389 TCPSocket &operator=(TCPSocket
const &that) =
delete;
390 TCPSocket &operator=(TCPSocket &&that) {
397 HandleT
const &
Handle()
const {
return handle_; }
401 void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
406 if (
Domain() == SockDomain::kV6) {
407 auto addr = SockAddrV6::InaddrAny();
408 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
409 xgboost_CHECK_SYS_CALL(
410 bind(handle_, handle,
sizeof(std::remove_reference_t<
decltype(addr.Handle())>)), 0);
412 sockaddr_in6 res_addr;
413 socklen_t addrlen =
sizeof(res_addr);
414 xgboost_CHECK_SYS_CALL(
415 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
416 return ntohs(res_addr.sin6_port);
418 auto addr = SockAddrV4::InaddrAny();
419 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
420 xgboost_CHECK_SYS_CALL(
421 bind(handle_, handle,
sizeof(std::remove_reference_t<
decltype(addr.Handle())>)), 0);
423 sockaddr_in res_addr;
424 socklen_t addrlen =
sizeof(res_addr);
425 xgboost_CHECK_SYS_CALL(
426 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
427 return ntohs(res_addr.sin_port);
433 auto SendAll(
void const *buf, std::size_t len) {
434 char const *_buf =
reinterpret_cast<const char *
>(buf);
435 std::size_t ndone = 0;
436 while (ndone < len) {
437 ssize_t ret = send(handle_, _buf, len - ndone, 0);
439 if (system::LastErrorWouldBlock()) {
442 system::ThrowAtError(
"send");
453 char *_buf =
reinterpret_cast<char *
>(buf);
454 std::size_t ndone = 0;
455 while (ndone < len) {
456 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
458 if (system::LastErrorWouldBlock()) {
461 system::ThrowAtError(
"recv");
478 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
479 const char *buf =
reinterpret_cast<const char *
>(buf_);
480 return send(handle_, buf, len, flags);
489 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
490 char *_buf =
reinterpret_cast<char *
>(buf);
491 return recv(handle_, _buf, len, flags);
500 std::size_t
Recv(std::string *p_str);
505 if (InvalidSocket() != handle_) {
506 xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
507 handle_ = InvalidSocket();
514#if defined(xgboost_IS_MINGW)
518 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
519 if (fd == InvalidSocket()) {
520 system::ThrowAtError(
"socket");
524#if defined(__APPLE__)
525 socket.domain_ = domain;
536std::error_code
Connect(SockAddress
const &addr, TCPSocket *out);
542 char buf[HOST_NAME_MAX];
543 xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
549#undef xgboost_CHECK_SYS_CALL
551#if defined(xgboost_IS_MINGW)
552#undef xgboost_IS_MINGW
Address for TCP socket, can be either IPv4 or IPv6.
Definition socket.h:226
TCP socket for simple communication.
Definition socket.h:249
std::int32_t GetSockError() const
get last error code if any
Definition socket.h:318
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition socket.h:433
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition socket.h:504
void Listen(std::int32_t backlog=16)
Listen to incoming requests.
Definition socket.h:401
auto Domain() const -> SockDomain
Return the socket domain.
Definition socket.h:270
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition socket.h:513
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition socket.h:489
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition socket.h:452
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition socket.h:372
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition socket.h:478
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition socket.h:405
bool BadSocket() const
check if anything bad happens
Definition socket.h:326
HandleT const & Handle() const
Return the native socket file descriptor.
Definition socket.h:397
Copyright 2015-2023 by XGBoost Contributors.
defines console logging options for xgboost. Use to enforce unified print behavior.
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
Definition socket.cc:74
std::string GetHostName()
Get the local host name.
Definition socket.h:541
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance.
Definition socket.cc:17
namespace of xgboost
Definition base.h:90
Definition string_view.h:15