Skip to content

Commit 935c679

Browse files
committed
Add missing training support to onert Python API (experimental module)
This commit integrate previously omitted modifications for training support in the onert Python API. - Expose new experimental training functionalities by updating the package’s public API: - Modified `__init__.py` to include the `experimental` submodule. - Added the experimental module, which imports train components and exposes `TrainSession`, `traininfo`, `DataLoader`, `optimizer`, `losses`, and `metrics`. - Implemented a flexible DataLoader in the experimental training module: - Supports input from file paths or NumPy arrays. - Handles loading of both .npy and raw binary files with configurable shapes and data types. - Includes batching logic and a split method for training/validation separation. - Improved training compiler behavior in `TrainingCompiler.cc`: - Adjusted the shape validation to accept unspecified dimensions (using `ir::Shape::kUnspecifiedDim`) in addition to dimensions of value 1. ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent c6eb5a0 commit 935c679

5 files changed

Lines changed: 247 additions & 2 deletions

File tree

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Define the public API of the onert package
2-
__all__ = ["infer", "tensorinfo"]
2+
__all__ = ["infer", "tensorinfo", "experimental"]
33

44
# Import and expose the infer module's functionalities
55
from . import infer
66

77
# Import and expose tensorinfo
88
from .common import tensorinfo as tensorinfo
9+
10+
# Import and expose the experimental module's functionalities
11+
from . import experimental
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__all__ = ["train"]
2+
3+
from . import train
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .session import TrainSession as session
2+
from onert.native.libnnfw_api_pybind import traininfo
3+
from .dataloader import DataLoader
4+
from . import optimizer
5+
from . import losses
6+
from . import metrics
7+
8+
__all__ = ["session", "traininfo", "DataLoader", "optimizer", "losses", "metrics"]
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import os
2+
import numpy as np
3+
4+
5+
class DataLoader:
6+
"""
7+
A flexible DataLoader to manage training and validation data.
8+
Automatically detects whether inputs are paths or NumPy arrays.
9+
"""
10+
def __init__(self,
11+
input_dataset,
12+
expected_dataset,
13+
batch_size,
14+
input_shape=None,
15+
expected_shape=None,
16+
dtype=np.float32):
17+
"""
18+
Initialize the DataLoader.
19+
20+
Args:
21+
input_dataset (list of np.ndarray): List of input arrays where each array's first dimension is the batch dimension.
22+
expected_dataset (list of np.ndarray): List of expected arrays where each array's first dimension is the batch dimension.
23+
batch_size (int): Number of samples per batch.
24+
input_shape (tuple, optional): Shape of the input data if raw format is used.
25+
expected_shape (tuple, optional): Shape of the expected data if raw format is used.
26+
dtype (type, optional): Data type of the raw file (default: np.float32).
27+
"""
28+
self.batch_size = batch_size
29+
self.inputs = self._process_dataset(input_dataset, input_shape, dtype)
30+
self.expecteds = self._process_dataset(expected_dataset, expected_shape, dtype)
31+
self.batched_inputs = []
32+
33+
# Verify data consistency
34+
self.num_samples = self.inputs[0].shape[0] # Batch dimension
35+
if self.num_samples != self.expecteds[0].shape[0]:
36+
raise ValueError(
37+
"Input data and expected data must have the same number of samples.")
38+
39+
# Precompute batches
40+
self.batched_inputs, self.batched_expecteds = self._create_batches()
41+
42+
def _process_dataset(self, data, shape, dtype=np.float32):
43+
"""
44+
Process a dataset or file path.
45+
46+
Args:
47+
data (str or np.ndarray): Path to file or NumPy arrays.
48+
shape (tuple, optional): Shape of the data if raw format is used.
49+
dtype (type, optional): Data type for raw files.
50+
51+
Returns:
52+
list of np.ndarray: Loaded or passed data as NumPy arrays.
53+
"""
54+
if isinstance(data, list):
55+
# Check if all elements in the list are NumPy arrays
56+
if all(isinstance(item, np.ndarray) for item in data):
57+
return data
58+
raise ValueError("All elements in the list must be NumPy arrays.")
59+
if isinstance(data, np.ndarray):
60+
# If it's already a NumPy array and is not a list of arrays
61+
if len(data.shape) > 1:
62+
# If the array has multiple dimensions, split it into a list of arrays
63+
return [data[i] for i in range(data.shape[0])]
64+
else:
65+
# If it's a single array, wrap it into a list
66+
return [data]
67+
elif isinstance(data, str):
68+
# If it's a string, assume it's a file path
69+
return [self._load_data(data, shape, dtype)]
70+
else:
71+
raise ValueError("Data must be a NumPy array or a valid file path.")
72+
73+
def _load_data(self, file_path, shape, dtype=np.float32):
74+
"""
75+
Load data from a file, supporting both .npy and raw formats.
76+
77+
Args:
78+
file_path (str): Path to the file to load.
79+
shape (tuple, optional): Shape of the data if raw format is used.
80+
dtype (type, optional): Data type of the raw file (default: np.float32).
81+
82+
Returns:
83+
np.ndarray: Loaded data as a NumPy array.
84+
"""
85+
_, ext = os.path.splitext(file_path)
86+
87+
if ext == ".npy":
88+
# Load .npy file
89+
return np.load(file_path)
90+
elif ext in [".bin", ".raw"]:
91+
# Load raw binary file
92+
if shape is None:
93+
raise ValueError(f"Shape must be provided for raw file: {file_path}")
94+
return self._load_raw(file_path, shape, dtype)
95+
else:
96+
raise ValueError(f"Unsupported file format: {ext}")
97+
98+
def _load_raw(self, file_path, shape, dtype):
99+
"""
100+
Load raw binary data.
101+
102+
Args:
103+
file_path (str): Path to the raw binary file.
104+
shape (tuple): Shape of the data to reshape into.
105+
dtype (type): Data type of the binary file.
106+
107+
Returns:
108+
np.ndarray: Loaded data as a NumPy array.
109+
"""
110+
# Calculate the expected number of elements based on the provided shape
111+
expected_elements = np.prod(shape)
112+
113+
# Calculate the expected size of the raw file in bytes
114+
expected_size = expected_elements * np.dtype(dtype).itemsize
115+
116+
# Get the actual size of the raw file
117+
actual_size = os.path.getsize(file_path)
118+
119+
# Check if the sizes match
120+
if actual_size != expected_size:
121+
raise ValueError(
122+
f"Raw file size ({actual_size} bytes) does not match the expected size "
123+
f"({expected_size} bytes) based on the provided shape {shape} and dtype {dtype}."
124+
)
125+
126+
# Read and load the raw data
127+
with open(file_path, "rb") as f:
128+
data = f.read()
129+
array = np.frombuffer(data, dtype=dtype)
130+
if array.size != expected_elements:
131+
raise ValueError(
132+
f"Raw data size does not match the expected shape: {shape}. "
133+
f"Expected {expected_elements} elements, got {array.size} elements.")
134+
return array.reshape(shape)
135+
136+
def _create_batches(self):
137+
"""
138+
Precompute batches for inputs and expected outputs.
139+
140+
Returns:
141+
tuple: Lists of batched inputs and batched expecteds.
142+
"""
143+
batched_inputs = []
144+
batched_expecteds = []
145+
146+
for batch_start in range(0, self.num_samples, self.batch_size):
147+
batch_end = min(batch_start + self.batch_size, self.num_samples)
148+
149+
# Collect batched inputs
150+
inputs_batch = [
151+
input_array[batch_start:batch_end] for input_array in self.inputs
152+
]
153+
if batch_end - batch_start < self.batch_size:
154+
# Resize the last batch to match batch_size
155+
inputs_batch = [
156+
np.resize(batch, (self.batch_size, *batch.shape[1:]))
157+
for batch in inputs_batch
158+
]
159+
160+
batched_inputs.append(inputs_batch)
161+
162+
# Collect batched expecteds
163+
expecteds_batch = [
164+
expected_array[batch_start:batch_end] for expected_array in self.expecteds
165+
]
166+
if batch_end - batch_start < self.batch_size:
167+
# Resize the last batch to match batch_size
168+
expecteds_batch = [
169+
np.resize(batch, (self.batch_size, *batch.shape[1:]))
170+
for batch in expecteds_batch
171+
]
172+
173+
batched_expecteds.append(expecteds_batch)
174+
175+
return batched_inputs, batched_expecteds
176+
177+
def __iter__(self):
178+
"""
179+
Make the DataLoader iterable.
180+
181+
Returns:
182+
self
183+
"""
184+
self.index = 0
185+
return self
186+
187+
def __next__(self):
188+
"""
189+
Return the next batch of data.
190+
191+
Returns:
192+
tuple: (inputs, expecteds) for the next batch.
193+
"""
194+
if self.index >= len(self.batched_inputs):
195+
raise StopIteration
196+
197+
# Retrieve precomputed batch
198+
input_batch = self.batched_inputs[self.index]
199+
expected_batch = self.batched_expecteds[self.index]
200+
201+
self.index += 1
202+
return input_batch, expected_batch
203+
204+
def split(self, validation_split):
205+
"""
206+
Split the data into training and validation sets.
207+
208+
Args:
209+
validation_split (float): Ratio of validation data. Must be between 0.0 and 1.0.
210+
211+
Returns:
212+
tuple: Two DataLoader instances, one for training and one for validation.
213+
"""
214+
if not (0.0 <= validation_split <= 1.0):
215+
raise ValueError("Validation split must be between 0.0 and 1.0.")
216+
217+
split_index = int(len(self.inputs[0]) * (1.0 - validation_split))
218+
219+
train_inputs = [input_array[:split_index] for input_array in self.inputs]
220+
val_inputs = [input_array[split_index:] for input_array in self.inputs]
221+
train_expecteds = [
222+
expected_array[:split_index] for expected_array in self.expecteds
223+
]
224+
val_expecteds = [
225+
expected_array[split_index:] for expected_array in self.expecteds
226+
]
227+
228+
train_loader = DataLoader(train_inputs, train_expecteds, self.batch_size)
229+
val_loader = DataLoader(val_inputs, val_expecteds, self.batch_size)
230+
231+
return train_loader, val_loader

runtime/onert/core/src/compiler/train/TrainingCompiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
181181
auto &input = trainable_subg->operands().at(ind);
182182
auto new_shape = input.info().shape();
183183
// TODO Consider batch size index
184-
if (new_shape.dim(0) != 1)
184+
if (new_shape.dim(0) != 1 && new_shape.dim(0) != ir::Shape::kUnspecifiedDim)
185185
throw std::runtime_error("the first dim is not 1. It is not supported yet.");
186186
new_shape.dim(0) = _training_info.batchSize();
187187
input.info().shape(new_shape);

0 commit comments

Comments
 (0)