Medial Code Documentation
Loading...
Searching...
No Matches
Data Structures | Enumerations | Functions | Variables
xgboost::collective Namespace Reference

Data Structures

class  AllgatherFunctor
 Functor for allgather. More...
 
struct  AllgatherVResult
 
class  AllreduceFunctor
 Functor for allreduce. More...
 
class  BroadcastFunctor
 Functor for broadcast. More...
 
class  Communicator
 A communicator class that handles collective communication. More...
 
class  CommunicatorContext
 
class  FederatedCommunicator
 A Federated Learning communicator class that handles collective communication. More...
 
class  FederatedCommunicatorTest
 
class  InMemoryCommunicator
 An in-memory communicator, useful for testing. More...
 
class  InMemoryCommunicatorTest
 
class  InMemoryHandler
 Handles collective communication primitives in memory. More...
 
class  NoOpCommunicator
 A no-op communicator, used for non-distributed training. More...
 
class  Op
 
class  RabitCommunicator
 
class  SockAddress
 Address for TCP socket, can be either IPv4 or IPv6. More...
 
class  SockAddrV4
 
class  SockAddrV6
 
class  TCPSocket
 TCP socket for simple communication. More...
 

Enumerations

enum class  SockDomain : std::int32_t { kV4 = AF_INET , kV6 = AF_INET6 }
 
enum class  DataType {
  kInt8 = 0 , kUInt8 = 1 , kInt32 = 2 , kUInt32 = 3 ,
  kInt64 = 4 , kUInt64 = 5 , kFloat = 6 , kDouble = 7
}
 Defines the integral and floating data types.
 
enum class  Operation {
  kMax = 0 , kMin = 1 , kSum = 2 , kBitwiseAND = 3 ,
  kBitwiseOR = 4 , kBitwiseXOR = 5
}
 Defines the reduction operation.
 
enum class  CommunicatorType {
  kUnknown , kRabit , kFederated , kInMemory ,
  kInMemoryNccl
}
 

Functions

SockAddress MakeSockAddress (StringView host, in_port_t port)
 Parse host address and return a SockAddress instance.
 
std::error_code Connect (SockAddress const &addr, TCPSocket *out)
 Connect to remote address, returns the error code if failed (no exception is raised so that we can retry).
 
std::string GetHostName ()
 Get the local host name.
 
None init (**Any args)
 
None finalize ()
 
int get_rank ()
 
int get_world_size ()
 
int is_distributed ()
 
None communicator_print (Any msg)
 
str get_processor_name ()
 
_T broadcast (_T data, int root)
 
np.ndarray allreduce (np.ndarray data, Op op)
 
template<typename Function >
void ApplyWithLabels (MetaInfo const &info, void *buffer, size_t size, Function &&function)
 Apply the given function where the labels are.
 
template<typename T >
GlobalMax (MetaInfo const &info, T value)
 Find the global max of the given value across all workers.
 
template<typename T >
void GlobalSum (MetaInfo const &info, T *values, size_t size)
 Find the global sum of the given values across all workers.
 
template<typename Container >
void GlobalSum (MetaInfo const &info, Container *values)
 
template<typename T >
GlobalRatio (MetaInfo const &info, T dividend, T divisor)
 Find the global ratio of the given two values across all workers.
 
void Init (Json const &config)
 Initialize the collective communicator.
 
void Finalize ()
 Finalize the collective communicator.
 
int GetRank ()
 Get rank of current process.
 
int GetWorldSize ()
 Get total number of processes.
 
bool IsDistributed ()
 Get if the communicator is distributed.
 
bool IsFederated ()
 Get if the communicator is federated.
 
void Print (char const *message)
 Print the message to the communicator.
 
void Print (std::string const &message)
 
std::string GetProcessorName ()
 Get the name of the processor.
 
void Broadcast (void *send_receive_buffer, size_t size, int root)
 Broadcast a memory region to all others from root. This function is NOT thread-safe.
 
void Broadcast (std::string *sendrecv_data, int root)
 
void Allgather (void *send_receive_buffer, std::size_t size)
 Gathers data from all processes and distributes it to all processes.
 
void Allreduce (void *send_receive_buffer, size_t count, int data_type, int op)
 Perform in-place allreduce. This function is NOT thread-safe.
 
void Allreduce (void *send_receive_buffer, size_t count, DataType data_type, Operation op)
 
