Medial Code Documentation
Loading...
Searching...
No Matches
Data Structures | Functions
xgboost.spark.utils Namespace Reference

Data Structures

class  CommunicatorContext
 

Functions

str get_class_name (Type cls)
 
Dict[str, Any] _get_default_params_from_func (Callable func, Set[str] unsupported_set)
 
Dict[str, Any] _start_tracker (BarrierTaskContext context, int n_workers)
 
Dict[str, Any] _get_rabit_args (BarrierTaskContext context, int n_workers)
 
str _get_host_ip (BarrierTaskContext context)
 
SparkSession _get_spark_session ()
 
logging.Logger get_logger (str name, str level="INFO")
 
int _get_max_num_concurrent_tasks (SparkContext spark_context)
 
bool _is_local (SparkContext spark_context)
 
bool _is_standalone_or_localcluster (SparkContext spark_context)
 
int _get_gpu_id (TaskContext task_context)
 
str _get_or_create_tmp_dir ()
 
XGBModel deserialize_xgb_model (str model, Callable[[], XGBModel] xgb_model_creator)
 
str serialize_booster (Booster booster)
 
Booster deserialize_booster (str model)
 
bool use_cuda (Optional[str] device)
 

Detailed Description

Xgboost pyspark integration submodule for helper functions.

Function Documentation

◆ _get_default_params_from_func()

Dict[str, Any] xgboost.spark.utils._get_default_params_from_func ( Callable  func,
Set[str]   unsupported_set 
)
protected
Returns a dictionary of parameters and their default value of function fn.  Only
the parameters with a default value will be included.

◆ _get_gpu_id()

int xgboost.spark.utils._get_gpu_id ( TaskContext  task_context)
protected
Get the gpu id from the task resources

◆ _get_host_ip()

str xgboost.spark.utils._get_host_ip ( BarrierTaskContext  context)
protected
Gets the hostIP for Spark. This essentially gets the IP of the first worker.

◆ _get_max_num_concurrent_tasks()

int xgboost.spark.utils._get_max_num_concurrent_tasks ( SparkContext  spark_context)
protected
Gets the current max number of concurrent tasks.

◆ _get_rabit_args()

Dict[str, Any] xgboost.spark.utils._get_rabit_args ( BarrierTaskContext  context,
int  n_workers 
)
protected
Get rabit context arguments to send to each worker.

◆ _get_spark_session()

SparkSession xgboost.spark.utils._get_spark_session ( )
protected
Get or create spark session. Note: This function can only be invoked from driver
side.

◆ _is_local()

bool xgboost.spark.utils._is_local ( SparkContext  spark_context)
protected
Whether it is Spark local mode

◆ _start_tracker()

Dict[str, Any] xgboost.spark.utils._start_tracker ( BarrierTaskContext  context,
int  n_workers 
)
protected
Start Rabit tracker with n_workers

◆ deserialize_booster()

Booster xgboost.spark.utils.deserialize_booster ( str  model)
Deserialize an xgboost.core.Booster from the input ser_model_string.

◆ deserialize_xgb_model()

XGBModel xgboost.spark.utils.deserialize_xgb_model ( str  model,
Callable[[], XGBModel]   xgb_model_creator 
)
Deserialize an xgboost.XGBModel instance from the input model.

◆ get_class_name()

str xgboost.spark.utils.get_class_name ( Type  cls)
Return the class name.

◆ get_logger()

logging.Logger xgboost.spark.utils.get_logger ( str  name,
str   level = "INFO" 
)
Gets a logger by name, or creates and configures it for the first time.

◆ serialize_booster()

str xgboost.spark.utils.serialize_booster ( Booster  booster)
Serialize the input booster to a string.

Parameters
----------
booster:
    an xgboost.core.Booster instance

◆ use_cuda()

bool xgboost.spark.utils.use_cuda ( Optional[str]  device)
Whether xgboost is using CUDA workers.