Medial Code Documentation
Loading...
Searching...
No Matches
socket_wrapper.hpp
1#ifndef LIGHTGBM_NETWORK_SOCKET_WRAPPER_HPP_
2#define LIGHTGBM_NETWORK_SOCKET_WRAPPER_HPP_
3#ifdef USE_SOCKET
4
5#if defined(_WIN32)
6#ifdef _MSC_VER
7#define NOMINMAX
8#endif
9#include <winsock2.h>
10#include <ws2tcpip.h>
11#include <iphlpapi.h>
12
13#else
14
15#include <fcntl.h>
16#include <netdb.h>
17#include <cerrno>
18#include <unistd.h>
19#include <arpa/inet.h>
20#include <netinet/in.h>
21#include <sys/socket.h>
22#include <sys/ioctl.h>
23#include <sys/types.h>
24#include <ifaddrs.h>
25#include <netinet/tcp.h>
26
27#endif
28
29#include <LightGBM/utils/log.h>
30
31#include <cstdlib>
32
33#include <unordered_set>
34#include <string>
35
36#ifdef _MSC_VER
37#pragma comment(lib, "Ws2_32.lib")
38#pragma comment(lib, "IPHLPAPI.lib")
39#endif
40
41namespace LightGBM {
42
43#ifndef _WIN32
44
45typedef int SOCKET;
46const int INVALID_SOCKET = -1;
47#define SOCKET_ERROR -1
48
49#endif
50
51#ifdef _WIN32
52#ifndef _MSC_VER
53// not using visual studio in windows
54inline int inet_pton(int af, const char *src, void *dst) {
55 struct sockaddr_storage ss;
56 int size = sizeof(ss);
57 char src_copy[INET6_ADDRSTRLEN + 1];
58
59 ZeroMemory(&ss, sizeof(ss));
60 /* stupid non-const API */
61 strncpy(src_copy, src, INET6_ADDRSTRLEN + 1);
62 src_copy[INET6_ADDRSTRLEN] = 0;
63
64 if (WSAStringToAddress(src_copy, af, NULL, (struct sockaddr *)&ss, &size) == 0) {
65 switch (af) {
66 case AF_INET:
67 *(struct in_addr *)dst = ((struct sockaddr_in *)&ss)->sin_addr;
68 return 1;
69 case AF_INET6:
70 *(struct in6_addr *)dst = ((struct sockaddr_in6 *)&ss)->sin6_addr;
71 return 1;
72 }
73 }
74 return 0;
75}
76#endif
77#endif
78
79#define MALLOC(x) HeapAlloc(GetProcessHeap(), 0, (x))
80#define FREE(x) HeapFree(GetProcessHeap(), 0, (x))
81
82namespace SocketConfig {
83const int kSocketBufferSize = 100 * 1000;
84const int kMaxReceiveSize = 100 * 1000;
85const bool kNoDelay = true;
86}
87
88class TcpSocket {
89public:
90 TcpSocket() {
91 sockfd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
92 if (sockfd_ == INVALID_SOCKET) {
93 Log::Fatal("Socket construction error");
94 return;
95 }
96 ConfigSocket();
97 }
98
99 explicit TcpSocket(SOCKET socket) {
100 sockfd_ = socket;
101 if (sockfd_ == INVALID_SOCKET) {
102 Log::Fatal("Passed socket error");
103 return;
104 }
105 ConfigSocket();
106 }
107
108 TcpSocket(const TcpSocket &object) {
109 sockfd_ = object.sockfd_;
110 ConfigSocket();
111 }
112 ~TcpSocket() {
113 }
114 inline void SetTimeout(int timeout) {
115 setsockopt(sockfd_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char*>(&timeout), sizeof(timeout));
116 }
117 inline void ConfigSocket() {
118 if (sockfd_ == INVALID_SOCKET) {
119 return;
120 }
121
122 if (setsockopt(sockfd_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<const char*>(&SocketConfig::kSocketBufferSize), sizeof(SocketConfig::kSocketBufferSize)) != 0) {
123 Log::Warning("Set SO_RCVBUF failed, please increase your net.core.rmem_max to 100k at least");
124 }
125
126 if (setsockopt(sockfd_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<const char*>(&SocketConfig::kSocketBufferSize), sizeof(SocketConfig::kSocketBufferSize)) != 0) {
127 Log::Warning("Set SO_SNDBUF failed, please increase your net.core.wmem_max to 100k at least");
128 }
129 if (setsockopt(sockfd_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char*>(&SocketConfig::kNoDelay), sizeof(SocketConfig::kNoDelay)) != 0) {
130 Log::Warning("Set TCP_NODELAY failed");
131 }
132 }
133
134 inline static void Startup() {
135#if defined(_WIN32)
136 WSADATA wsa_data;
137 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
138 Log::Fatal("Socket error: WSAStartup error");
139 }
140 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
141 WSACleanup();
142 Log::Fatal("Socket error: Winsock.dll version error");
143 }
144#else
145#endif
146 }
147 inline static void Finalize() {
148#if defined(_WIN32)
149 WSACleanup();
150#endif
151 }
152
153 inline static int GetLastError() {
154#if defined(_WIN32)
155 return WSAGetLastError();
156#else
157 return errno;
158#endif
159 }
160
161
162
163#if defined(_WIN32)
164 inline static std::unordered_set<std::string> GetLocalIpList() {
165 std::unordered_set<std::string> ip_list;
166 char buffer[512];
167 // get hostName
168 if (gethostname(buffer, sizeof(buffer)) == SOCKET_ERROR) {
169 Log::Fatal("Error code %d, when getting local host name", WSAGetLastError());
170 }
171 // push local ip
172 PIP_ADAPTER_INFO pAdapterInfo;
173 PIP_ADAPTER_INFO pAdapter = NULL;
174 DWORD dwRetVal = 0;
175 ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO);
176 pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(sizeof(IP_ADAPTER_INFO));
177 if (pAdapterInfo == NULL) {
178 Log::Fatal("GetAdaptersinfo error: allocating memory");
179 }
180 // Make an initial call to GetAdaptersInfo to get
181 // the necessary size into the ulOutBufLen variable
182 if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) {
183 FREE(pAdapterInfo);
184 pAdapterInfo = (IP_ADAPTER_INFO *)MALLOC(ulOutBufLen);
185 if (pAdapterInfo == NULL) {
186 Log::Fatal("GetAdaptersinfo error: allocating memory");
187 }
188 }
189 if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) {
190 pAdapter = pAdapterInfo;
191 while (pAdapter) {
192 ip_list.insert(pAdapter->IpAddressList.IpAddress.String);
193 pAdapter = pAdapter->Next;
194 }
195 } else {
196 Log::Fatal("GetAdaptersinfo error: code %d", dwRetVal);
197 }
198 if (pAdapterInfo)
199 FREE(pAdapterInfo);
200 return ip_list;
201 }
202#else
203 inline static std::unordered_set<std::string> GetLocalIpList() {
204 std::unordered_set<std::string> ip_list;
205 struct ifaddrs * ifAddrStruct = NULL;
206 struct ifaddrs * ifa = NULL;
207 void * tmpAddrPtr = NULL;
208
209 getifaddrs(&ifAddrStruct);
210
211 for (ifa = ifAddrStruct; ifa != NULL; ifa = ifa->ifa_next) {
212 if (!ifa->ifa_addr) {
213 continue;
214 }
215 if (ifa->ifa_addr->sa_family == AF_INET) {
216 tmpAddrPtr = &((struct sockaddr_in *)ifa->ifa_addr)->sin_addr;
217 char addressBuffer[INET_ADDRSTRLEN];
218 inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN);
219 ip_list.insert(std::string(addressBuffer));
220 }
221 }
222 if (ifAddrStruct != NULL) freeifaddrs(ifAddrStruct);
223 return ip_list;
224 }
225#endif
226 inline static sockaddr_in GetAddress(const char* url, int port) {
227 sockaddr_in addr = sockaddr_in();
228 std::memset(&addr, 0, sizeof(sockaddr_in));
229 inet_pton(AF_INET, url, &addr.sin_addr);
230 addr.sin_family = AF_INET;
231 addr.sin_port = htons(static_cast<u_short>(port));
232 return addr;
233 }
234
235 inline bool Bind(int port) {
236 sockaddr_in local_addr = GetAddress("0.0.0.0", port);
237 if (bind(sockfd_, reinterpret_cast<const sockaddr*>(&local_addr), sizeof(sockaddr_in)) == 0) {
238 return true;
239 }
240 return false;
241 }
242
243 inline bool Connect(const char *url, int port) {
244 sockaddr_in server_addr = GetAddress(url, port);
245 if (connect(sockfd_, reinterpret_cast<const sockaddr*>(&server_addr), sizeof(sockaddr_in)) == 0) {
246 return true;
247 }
248 return false;
249 }
250
251 inline void Listen(int backlog = 128) {
252 listen(sockfd_, backlog);
253 }
254
255 inline TcpSocket Accept() {
256 SOCKET newfd = accept(sockfd_, NULL, NULL);
257 if (newfd == INVALID_SOCKET) {
258 Log::Fatal("Socket accept error, code: %d", GetLastError());
259 }
260 return TcpSocket(newfd);
261 }
262
263 inline int Send(const char *buf_, int len, int flag = 0) {
264 int cur_cnt = send(sockfd_, buf_, len, flag);
265 if (cur_cnt == SOCKET_ERROR) {
266 Log::Fatal("Socket send error, code: %d", GetLastError());
267 }
268 return cur_cnt;
269 }
270
271 inline int Recv(char *buf_, int len, int flags = 0) {
272 int cur_cnt = recv(sockfd_, buf_ , len , flags);
273 if (cur_cnt == SOCKET_ERROR) {
274 Log::Fatal("Socket recv error, code: %d", GetLastError());
275 }
276 return cur_cnt;
277 }
278
279 inline bool IsClosed() {
280 return sockfd_ == INVALID_SOCKET;
281 }
282
283 inline void Close() {
284 if (!IsClosed()) {
285#if defined(_WIN32)
286 closesocket(sockfd_);
287#else
288 close(sockfd_);
289#endif
290 sockfd_ = INVALID_SOCKET;
291 }
292 }
293
294private:
295 SOCKET sockfd_;
296};
297
298} // namespace LightGBM
299#endif // USE_SOCKET
300#endif // LightGBM_NETWORK_SOCKET_WRAPPER_HPP_
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
bool Finalize()
finalizes the engine module
Definition engine.cc:55
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
Definition socket.cc:74