template<Operation op>
void Allreduce (int8_t *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (uint8_t *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (int32_t *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (uint32_t *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (int64_t *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (uint64_t *send_receive_buffer, size_t count)
 
template<Operation op, typename T , typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}>>
void Allreduce (T *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (float *send_receive_buffer, size_t count)
 
template<Operation op>
void Allreduce (double *send_receive_buffer, size_t count)
 
template<typename T >
AllgatherVResult< T > AllgatherV (std::vector< T > const &inputs, std::vector< std::size_t > const &sizes)
 Gathers variable-length data from all processes and distributes it to all processes.
 
std::size_t GetTypeSize (DataType data_type)
 Get the size of the data type.
 
int CompareStringsCaseInsensitive (const char *s1, const char *s2)
 Case-insensitive string comparison.
 
 TEST (CommunicatorFactory, TypeFromEnv)
 
 TEST (CommunicatorFactory, TypeFromArgs)
 
 TEST (CommunicatorFactory, TypeFromArgsUpperCase)
 
 TEST (InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall)
 
 TEST (InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall)
 
 TEST (InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig)
 
 TEST (InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger)
 
 TEST (InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger)
 
 TEST (InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank)
 
 TEST (InMemoryCommunicatorSimpleTest, IsDistributed)
 
 TEST_F (InMemoryCommunicatorTest, Allgather)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceMax)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceMin)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceSum)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceBitwiseAND)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceBitwiseOR)
 
 TEST_F (InMemoryCommunicatorTest, AllreduceBitwiseXOR)
 
 TEST_F (InMemoryCommunicatorTest, Broadcast)
 
 TEST_F (InMemoryCommunicatorTest, Mixture)
 
 TEST (RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall)
 
 TEST (RabitCommunicatorSimpleTest, ThrowOnRankTooSmall)
 
 TEST (RabitCommunicatorSimpleTest, ThrowOnRankTooBig)
 
 TEST (RabitCommunicatorSimpleTest, GetWorldSizeAndRank)
 
 TEST (RabitCommunicatorSimpleTest, IsNotDistributed)
 
 TEST (Socket, Basic)
 
 TEST (FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall)
 
 TEST (FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall)
 
 TEST (FederatedCommunicatorSimpleTest, ThrowOnRankTooBig)
 
 TEST (FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger)
 
 TEST (FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger)
 
 TEST (FederatedCommunicatorSimpleTest, GetWorldSizeAndRank)
 
 TEST (FederatedCommunicatorSimpleTest, IsDistributed)
 
 TEST_F (FederatedCommunicatorTest, Allgather)
 
 TEST_F (FederatedCommunicatorTest, Allreduce)
 
 TEST_F (FederatedCommunicatorTest, Broadcast)
 

Variables

 LOGGER = logging.getLogger("[xgboost.collective]")
 
dict DTYPE_ENUM__
 

Detailed Description

XGBoost collective communication related API.

Copyright 2022 XGBoost contributors

Function Documentation

◆ Allgather()

void xgboost::collective::Allgather ( void *  send_receive_buffer,
std::size_t  size 
)
inline

Gathers data from all processes and distributes it to all processes.

This assumes all ranks have the same size, and input data has been sliced into the corresponding position.

Parameters
send_receive_bufferBuffer storing the data.
sizeSize of the data in bytes.

◆ AllgatherV()

template<typename T >
AllgatherVResult< T > xgboost::collective::AllgatherV ( std::vector< T > const &  inputs,
std::vector< std::size_t > const &  sizes 
)
inline

Gathers variable-length data from all processes and distributes it to all processes.

We assume each worker has the same number of inputs, but each input may be of a different size.

Parameters
inputsAll the inputs from the local worker.
sizesSizes of each input.

◆ allreduce()

np.ndarray xgboost.collective.allreduce ( np.ndarray  data,
Op  op 
)
Perform allreduce, return the result.

Parameters
----------
data :
    Input data.
op :
    Reduction operator.

Returns
-------
result :
    The result of allreduce, have same shape as data

Notes
-----
This function is not thread-safe.

◆ Allreduce()

void xgboost::collective::Allreduce ( void *  send_receive_buffer,
size_t  count,
int  data_type,
int  op 
)
inline

Perform in-place allreduce. This function is NOT thread-safe.

Example Usage: the following code gives sum of the result vector<int> data(10); ... Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); ...

Parameters
send_receive_bufferBuffer for both sending and receiving data.
countNumber of elements to be reduced.
data_typeEnumeration of data type, see xgboost::collective::DataType in communicator.h.
opEnumeration of operation type, see xgboost::collective::Operation in communicator.h.

◆ ApplyWithLabels()

template<typename Function >
void xgboost::collective::ApplyWithLabels ( MetaInfo const &  info,
void *  buffer,
size_t  size,
Function &&  function 
)

