-
Notifications
You must be signed in to change notification settings - Fork 160
Expand file tree
/
Copy pathruntime.rs
More file actions
125 lines (105 loc) · 3.72 KB
/
runtime.rs
File metadata and controls
125 lines (105 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use cubecl_common::profile::TimingMethod;
use cubecl_core::{
CubeCount, CubeDim, MemoryConfiguration, Runtime,
channel::MpscComputeChannel,
client::ComputeClient,
ir::{StorageType, TargetProperties},
};
use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows};
use cubecl_runtime::{
ComputeRuntime, DeviceProperties,
memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
storage::BytesStorage,
};
use sysinfo::System;
use crate::{
compiler::{MlirCompiler, register_supported_types},
compute::server::{CpuContext, CpuServer},
device::CpuDevice,
};
#[derive(Default)]
pub struct RuntimeOptions {
/// Configures the memory management.
pub memory_config: MemoryConfiguration,
}
#[derive(Debug)]
pub struct CpuRuntime;
static RUNTIME: ComputeRuntime<CpuDevice, Server, Channel> = ComputeRuntime::new();
pub type CpuCompiler = MlirCompiler;
type Server = CpuServer;
type Channel = MpscComputeChannel<Server>;
fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
let max_cube_dim = CubeDim::new(u32::MAX, u32::MAX, u32::MAX);
let max_cube_count = CubeCount::Static(64, 64, 64);
let system = System::new_all();
let max_shared_memory_size = system
.cgroup_limits()
.map(|g| g.total_memory)
.unwrap_or(system.total_memory()) as usize;
let topology = HardwareProperties {
plane_size_min: 1,
plane_size_max: 1,
max_bindings: u32::MAX,
max_shared_memory_size,
max_cube_count,
max_units_per_cube: u32::MAX,
max_cube_dim,
num_streaming_multiprocessors: None,
num_tensor_cores: None,
min_tensor_cores_dim: None,
};
let storage = BytesStorage::default();
const ALIGNMENT: u64 = 4;
let mem_properties = MemoryDeviceProperties {
max_page_size: max_shared_memory_size as u64,
alignment: ALIGNMENT,
data_transfer_async: false,
};
let memory_management =
MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config);
let mut device_props = DeviceProperties::new(
Default::default(),
mem_properties,
topology,
TimingMethod::Device,
// Default to contiguous on CPU.
cubecl_runtime::server::AllocationKind::Contiguous,
);
register_supported_types(&mut device_props);
let ctx = CpuContext::new(memory_management);
let server = CpuServer::new(ctx);
ComputeClient::new(Channel::new(server), device_props, ())
}
impl Runtime for CpuRuntime {
type Compiler = CpuCompiler;
type Server = CpuServer;
type Channel = Channel;
type Device = CpuDevice;
fn client(_device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
RUNTIME.client(_device, move || create_client(RuntimeOptions::default()))
}
fn name(_client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
"cpu"
}
fn supported_line_sizes() -> &'static [u8] {
&[64, 32, 16, 8, 4, 2, 1]
}
fn line_size_type(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
Self::supported_line_sizes()
.iter()
.filter(|v| **v as usize * elem.size() <= 64)
.cloned() // 128 bits
}
fn max_cube_count() -> (u32, u32, u32) {
(u32::MAX, u32::MAX, u32::MAX)
}
fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)
}
fn target_properties() -> TargetProperties {
TargetProperties {
// Values are irrelevant, since no wgsl backends currently support manual mma
mma: Default::default(),
}
}
}