Skip to content

Commit bee0243

Browse files
committed
Add user_cpu_context and ability to provide host_policy via it
1 parent 4eb47a5 commit bee0243

File tree

5 files changed

+92
-0
lines changed

5 files changed

+92
-0
lines changed

cpp/oneapi/dal/compute.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "oneapi/dal/detail/compute_ops.hpp"
20+
#include "oneapi/dal/detail/user_policy.hpp"
2021
#include "oneapi/dal/detail/spmd_policy.hpp"
2122
#include "oneapi/dal/spmd/communicator.hpp"
2223

@@ -28,6 +29,11 @@ auto compute(Args&&... args) {
2829
return dal::detail::compute_dispatch(std::forward<Args>(args)...);
2930
}
3031

32+
template <typename... Args>
33+
auto compute(detail::user_cpu_context uctx, Args&&... args) {
34+
return dal::detail::compute_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
35+
}
36+
3137
#ifdef ONEDAL_DATA_PARALLEL
3238
template <typename... Args>
3339
auto compute(sycl::queue& queue, Args&&... args) {

cpp/oneapi/dal/detail/policy.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class ONEDAL_EXPORT host_policy : public base {
103103
}
104104
host_policy(const host_policy&) = default;
105105
host_policy(host_policy&&) = default;
106+
host_policy& operator=(const host_policy&) = default;
106107

107108
static host_policy get_default() {
108109
return host_policy(make_default_impl());

cpp/oneapi/dal/detail/user_policy.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
#include "oneapi/dal/detail/user_policy.hpp"
17+
18+
namespace oneapi::dal::detail {
19+
class user_cpu_context_impl {
20+
public:
21+
user_cpu_context_impl() : policy_(host_policy::get_default()) {}
22+
user_cpu_context_impl(const host_policy& policy) : policy_(policy) {}
23+
void set_host_policy(const host_policy& policy) {
24+
policy_ = policy;
25+
}
26+
host_policy get_host_policy() {
27+
return policy_;
28+
}
29+
30+
private:
31+
detail::host_policy policy_;
32+
};
33+
34+
user_cpu_context::user_cpu_context() : impl_(new user_cpu_context_impl()) {}
35+
36+
user_cpu_context::user_cpu_context(const host_policy& policy)
37+
: impl_(new user_cpu_context_impl(policy)) {}
38+
39+
void user_cpu_context::set_host_policy(const host_policy& policy) {
40+
impl_->set_host_policy(policy);
41+
}
42+
43+
host_policy user_cpu_context::get_host_policy() {
44+
return impl_->get_host_policy();
45+
}
46+
47+
} // namespace oneapi::dal::detail

cpp/oneapi/dal/detail/user_policy.hpp

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
#include "oneapi/dal/detail/policy.hpp"
17+
18+
namespace oneapi::dal::detail {
19+
class user_cpu_context_impl;
20+
21+
class user_cpu_context {
22+
public:
23+
user_cpu_context();
24+
user_cpu_context(const host_policy& policy);
25+
host_policy get_host_policy();
26+
void set_host_policy(const host_policy& policy);
27+
28+
private:
29+
pimpl<user_cpu_context_impl> impl_;
30+
};
31+
32+
} //namespace oneapi::dal::detail

cpp/oneapi/dal/train.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "oneapi/dal/detail/train_ops.hpp"
20+
#include "oneapi/dal/detail/user_policy.hpp"
2021
#include "oneapi/dal/detail/spmd_policy.hpp"
2122
#include "oneapi/dal/spmd/communicator.hpp"
2223

@@ -28,6 +29,11 @@ auto train(Args&&... args) {
2829
return dal::detail::train_dispatch(std::forward<Args>(args)...);
2930
}
3031

32+
template <typename... Args>
33+
auto train(detail::user_cpu_context uctx, Args&&... args) {
34+
return dal::detail::train_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
35+
}
36+
3137
#ifdef ONEDAL_DATA_PARALLEL
3238
template <typename... Args>
3339
auto train(sycl::queue& queue, Args&&... args) {

0 commit comments

Comments
 (0)