Skip to content

Survival Random Forest predict_survival_function does not scale with n_jobs #382

@PierrickPochelu

Description

@PierrickPochelu

Description

My laptop is equipped with multi-cores. Increasing n_jobs improve computing speed of fit and predict. However, it is inefficient for improving the speed of predict_survival_function.

Code Sample to Reproduce the Bug

import numpy as np
import pandas as pd
import time

np.random.seed(42)

def create_data(nb_events, nb_features):
    np_X=np.random.rand(nb_events, nb_features)
    np_time=np.random.rand(nb_events, 1)
    np_is_living=np_X[:,0] < np_time[:,0]
    y=np.empty(nb_events, dtype=[('event', '?'), ('time', '<f16')])
    y['event']=np_is_living.reshape(-1)
    y['time']=np_time.reshape(-1)
    X=pd.DataFrame(np_X,columns=['f'+str(i) for i in range(1,nb_features+1)])
    return X, y

X_train,y_train=create_data(nb_events=150, nb_features=8)
X_test,y_test=create_data(nb_events=150, nb_features=8)

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42, n_jobs=8) #<------------- Increasing n_jobs does not improve predict_survival_function speed


print("Fitting ....")
st=time.time()
rsf.fit(X_train,y_train)
print(f"Fit time:{time.time()-st}")

st=time.time()
for i in range(100):
    pred=rsf.predict_survival_function(X_test)
print(f"Predict time:{time.time()-st}")

Expected Results
Compared to n_jobs=1, n_jobs=8 should theoretically divided by 8 the computing time, at least I expect a division by 2.

Actual Results
n_jobs=8 is slower than n_jobs=1

Versions
System:
python: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
executable: /home/pierrick/PycharmProjects/venv/bin/python
machine: Linux-5.19.0-45-generic-x86_64-with-glibc2.35

Python dependencies:
sklearn: 1.2.2
pip: 22.3.1
setuptools: 65.5.1
numpy: 1.24.3
scipy: 1.10.1
Cython: None
pandas: 2.0.2
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 20

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Prescott
num_threads: 20

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: Prescott
num_threads: 20
sksurv: 0.21.0
numexpr: 2.8.4
osqp: 0.6.3

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions