Skip to content

Commit 79ab937

Browse files
authored
Add missing training support to onert Python API (experimental module) (#15175)
* Add missing training support to onert Python API (experimental module) This commit integrate previously omitted modifications for training support in the onert Python API. - Expose new experimental training functionalities by updating the package’s public API: - Modified `__init__.py` to include the `experimental` submodule. - Added the experimental module, which imports train components and exposes `TrainSession`, `traininfo`, `DataLoader`, `optimizer`, `losses`, and `metrics`. - Implemented a flexible DataLoader in the experimental training module: - Supports input from file paths or NumPy arrays. - Handles loading of both .npy and raw binary files with configurable shapes and data types. - Includes batching logic and a split method for training/validation separation. - Improved training compiler behavior in `TrainingCompiler.cc`: - Adjusted the shape validation to accept unspecified dimensions (using `ir::Shape::kUnspecifiedDim`) in addition to dimensions of value 1. ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com> * Add typing annotations to DataLoader
1 parent 662bb28 commit 79ab937

5 files changed

Lines changed: 260 additions & 2 deletions

File tree

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Define the public API of the onert package
2-
__all__ = ["infer", "tensorinfo"]
2+
__all__ = ["infer", "tensorinfo", "experimental"]
33

44
# Import and expose the infer module's functionalities
55
from . import infer
66

77
# Import and expose tensorinfo
88
from .common import tensorinfo as tensorinfo
9+
10+
# Import and expose the experimental module's functionalities
11+
from . import experimental
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__all__ = ["train"]
2+
3+
from . import train
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .session import TrainSession as session
2+
from onert.native.libnnfw_api_pybind import traininfo
3+
from .dataloader import DataLoader
4+
from . import optimizer
5+
from . import losses
6+
from . import metrics
7+
8+
__all__ = ["session", "traininfo", "DataLoader", "optimizer", "losses", "metrics"]
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import os
2+
import numpy as np
3+
from typing import List, Tuple, Union, Optional, Any, Iterator
4+
5+
6+
class DataLoader:
7+
"""
8+
A flexible DataLoader to manage training and validation data.
9+
Automatically detects whether inputs are paths or NumPy arrays.
10+
"""
11+
def __init__(self,
12+
input_dataset: Union[List[np.ndarray], np.ndarray, str],
13+
expected_dataset: Union[List[np.ndarray], np.ndarray, str],
14+
batch_size: int,
15+
input_shape: Optional[Tuple[int, ...]] = None,
16+
expected_shape: Optional[Tuple[int, ...]] = None,
17+
dtype: Any = np.float32) -> None:
18+
"""
19+
Initialize the DataLoader.
20+
21+
Args:
22+
input_dataset (list of np.ndarray | np.ndarray | str):
23+
List of input arrays where each array's first dimension is the batch dimension,
24+
or a single NumPy array, or a file path.
25+
expected_dataset (list of np.ndarray | np.ndarray | str):
26+
List of expected arrays where each array's first dimension is the batch dimension,
27+
or a single NumPy array, or a file path.
28+
batch_size (int): Number of samples per batch.
29+
input_shape (tuple[int, ...], optional): Shape of the input data if raw format is used.
30+
expected_shape (tuple[int, ...], optional): Shape of the expected data if raw format is used.
31+
dtype (type, optional): Data type of the raw file (default: np.float32).
32+
"""
33+
self.batch_size: int = batch_size
34+
self.inputs: List[np.ndarray] = self._process_dataset(input_dataset, input_shape,
35+
dtype)
36+
self.expecteds: List[np.ndarray] = self._process_dataset(
37+
expected_dataset, expected_shape, dtype)
38+
self.batched_inputs: List[List[np.ndarray]] = []
39+
40+
# Verify data consistency
41+
self.num_samples: int = self.inputs[0].shape[0] # Batch dimension
42+
if self.num_samples != self.expecteds[0].shape[0]:
43+
raise ValueError(
44+
"Input data and expected data must have the same number of samples.")
45+
46+
# Precompute batches
47+
self.batched_inputs, self.batched_expecteds = self._create_batches()
48+
49+
def _process_dataset(self,
50+
data: Union[List[np.ndarray], np.ndarray, str],
51+
shape: Optional[Tuple[int, ...]],
52+
dtype: Any = np.float32) -> List[np.ndarray]:
53+
"""
54+
Process a dataset or file path.
55+
56+
Args:
57+
data (str | np.ndarray | list[np.ndarray]): Path to file or NumPy arrays.
58+
shape (tuple[int, ...], optional): Shape of the data if raw format is used.
59+
dtype (type, optional): Data type for raw files.
60+
61+
Returns:
62+
list[np.ndarray]: Loaded or passed data as NumPy arrays.
63+
"""
64+
if isinstance(data, list):
65+
# Check if all elements in the list are NumPy arrays
66+
if all(isinstance(item, np.ndarray) for item in data):
67+
return data
68+
raise ValueError("All elements in the list must be NumPy arrays.")
69+
if isinstance(data, np.ndarray):
70+
# If it's already a NumPy array and is not a list of arrays
71+
if data.ndim > 1:
72+
# If the array has multiple dimensions, split it into a list of arrays
73+
return [data[i] for i in range(data.shape[0])]
74+
else:
75+
# If it's a single array, wrap it into a list
76+
return [data]
77+
elif isinstance(data, str):
78+
# If it's a string, assume it's a file path
79+
return [self._load_data(data, shape, dtype)]
80+
else:
81+
raise ValueError("Data must be a NumPy array or a valid file path.")
82+
83+
def _load_data(self,
84+
file_path: str,
85+
shape: Optional[Tuple[int, ...]],
86+
dtype: Any = np.float32) -> np.ndarray:
87+
"""
88+
Load data from a file, supporting both .npy and raw formats.
89+
90+
Args:
91+
file_path (str): Path to the file to load.
92+
shape (tuple[int, ...], optional): Shape of the data if raw format is used.
93+
dtype (type, optional): Data type of the raw file (default: np.float32).
94+
95+
Returns:
96+
np.ndarray: Loaded data as a NumPy array.
97+
"""
98+
_, ext = os.path.splitext(file_path)
99+
100+
if ext == ".npy":
101+
# Load .npy file
102+
return np.load(file_path)
103+
elif ext in [".bin", ".raw"]:
104+
# Load raw binary file
105+
if shape is None:
106+
raise ValueError(f"Shape must be provided for raw file: {file_path}")
107+
return self._load_raw(file_path, shape, dtype)
108+
else:
109+
raise ValueError(f"Unsupported file format: {ext}")
110+
111+
def _load_raw(self, file_path: str, shape: Tuple[int, ...], dtype: Any) -> np.ndarray:
112+
"""
113+
Load raw binary data.
114+
115+
Args:
116+
file_path (str): Path to the raw binary file.
117+
shape (tuple[int, ...]): Shape of the data to reshape into.
118+
dtype (type): Data type of the binary file.
119+
120+
Returns:
121+
np.ndarray: Loaded data as a NumPy array.
122+
"""
123+
# Calculate the expected number of elements based on the provided shape
124+
expected_elements: int = int(np.prod(shape))
125+
126+
# Calculate the expected size of the raw file in bytes
127+
expected_size: int = expected_elements * np.dtype(dtype).itemsize
128+
129+
# Get the actual size of the raw file
130+
actual_size: int = os.path.getsize(file_path)
131+
132+
# Check if the sizes match
133+
if actual_size != expected_size:
134+
raise ValueError(
135+
f"Raw file size ({actual_size} bytes) does not match the expected size "
136+
f"({expected_size} bytes) based on the provided shape {shape} and dtype {dtype}."
137+
)
138+
139+
# Read and load the raw data
140+
with open(file_path, "rb") as f:
141+
data = f.read()
142+
array = np.frombuffer(data, dtype=dtype)
143+
if array.size != expected_elements:
144+
raise ValueError(
145+
f"Raw data size does not match the expected shape: {shape}. "
146+
f"Expected {expected_elements} elements, got {array.size} elements.")
147+
return array.reshape(shape)
148+
149+
def _create_batches(self) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
150+
"""
151+
Precompute batches for inputs and expected outputs.
152+
153+
Returns:
154+
tuple: Lists of batched inputs and batched expecteds.
155+
"""
156+
batched_inputs: List[List[np.ndarray]] = []
157+
batched_expecteds: List[List[np.ndarray]] = []
158+
159+
for batch_start in range(0, self.num_samples, self.batch_size):
160+
batch_end = min(batch_start + self.batch_size, self.num_samples)
161+
162+
# Collect batched inputs
163+
inputs_batch = [
164+
input_array[batch_start:batch_end] for input_array in self.inputs
165+
]
166+
if batch_end - batch_start < self.batch_size:
167+
# Resize the last batch to match batch_size
168+
inputs_batch = [
169+
np.resize(batch, (self.batch_size, *batch.shape[1:]))
170+
for batch in inputs_batch
171+
]
172+
173+
batched_inputs.append(inputs_batch)
174+
175+
# Collect batched expecteds
176+
expecteds_batch = [
177+
expected_array[batch_start:batch_end] for expected_array in self.expecteds
178+
]
179+
if batch_end - batch_start < self.batch_size:
180+
# Resize the last batch to match batch_size
181+
expecteds_batch = [
182+
np.resize(batch, (self.batch_size, *batch.shape[1:]))
183+
for batch in expecteds_batch
184+
]
185+
186+
batched_expecteds.append(expecteds_batch)
187+
188+
return batched_inputs, batched_expecteds
189+
190+
def __iter__(self) -> Iterator[Tuple[List[np.ndarray], List[np.ndarray]]]:
191+
"""
192+
Make the DataLoader iterable.
193+
194+
Returns:
195+
self
196+
"""
197+
self.index = 0
198+
return self
199+
200+
def __next__(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
201+
"""
202+
Return the next batch of data.
203+
204+
Returns:
205+
tuple: (inputs, expecteds) for the next batch.
206+
"""
207+
if self.index >= len(self.batched_inputs):
208+
raise StopIteration
209+
210+
# Retrieve precomputed batch
211+
input_batch = self.batched_inputs[self.index]
212+
expected_batch = self.batched_expecteds[self.index]
213+
214+
self.index += 1
215+
return input_batch, expected_batch
216+
217+
def split(self, validation_split: float) -> Tuple["DataLoader", "DataLoader"]:
218+
"""
219+
Split the data into training and validation sets.
220+
221+
Args:
222+
validation_split (float): Ratio of validation data. Must be between 0.0 and 1.0.
223+
224+
Returns:
225+
tuple: Two DataLoader instances, one for training and one for validation.
226+
"""
227+
if not (0.0 <= validation_split <= 1.0):
228+
raise ValueError("Validation split must be between 0.0 and 1.0.")
229+
230+
split_index = int(len(self.inputs[0]) * (1.0 - validation_split))
231+
232+
train_inputs = [input_array[:split_index] for input_array in self.inputs]
233+
val_inputs = [input_array[split_index:] for input_array in self.inputs]
234+
train_expecteds = [
235+
expected_array[:split_index] for expected_array in self.expecteds
236+
]
237+
val_expecteds = [
238+
expected_array[split_index:] for expected_array in self.expecteds
239+
]
240+
241+
train_loader = DataLoader(train_inputs, train_expecteds, self.batch_size)
242+
val_loader = DataLoader(val_inputs, val_expecteds, self.batch_size)
243+
244+
return train_loader, val_loader

runtime/onert/core/src/compiler/train/TrainingCompiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
181181
auto &input = trainable_subg->operands().at(ind);
182182
auto new_shape = input.info().shape();
183183
// TODO Consider batch size index
184-
if (new_shape.dim(0) != 1)
184+
if (new_shape.dim(0) != 1 && new_shape.dim(0) != ir::Shape::kUnspecifiedDim)
185185
throw std::runtime_error("the first dim is not 1. It is not supported yet.");
186186
new_shape.dim(0) = _training_info.batchSize();
187187
input.info().shape(new_shape);

0 commit comments

Comments
 (0)