Medial Code Documentation
Loading...
Searching...
No Matches
aggregator.h
1
8#pragma once
9#include <xgboost/data.h>
10
11#include <limits>
12#include <string>
13#include <utility>
14#include <vector>
15
16#include "communicator-inl.h"
17
18namespace xgboost {
19namespace collective {
20
35template <typename Function>
36void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
37 if (info.IsVerticalFederated()) {
38 // We assume labels are only available on worker 0, so the calculation is done there and result
39 // broadcast to other workers.
40 std::string message;
41 if (collective::GetRank() == 0) {
42 try {
43 std::forward<Function>(function)();
44 } catch (dmlc::Error& e) {
45 message = e.what();
46 }
47 }
48
49 collective::Broadcast(&message, 0);
50 if (message.empty()) {
51 collective::Broadcast(buffer, size, 0);
52 } else {
53 LOG(FATAL) << &message[0];
54 }
55 } else {
56 std::forward<Function>(function)();
57 }
58}
59
71template <typename T>
72T GlobalMax(MetaInfo const& info, T value) {
73 if (info.IsRowSplit()) {
74 collective::Allreduce<collective::Operation::kMax>(&value, 1);
75 }
76 return value;
77}
78
90template <typename T>
91void GlobalSum(MetaInfo const& info, T* values, size_t size) {
92 if (info.IsRowSplit()) {
93 collective::Allreduce<collective::Operation::kSum>(values, size);
94 }
95}
96
97template <typename Container>
98void GlobalSum(MetaInfo const& info, Container* values) {
99 GlobalSum(info, values->data(), values->size());
100}
101
114template <typename T>
115T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
116 std::array<T, 2> results{dividend, divisor};
117 GlobalSum(info, &results);
118 std::tie(dividend, divisor) = std::tuple_cat(results);
119 if (divisor <= 0) {
120 return std::numeric_limits<T>::quiet_NaN();
121 } else {
122 return dividend / divisor;
123 }
124}
125
126} // namespace collective
127} // namespace xgboost
Meta information about dataset, always sit in memory.
Definition data.h:48
bool IsVerticalFederated() const
A convenient method to check if we are doing vertical federated learning, which requires some special...
Definition data.cc:807
bool IsRowSplit() const
Whether the data is split row-wise.
Definition data.h:184
Copyright 2015-2023 by XGBoost Contributors.
void Broadcast(void *send_receive_buffer, size_t size, int root)
Broadcast a memory region to all others from root. This function is NOT thread-safe.
Definition communicator-inl.h:129
void GlobalSum(MetaInfo const &info, T *values, size_t size)
Find the global sum of the given values across all workers.
Definition aggregator.h:91
T GlobalRatio(MetaInfo const &info, T dividend, T divisor)
Find the global ratio of the given two values across all workers.
Definition aggregator.h:115
void ApplyWithLabels(MetaInfo const &info, void *buffer, size_t size, Function &&function)
Apply the given function where the labels are.
Definition aggregator.h:36
int GetRank()
Get rank of current process.
Definition communicator-inl.h:76
T GlobalMax(MetaInfo const &info, T value)
Find the global max of the given value across all workers.
Definition aggregator.h:72
namespace of xgboost
Definition base.h:90
exception class that will be thrown by default logger if DMLC_LOG_FATAL_THROW == 1
Definition logging.h:29