Apply the given function where the labels are.

Normally all the workers have access to the labels, so the function is just applied locally. In vertical federated learning, we assume labels are only available on worker 0, so the function is applied there, with the results broadcast to other workers.

Template Parameters
FunctionThe function used to calculate the results.
ArgsArguments to the function.
Parameters
infoMetaInfo about the DMatrix.
bufferThe buffer storing the results.
sizeThe size of the buffer.
functionThe function used to calculate the results.

◆ broadcast()

_T xgboost.collective.broadcast ( _T  data,
int  root 
)
Broadcast object from one node to all other nodes.

Parameters
----------
data : any type that can be pickled
    Input data, if current rank does not equal root, this can be None
root : int
    Rank of the node to broadcast data from.

Returns
-------
object : int
    the result of broadcast.

◆ Broadcast()

void xgboost::collective::Broadcast ( void *  send_receive_buffer,
size_t  size,
int  root 
)
inline

Broadcast a memory region to all others from root. This function is NOT thread-safe.

Example: int a = 1; Broadcast(&a, sizeof(a), root);

Parameters
send_receive_bufferPointer to the send or receive buffer.
sizeSize of the data.
rootThe process rank to broadcast from.

◆ communicator_print()

None xgboost.collective.communicator_print ( Any  msg)
Print message to the communicator.

This function can be used to communicate the information of
the progress to the communicator.

Parameters
----------
msg : str
    The message to be printed to the communicator.

◆ finalize()

None xgboost.collective.finalize ( )
Finalize the communicator.

◆ Finalize()

void xgboost::collective::Finalize ( void  )
inline

Finalize the collective communicator.

Call this function after you finished all jobs.

◆ get_processor_name()

str xgboost.collective.get_processor_name ( )
Get the processor name.

Returns
-------
name : str
    the name of processor(host)

◆ get_rank()

int xgboost.collective.get_rank ( )
Get rank of current process.

Returns
-------
rank : int
    Rank of current process.

◆ get_world_size()

int xgboost.collective.get_world_size ( )
Get total number workers.

Returns
-------
n : int
    Total number of process.

◆ GetProcessorName()

std::string xgboost::collective::GetProcessorName ( )
inline

Get the name of the processor.

Returns
Name of the processor.

◆ GetRank()

int xgboost::collective::GetRank ( )
inline

Get rank of current process.

Returns
Rank of the worker.

◆ GetWorldSize()

int xgboost::collective::GetWorldSize ( )
inline

Get total number of processes.

Returns
Total world size.

◆ GlobalMax()

template<typename T >
T xgboost::collective::GlobalMax ( MetaInfo const &  info,
value 
)

Find the global max of the given value across all workers.

This only applies when the data is split row-wise (horizontally). When data is split column-wise (vertically), the local value is returned.

Template Parameters
TThe type of the value.
Parameters
infoMetaInfo about the DMatrix.
valueThe input for finding the global max.
Returns
The global max of the input.

◆ GlobalRatio()

template<typename T >
T xgboost::collective::GlobalRatio ( MetaInfo const &  info,
dividend,
divisor 
)

Find the global ratio of the given two values across all workers.

This only applies when the data is split row-wise (horizontally). When data is split column-wise (vertically), the local ratio is returned.

Template Parameters
TThe type of the values.
Parameters
infoMetaInfo about the DMatrix.
dividendThe dividend of the ratio.
divisorThe divisor of the ratio.
Returns
The global ratio of the two inputs.

◆ GlobalSum()

template<typename T >
void xgboost::collective::GlobalSum ( MetaInfo const &  info,
T *  values,
size_t  size 
)

Find the global sum of the given values across all workers.

This only applies when the data is split row-wise (horizontally). When data is split column-wise (vertically), the original values are returned.

Template Parameters
TThe type of the values.
Parameters
infoMetaInfo about the DMatrix.
valuesPointer to the inputs to sum.
sizeNumber of values to sum.

◆ init()

None xgboost.collective.init ( **Any  args)
Initialize the collective library with arguments.

