Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/asolovev rf optimizations with rng #3029

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
minor fixes
Alexandr-Solovev committed Nov 5, 2024
commit 81d7dfe7100a714152fe3203d5c193796ed1a68f
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ class train_kernel_hist_impl {
using model_manager_t = train_model_manager<Float, Index, Task>;
using train_context_t = train_context<Float, Index, Task>;
using imp_data_t = impurity_data<Float, Index, Task>;
using rng_engine_t = pr::engine;
using rng_engine_t = pr::daal_engine<pr::engine_list_cpu::mt2203>;
using rng_engine_list_t = std::vector<rng_engine_t>;
using msg = dal::detail::error_messages;
using comm_t = bk::communicator<spmd::device_memory_access::usm>;
Original file line number Diff line number Diff line change
@@ -396,12 +396,12 @@ sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::gen_initial_tree_or
Index* const node_list_ptr = node_list_host.get_mutable_data();

for (Index node_idx = 0; node_idx < node_count; ++node_idx) {
pr::rng<Index> rn_gen;
pr::daal_rng<Index> rn_gen;
Index* gen_row_idx_global_ptr =
selected_row_global_ptr + ctx.selected_row_total_count_ * node_idx;
rn_gen.uniform(ctx.selected_row_total_count_,
gen_row_idx_global_ptr,
rng_engine_list[engine_offset + node_idx].get_state(),
rng_engine_list[engine_offset + node_idx].get_cpu_engine_state(),
0,
ctx.row_total_count_);

@@ -483,15 +483,15 @@ train_kernel_hist_impl<Float, Bin, Index, Task>::gen_feature_list(

auto node_vs_tree_map_list_host = node_vs_tree_map_list.to_host(queue_);

pr::rng<Index> rn_gen;
pr::daal_rng<Index> rn_gen;
auto tree_map_ptr = node_vs_tree_map_list_host.get_mutable_data();
if (ctx.selected_ftr_count_ != ctx.column_count_) {
for (Index node = 0; node < node_count; ++node) {
rn_gen.uniform_without_replacement(
rn_gen.uniform_without_replacement_cpu(
ctx.selected_ftr_count_,
selected_features_host_ptr + node * ctx.selected_ftr_count_,
selected_features_host_ptr + (node + 1) * ctx.selected_ftr_count_,
rng_engine_list[tree_map_ptr[node]].get_state(),
rng_engine_list[tree_map_ptr[node]].get_cpu_engine_state(),
0,
ctx.column_count_);
}
@@ -524,7 +524,7 @@ train_kernel_hist_impl<Float, Bin, Index, Task>::gen_random_thresholds(

auto node_vs_tree_map_list_host = node_vs_tree_map.to_host(queue_);

pr::rng<Float> rn_gen;
pr::daal_rng<Float> rn_gen;
auto tree_map_ptr = node_vs_tree_map_list_host.get_mutable_data();

// Create arrays for random generated bins
@@ -539,7 +539,7 @@ train_kernel_hist_impl<Float, Bin, Index, Task>::gen_random_thresholds(
for (Index node = 0; node < node_count; ++node) {
rn_gen.uniform(ctx.selected_ftr_count_,
random_bins_host_ptr + node * ctx.selected_ftr_count_,
rng_engine_list[tree_map_ptr[node]].get_state(),
rng_engine_list[tree_map_ptr[node]].get_cpu_engine_state(),
0.0f,
1.0f);
}
@@ -1660,12 +1660,13 @@ sycl::event train_kernel_hist_impl<Float, Bin, Index, Task>::compute_results(

const Float div1 = Float(1) / Float(built_tree_count + tree_idx_in_block + 1);

pr::rng<Index> rn_gen;
pr::daal_rng<Index> rn_gen;

for (Index column_idx = 0; column_idx < ctx.column_count_; ++column_idx) {
rn_gen.shuffle(oob_row_count,
permutation_ptr,
engine_arr[built_tree_count + tree_idx_in_block].get_state());
rn_gen.shuffle(
oob_row_count,
permutation_ptr,
engine_arr[built_tree_count + tree_idx_in_block].get_cpu_engine_state());
const Float oob_err_perm = compute_oob_error_perm(ctx,
model_manager,
data_host,
17 changes: 17 additions & 0 deletions cpp/oneapi/dal/backend/primitives/rng/rng_cpu.hpp
Original file line number Diff line number Diff line change
@@ -40,6 +40,23 @@ class daal_engine {
}
}

explicit daal_engine(const daal::algorithms::engines::EnginePtr& eng) : daal_engine_(eng) {
impl_ = dynamic_cast<daal::algorithms::engines::internal::BatchBaseImpl*>(eng.get());
if (!impl_) {
throw domain_error(dal::detail::error_messages::rng_engine_is_not_supported());
}
}

daal_engine& operator=(const daal::algorithms::engines::EnginePtr& eng) {
daal_engine_ = eng;
impl_ = dynamic_cast<daal::algorithms::engines::internal::BatchBaseImpl*>(eng.get());
if (!impl_) {
throw domain_error(dal::detail::error_messages::rng_engine_is_not_supported());
}

return *this;
}

virtual ~daal_engine() = default;

void* get_cpu_engine_state() const {
72 changes: 70 additions & 2 deletions cpp/oneapi/dal/backend/primitives/rng/rng_engine_collection.hpp
Original file line number Diff line number Diff line change
@@ -30,10 +30,78 @@ namespace oneapi::dal::backend::primitives {

#ifdef ONEDAL_DATA_PARALLEL

template <typename Size = std::int64_t, engine_list EngineType = engine_list::mt2203>
template <typename Size = std::int64_t>
class engine_collection {
public:
engine_collection(sycl::queue& queue, Size count, std::int64_t seed = 777)
explicit engine_collection(Size count, std::int64_t seed = 777)
: count_(count),
engine_(daal::algorithms::engines::mt2203::Batch<>::create(seed)),
params_(count),
technique_(daal::algorithms::engines::internal::family),
daal_engine_list_(count) {}

template <typename Op>
std::vector<daal_engine<engine_list_cpu::mt2203>> operator()(Op&& op) {
daal::services::Status status;
for (Size i = 0; i < count_; ++i) {
op(i, params_.nSkip[i]);
}
select_parallelization_technique(technique_);
daal::algorithms::engines::internal::EnginesCollection<daal::sse2> engine_collection(
engine_,
technique_,
params_,
daal_engine_list_,
&status);
if (!status) {
dal::backend::interop::status_to_exception(status);
}

std::vector<daal_engine<engine_list_cpu::mt2203>> engine_list(count_);
for (Size i = 0; i < count_; ++i) {
engine_list[i] = daal_engine_list_[i];
}

//copy elision
return engine_list;
}

private:
void select_parallelization_technique(
daal::algorithms::engines::internal::ParallelizationTechnique& technique) {
auto daal_engine_impl =
dynamic_cast<daal::algorithms::engines::internal::BatchBaseImpl*>(engine_.get());

daal::algorithms::engines::internal::ParallelizationTechnique techniques[] = {
daal::algorithms::engines::internal::family,
daal::algorithms::engines::internal::leapfrog,
daal::algorithms::engines::internal::skipahead
};

for (auto& techn : techniques) {
if (daal_engine_impl->hasSupport(techn)) {
technique = techn;
return;
}
}

throw domain_error(
dal::detail::error_messages::rng_engine_does_not_support_parallelization_techniques());
}

private:
Size count_;
daal::algorithms::engines::EnginePtr engine_;
daal::algorithms::engines::internal::Params<daal::sse2> params_;
daal::algorithms::engines::internal::ParallelizationTechnique technique_;
daal::services::internal::TArray<daal::algorithms::engines::EnginePtr, daal::sse2>
daal_engine_list_;
};

template <typename Size = std::int64_t, engine_list EngineType = engine_list::mt2203>
class engine_collection_oneapi {
public:
engine_collection_oneapi(sycl::queue& queue, Size count, std::int64_t seed = 777)
: count_(count),
seed_(seed) {
engines_.reserve(count_);