Parameters
----------
args: Dict[str, Any]
    Keyword arguments representing the parameters and their values.

    Accepted parameters:
      - xgboost_communicator: The type of the communicator. Can be set as an environment
        variable.
        * rabit: Use Rabit. This is the default if the type is unspecified.
        * federated: Use the gRPC interface for Federated Learning.
    Only applicable to the Rabit communicator (these are case sensitive):
      -- rabit_tracker_uri: Hostname of the tracker.
      -- rabit_tracker_port: Port number of the tracker.
      -- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
         assignment.
      -- rabit_world_size: Total number of workers.
      -- rabit_hadoop_mode: Enable Hadoop support.
      -- rabit_tree_reduce_minsize: Minimal size for tree reduce.
      -- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
      -- rabit_reduce_buffer: Size of the reduce buffer.
      -- rabit_bootstrap_cache: Size of the bootstrap cache.
      -- rabit_debug: Enable debugging.
      -- rabit_timeout: Enable timeout.
      -- rabit_timeout_sec: Timeout in seconds.
      -- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
    Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
    environment variables):
      -- DMLC_TRACKER_URI: Hostname of the tracker.
      -- DMLC_TRACKER_PORT: Port number of the tracker.
      -- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
         assignment.
      -- DMLC_ROLE: Role of the current task, "worker" or "server".
      -- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
      -- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
    Only applicable to the Federated communicator (use upper case for environment variables, use
    lower case for runtime configuration):
      -- federated_server_address: Address of the federated server.
      -- federated_world_size: Number of federated workers.
      -- federated_rank: Rank of the current worker.
      -- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
      -- federated_client_key: Client key file path. Only needed for the SSL mode.
      -- federated_client_cert: Client certificate file path. Only needed for the SSL mode.

◆ Init()

void xgboost::collective::Init ( Json const &  config)
inline

Initialize the collective communicator.

Currently the communicator API is experimental, function signatures may change in the future without notice.

Call this once before using anything.

The additional configuration is not required. Usually the communicator will detect settings from environment variables.

Parameters
json_configJSON encoded configuration. Accepted JSON keys are:
  • xgboost_communicator: The type of the communicator. Can be set as an environment variable.
    • rabit: Use Rabit. This is the default if the type is unspecified.
    • mpi: Use MPI.
    • federated: Use the gRPC interface for Federated Learning. Only applicable to the Rabit communicator (these are case-sensitive):
  • rabit_tracker_uri: Hostname of the tracker.
  • rabit_tracker_port: Port number of the tracker.
  • rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
  • rabit_world_size: Total number of workers.
  • rabit_hadoop_mode: Enable Hadoop support.
  • rabit_tree_reduce_minsize: Minimal size for tree reduce.
  • rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
  • rabit_reduce_buffer: Size of the reduce buffer.
  • rabit_bootstrap_cache: Size of the bootstrap cache.
  • rabit_debug: Enable debugging.
  • rabit_timeout: Enable timeout.
  • rabit_timeout_sec: Timeout in seconds.
  • rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. Only applicable to the Rabit communicator (these are case-sensitive, and can be set as environment variables):
  • DMLC_TRACKER_URI: Hostname of the tracker.
  • DMLC_TRACKER_PORT: Port number of the tracker.
  • DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
  • DMLC_ROLE: Role of the current task, "worker" or "server".
  • DMLC_NUM_ATTEMPT: Number of attempts after task failure.
  • DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. Only applicable to the Federated communicator (use upper case for environment variables, use lower case for runtime configuration):
  • federated_server_address: Address of the federated server.
  • federated_world_size: Number of federated workers.
  • federated_rank: Rank of the current worker.
  • federated_server_cert: Server certificate file path. Only needed for the SSL mode.
  • federated_client_key: Client key file path. Only needed for the SSL mode.
  • federated_client_cert: Client certificate file path. Only needed for the SSL mode.

◆ is_distributed()

int xgboost.collective.is_distributed ( )
If the collective communicator is distributed.

◆ IsDistributed()

bool xgboost::collective::IsDistributed ( )
inline

Get if the communicator is distributed.

Returns
True if the communicator is distributed.

◆ IsFederated()

bool xgboost::collective::IsFederated ( )
inline

Get if the communicator is federated.

Returns
True if the communicator is federated.

◆ MakeSockAddress()

SockAddress xgboost::collective::MakeSockAddress ( StringView  host,
in_port_t  port 
)

Parse host address and return a SockAddress instance.

Supports IPv4 and IPv6 host.

◆ Print()

void xgboost::collective::Print ( char const *  message)
inline

Print the message to the communicator.

This function can be used to communicate the information of the progress to the user who monitors the communicator.

Parameters
messageThe message to be printed.

Variable Documentation

◆ DTYPE_ENUM__

dict xgboost.collective.DTYPE_ENUM__
Initial value:
1= {
2 np.dtype("int8"): 0,
3 np.dtype("uint8"): 1,
4 np.dtype("int32"): 2,
5 np.dtype("uint32"): 3,
6 np.dtype("int64"): 4,
7 np.dtype("uint64"): 5,
8 np.dtype("float32"): 6,
9 np.dtype("float64"): 7,
10}