diff --git a/docs/howto/how-to-use-nnfw-python-api.md b/docs/howto/how-to-use-nnfw-python-api.md index 5a18b79aeb2..8951a4541fa 100644 --- a/docs/howto/how-to-use-nnfw-python-api.md +++ b/docs/howto/how-to-use-nnfw-python-api.md @@ -41,8 +41,137 @@ outputs = session.inference() ## Run Inference with app on the target devices -reference app : [minimal-python app](https://github.com/Samsung/ONE/blob/master/runtime/onert/sample/minimal-python) +reference app : [minimal-python app](https://github.com/Samsung/ONE/blob/master/runtime/onert/sample/minimal-python/infer) ``` $ python3 minimal.py path_to_nnpackage_directory ``` + +## Experimental API + +### Train with dataset + +1. Import the Module and Initialize TrainSession + +```python +import onert + +# Create a training session and load the nnpackage +# Default backends is set to "train". +session = onert.experimental.train.session(nnpackage_path, backends="train") +``` + +2. Prepare Input and Output Data + +```python +# Create a DataLoader + +from onert.experimental.train import DataLoader + +# Define the paths for input and expected output data +input_path = "path/to/input_data.npy" +expected_path = "path/to/expected_data.npy" + +# Define batch size +batch_size = 16 + +# Initialize DataLoader +data_loader = DataLoader(input_dataset=input_path, + expected_dataset=expected_path, + batch_size=batch_size) +``` + +3. Compile the Session + +```python +# Set Optimizer, Loss, and Metrics + +from onert.experimental.train import optimizer, losses, metrics + +# Define optimizer +optimizer_fn = optimizer.Adam(learning_rate=0.01) + +# Define loss function +loss_fn = losses.CategoricalCrossentropy() + +# Define metrics +metric_list = [metrics.CategoricalAccuracy()] + +# Compile the training session +session.compile(optimizer=optimizer_fn, loss=loss_fn, metrics=metric_list, batch_size=batch_size) +``` + +4. Train the Model + +```python +# Train and Validate + +# Train the model +session.train(data_loader=data_loader, + epochs=5, + validation_split=0.2, + checkpoint_path="checkpoint.ckpt") +``` + +5. Train one step with data loader (Optional) + +```python +for batch_idx, (inputs, expecteds) in enumerate(data_loader): + # Train on a single step + results = sess.train_step(inputs, expecteds) +``` + +### Custom Metric + +You can use custom metrics instread of provided metrics + +```python +from onert.experimental.train import metrics + +class CustomMeanAbsoluteError(Metric): + """ + Custom metric to calculate the mean absolute error (MAE) between predictions and ground truth. + """ + def __init__(self): + self.total_absolute_error = 0.0 + self.total_samples = 0 + + def update_state(self, outputs, expecteds): + """ + Update the metric's state based on the outputs and expected values. + + Args: + outputs (list of np.ndarray): List of model outputs. + expecteds (list of np.ndarray): List of expected (ground truth) values. + """ + for output, expected in zip(outputs, expecteds): + self.total_absolute_error += np.sum(np.abs(output - expected)) + self.total_samples += expected.size + + def result(self): + """ + Calculate and return the current mean absolute error. + + Returns: + float: The mean absolute error. + """ + return self.total_absolute_error / self.total_samples if self.total_samples > 0 else 0.0 + + def reset_state(self): + """ + Reset the metric's state for the next epoch. + """ + self.total_absolute_error = 0.0 + self.total_samples = 0 + +# Add the custom metric to the list +metric_list = [ + CustomMeanAbsoluteError() +] + +# Compile the session with the custom metric +session.compile(optimizer=optimizer_fn, loss=loss_fn, metrics=metric_list, batch_size=batch_size) +``` + +### Run Train with dataset on the target devices +reference app : [minimal-python app](https://github.com/Samsung/ONE/blob/master/runtime/onert/sample/minimal-python/experimental/) diff --git a/infra/nnfw/CMakeLists.txt b/infra/nnfw/CMakeLists.txt index 2c3a30a3616..7d41c793f69 100644 --- a/infra/nnfw/CMakeLists.txt +++ b/infra/nnfw/CMakeLists.txt @@ -11,6 +11,8 @@ set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) ### CMAKE_BUILD_TYPE_LC: Build type lower case string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LC) +# string(TOLOWER ${CMAKE_BUILD_TYPE} "debug") +set(CMAKE_CXX_FLAGS_DEBUG "-g") set(NNAS_PROJECT_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../.." CACHE INTERNAL "Where to find nnas top-level source directory" diff --git a/infra/nnfw/python/setup.py b/infra/nnfw/python/setup.py index d140e17e2cb..37f63119d6d 100644 --- a/infra/nnfw/python/setup.py +++ b/infra/nnfw/python/setup.py @@ -52,12 +52,23 @@ # copy *py files to package_directory PY_DIR = os.path.join(THIS_FILE_DIR, '../../../runtime/onert/api/python/package') - for py_file in os.listdir(PY_DIR): - if py_file.endswith(".py"): - src_path = os.path.join(PY_DIR, py_file) - dest_path = os.path.join(THIS_FILE_DIR, package_directory) - shutil.copy(src_path, dest_path) - print(f"Copied '{src_path}' to '{dest_path}'") + for root, dirs, files in os.walk(PY_DIR): + # Calculate the relative path from the source directory + rel_path = os.path.relpath(root, PY_DIR) + dest_dir = os.path.join(THIS_FILE_DIR, package_directory) + dest_sub_dir = os.path.join(dest_dir, rel_path) + print(f"dest_sub_dir '{dest_sub_dir}'") + + # Ensure the corresponding destination subdirectory exists + os.makedirs(dest_sub_dir, exist_ok=True) + + # Copy only .py files + for py_file in files: + if py_file.endswith(".py"): + src_path = os.path.join(root, py_file) + # dest_path = os.path.join(THIS_FILE_DIR, package_directory) + shutil.copy(src_path, dest_sub_dir) + print(f"Copied '{src_path}' to '{dest_sub_dir}'") # remove architecture directory if os.path.exists(package_directory): @@ -136,12 +147,12 @@ def get_directories(): # copy .so files to architecture directories setup(name=package_name, - version='0.1.0', + version='0.2.0', description='onert API binding', long_description='It provides onert Python api', url='https://github.com/Samsung/ONE', license='Apache-2.0, MIT, BSD-2-Clause, BSD-3-Clause, Mozilla Public License 2.0', has_ext_modules=lambda: True, - packages=[package_directory], + packages=find_packages(), package_data={package_directory: so_list}, install_requires=['numpy >= 1.19']) diff --git a/runtime/onert/api/python/include/nnfw_api_wrapper.h b/runtime/onert/api/python/include/nnfw_api_wrapper.h index 23e76b5ce85..3b2d7cd71f2 100644 --- a/runtime/onert/api/python/include/nnfw_api_wrapper.h +++ b/runtime/onert/api/python/include/nnfw_api_wrapper.h @@ -14,13 +14,24 @@ * limitations under the License. */ +#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ +#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ + #include "nnfw.h" +#include "nnfw_experimental.h" #include #include namespace py = pybind11; +namespace onert +{ +namespace api +{ +namespace python +{ + /** * @brief tensor info describes the type and shape of tensors * @@ -120,6 +131,7 @@ class NNFW_SESSION void close_session(); void set_input_tensorinfo(uint32_t index, const tensorinfo *tensor_info); + void prepare(); void run(); void run_async(); void wait(); @@ -159,4 +171,68 @@ class NNFW_SESSION void set_output_layout(uint32_t index, const char *layout); tensorinfo input_tensorinfo(uint32_t index); tensorinfo output_tensorinfo(uint32_t index); + + ////////////////////////////////////////////// + // Experimental APIs for training + ////////////////////////////////////////////// + nnfw_train_info train_get_traininfo(); + void train_set_traininfo(const nnfw_train_info *info); + + template void train_set_input(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_input_tensorinfo(this->session, index, &tensor_info); + + py::buffer_info buf_info = buffer.request(); + const auto buf_shape = buf_info.shape; + assert(tensor_info.rank == static_cast(buf_shape.size()) && buf_shape.size() > 0); + tensor_info.dims[0] = static_cast(buf_shape.at(0)); + + ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info)); + } + template void train_set_expected(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_output_tensorinfo(this->session, index, &tensor_info); + + py::buffer_info buf_info = buffer.request(); + const auto buf_shape = buf_info.shape; + assert(tensor_info.rank == static_cast(buf_shape.size()) && buf_shape.size() > 0); + tensor_info.dims[0] = static_cast(buf_shape.at(0)); + + ensure_status( + nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info)); + } + template void train_set_output(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_output_tensorinfo(this->session, index, &tensor_info); + NNFW_TYPE type = tensor_info.dtype; + uint32_t output_elements = num_elems(&tensor_info); + size_t length = sizeof(T) * output_elements; + + ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length)); + } + + void train_prepare(); + void train(bool update_weights); + float train_get_loss(uint32_t index); + + void train_export_circle(const py::str &path); + void train_import_checkpoint(const py::str &path); + void train_export_checkpoint(const py::str &path); + + ////////////////////////////////////////////// + // Optional APIs for training + ////////////////////////////////////////////// + // nnfw_tensorinfo train_input_tensorinfo(uint32_t index); + // nnfw_tensorinfo train_expected_tensorinfo(uint32_t index); + + // TODO Add other apis }; + +} // namespace python +} // namespace api +} // namespace onert + +#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ diff --git a/runtime/onert/api/python/include/nnfw_session_bindings.h b/runtime/onert/api/python/include/nnfw_session_bindings.h new file mode 100644 index 00000000000..d90cb4af5ba --- /dev/null +++ b/runtime/onert/api/python/include/nnfw_session_bindings.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_API_PYTHON_NNFW_SESSION_BINDINGS_H__ +#define __ONERT_API_PYTHON_NNFW_SESSION_BINDINGS_H__ + +#include + +// Declare binding common functions +void bind_nnfw_session(pybind11::module_ &m); + +// Declare binding experimental functinos +void bind_experimental_nnfw_session(pybind11::module_ &m); + +#endif // __ONERT_API_PYTHON_NNFW_SESSION_BINDINGS_H__ diff --git a/runtime/onert/api/python/include/nnfw_tensorinfo_bindings.h b/runtime/onert/api/python/include/nnfw_tensorinfo_bindings.h new file mode 100644 index 00000000000..931da47fdff --- /dev/null +++ b/runtime/onert/api/python/include/nnfw_tensorinfo_bindings.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_API_PYTHON_NNFW_TENSORINFO_BINDINGS_H__ +#define __ONERT_API_PYTHON_NNFW_TENSORINFO_BINDINGS_H__ + +#include + +// Declare binding tensorinfo +void bind_tensorinfo(pybind11::module_ &m); + +#endif // __ONERT_API_PYTHON_NNFW_TENSORINFO_BINDINGS_H__ diff --git a/runtime/onert/api/python/include/nnfw_traininfo_bindings.h b/runtime/onert/api/python/include/nnfw_traininfo_bindings.h new file mode 100644 index 00000000000..8ae9b42d104 --- /dev/null +++ b/runtime/onert/api/python/include/nnfw_traininfo_bindings.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ +#define __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ + +#include +#include + +namespace py = pybind11; + +// Declare binding train enums +void bind_nnfw_train_enums(py::module_ &m); + +// Declare binding loss info +void bind_nnfw_loss_info(py::module_ &m); + +// Declare binding train info +void bind_nnfw_train_info(py::module_ &m); + +#endif // __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ diff --git a/runtime/onert/api/python/package/__init__.py b/runtime/onert/api/python/package/__init__.py index 05e235a3112..01bd06372ac 100644 --- a/runtime/onert/api/python/package/__init__.py +++ b/runtime/onert/api/python/package/__init__.py @@ -1,2 +1,14 @@ -__all__ = ['infer'] -from . import infer +# Define the public API of the onert package +# __all__ = ["infer", "train"] +__all__ = ["infer", "tensorinfo", "train"] +# __all__ = ["tensorinfo", "train"] + +# Import and expose the infer module's functionalities +from . import infer as infer +# from . import session as infer, tensorinfo + +# Import and expose tensorinfo +from .infer import tensorinfo as tensorinfo + +# Import and expose the experimental module's functionalities +from . import experimental diff --git a/runtime/onert/api/python/package/common/__init__.py b/runtime/onert/api/python/package/common/__init__.py new file mode 100644 index 00000000000..6790b3aec34 --- /dev/null +++ b/runtime/onert/api/python/package/common/__init__.py @@ -0,0 +1,3 @@ +from .basesession import BaseSession + +__all__ = ["BaseSession"] diff --git a/runtime/onert/api/python/package/common/basesession.py b/runtime/onert/api/python/package/common/basesession.py new file mode 100644 index 00000000000..bf14f48247b --- /dev/null +++ b/runtime/onert/api/python/package/common/basesession.py @@ -0,0 +1,91 @@ +import numpy as np + + +def num_elems(tensor_info): + """Get the total number of elements in nnfw_tensorinfo.dims.""" + n = 1 + for x in range(tensor_info.rank): + n *= tensor_info.dims[x] + return n + + +class BaseSession: + """ + Base class providing common functionality for inference and training sessions. + """ + def __init__(self, backend_session=None): + """ + Initialize the BaseSession with a backend session. + + Args: + backend_session: A backend-specific session object (e.g., nnfw_session). + """ + self.session = backend_session + self.inputs = [] + self.outputs = [] + + def __getattr__(self, name): + """ + Delegate attribute access to the bound NNFW_SESSION instance. + + Args: + name (str): The name of the attribute or method to access. + + Returns: + The attribute or method from the bound NNFW_SESSION instance. + """ + if name in self.__dict__: + # First, try to get the attribute from the instance's own dictionary + return self.__dict__[name] + elif hasattr(self.session, name): + # If not found, delegate to the session object + return getattr(self.session, name) + else: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + def _recreate_session(self, backend_session): + """ + Protected method to recreate the session. + Subclasses can override this method to provide custom session recreation logic. + """ + if self.session is not None: + del self.session # Clean up the existing session + self.session = backend_session + + def set_inputs(self, size, inputs_array=[]): + """ + Set the input tensors for the session. + + Args: + size (int): Number of input tensors. + inputs_array (list): List of numpy arrays for the input data. + """ + for i in range(size): + input_tensorinfo = self.session.input_tensorinfo(i) + + if len(inputs_array) > i: + input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype) + else: + print( + f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0." + ) + input_array = np.zeros((num_elems(input_tensorinfo)), + dtype=input_tensorinfo.dtype) + + self.session.set_input(i, input_array) + self.inputs.append(input_array) + + def set_outputs(self, size): + """ + Set the output tensors for the session. + + Args: + size (int): Number of output tensors. + """ + for i in range(size): + output_tensorinfo = self.session.output_tensorinfo(i) + output_array = np.zeros((num_elems(output_tensorinfo)), + dtype=output_tensorinfo.dtype) + self.session.set_output(i, output_array) + self.outputs.append(output_array) diff --git a/runtime/onert/api/python/package/experimental/__init__.py b/runtime/onert/api/python/package/experimental/__init__.py new file mode 100644 index 00000000000..b472b1d73ce --- /dev/null +++ b/runtime/onert/api/python/package/experimental/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["train"] + +from . import train diff --git a/runtime/onert/api/python/package/experimental/train/__init__.py b/runtime/onert/api/python/package/experimental/train/__init__.py new file mode 100644 index 00000000000..c5148773db0 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/__init__.py @@ -0,0 +1,8 @@ +from .session import TrainSession as session +from onert.native.libnnfw_api_pybind import traininfo +from .dataloader import DataLoader +from . import optimizer +from . import losses +from . import metrics + +__all__ = ["session", "traininfo", "DataLoader", "optimizer", "losses", "metrics"] diff --git a/runtime/onert/api/python/package/experimental/train/dataloader.py b/runtime/onert/api/python/package/experimental/train/dataloader.py new file mode 100644 index 00000000000..281eee3cb9c --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/dataloader.py @@ -0,0 +1,265 @@ +import os +import numpy as np + + +class DataLoader: + """ + A flexible DataLoader to manage training and validation data. + Automatically detects whether inputs are paths or NumPy arrays. + """ + def __init__(self, + input_dataset, + expected_dataset, + batch_size, + input_shape=None, + expected_shape=None, + dtype=np.float32): + """ + Initialize the DataLoader. + + Args: + input_dataset (list of np.ndarray): List of input arrays where each array's first dimension is the batch dimension. + expected_dataset (list of np.ndarray): List of expected arrays where each array's first dimension is the batch dimension. + batch_size (int): Number of samples per batch. + input_shape (tuple, optional): Shape of the input data if raw format is used. + expected_shape (tuple, optional): Shape of the expected data if raw format is used. + dtype (type, optional): Data type of the raw file (default: np.float32). + """ + self.batch_size = batch_size + self.inputs = self._process_dataset(input_dataset, input_shape, dtype) + self.expecteds = self._process_dataset(expected_dataset, expected_shape, dtype) + self.batched_inputs = [] + + # Verify data consistency + self.num_samples = self.inputs[0].shape[0] # Batch dimension + if self.num_samples != self.expecteds[0].shape[0]: + raise ValueError( + "Input data and expected data must have the same number of samples.") + + # Precompute batches + self.batched_inputs, self.batched_expecteds = self._create_batches() + + def _process_dataset(self, data, shape, dtype=np.float32): + """ + Process a dataset or file path. + + Args: + data (str or np.ndarray): Path to file or NumPy arrays. + shape (tuple, optional): Shape of the data if raw format is used. + dtype (type, optional): Data type for raw files. + + Returns: + list of np.ndarray: Loaded or passed data as NumPy arrays. + """ + if isinstance(data, list): + # Check if all elements in the list are NumPy arrays + if all(isinstance(item, np.ndarray) for item in data): + return data + raise ValueError("All elements in the list must be NumPy arrays.") + if isinstance(data, np.ndarray): + # If it's already a NumPy array and is not a list of arrays + if len(data.shape) > 1: + # If the array has multiple dimensions, split it into a list of arrays + return [data[i] for i in range(data.shape[0])] + else: + # If it's a single array, wrap it into a list + return [data] + elif isinstance(data, str): + # If it's a string, assume it's a file path + return [self._load_data(data, shape, dtype)] + else: + raise ValueError("Data must be a NumPy array or a valid file path.") + + def _load_data(self, file_path, shape, dtype=np.float32): + """ + Load data from a file, supporting both .npy and raw formats. + + Args: + file_path (str): Path to the file to load. + shape (tuple, optional): Shape of the data if raw format is used. + dtype (type, optional): Data type of the raw file (default: np.float32). + + Returns: + np.ndarray: Loaded data as a NumPy array. + """ + _, ext = os.path.splitext(file_path) + + if ext == ".npy": + # Load .npy file + return np.load(file_path) + elif ext in [".bin", ".raw"]: + # Load raw binary file + if shape is None: + raise ValueError(f"Shape must be provided for raw file: {file_path}") + return self._load_raw(file_path, shape, dtype) + else: + raise ValueError(f"Unsupported file format: {ext}") + + def _load_raw(self, file_path, shape, dtype): + """ + Load raw binary data. + + Args: + file_path (str): Path to the raw binary file. + shape (tuple): Shape of the data to reshape into. + dtype (type): Data type of the binary file. + + Returns: + np.ndarray: Loaded data as a NumPy array. + """ + # Calculate the expected number of elements based on the provided shape + expected_elements = np.prod(shape) + + # Calculate the expected size of the raw file in bytes + expected_size = expected_elements * np.dtype(dtype).itemsize + + # Get the actual size of the raw file + actual_size = os.path.getsize(file_path) + + # Check if the sizes match + if actual_size != expected_size: + raise ValueError( + f"Raw file size ({actual_size} bytes) does not match the expected size " + f"({expected_size} bytes) based on the provided shape {shape} and dtype {dtype}." + ) + + # Read and load the raw data + with open(file_path, "rb") as f: + data = f.read() + array = np.frombuffer(data, dtype=dtype) + if array.size != expected_elements: + raise ValueError( + f"Raw data size does not match the expected shape: {shape}. " + f"Expected {expected_elements} elements, got {array.size} elements.") + return array.reshape(shape) + + def _create_batches(self): + """ + Precompute batches for inputs and expected outputs. + + Returns: + tuple: Lists of batched inputs and batched expecteds. + """ + batched_inputs = [] + batched_expecteds = [] + + for batch_start in range(0, self.num_samples, self.batch_size): + batch_end = min(batch_start + self.batch_size, self.num_samples) + + # Collect batched inputs + inputs_batch = [ + input_array[batch_start:batch_end] for input_array in self.inputs + ] + if batch_end - batch_start < self.batch_size: + # Resize the last batch to match batch_size + inputs_batch = [ + np.resize(batch, (self.batch_size, *batch.shape[1:])) + for batch in inputs_batch + ] + + batched_inputs.append(inputs_batch) + + # Collect batched expecteds + expecteds_batch = [ + expected_array[batch_start:batch_end] for expected_array in self.expecteds + ] + if batch_end - batch_start < self.batch_size: + # Resize the last batch to match batch_size + expecteds_batch = [ + np.resize(batch, (self.batch_size, *batch.shape[1:])) + for batch in expecteds_batch + ] + + batched_expecteds.append(expecteds_batch) + + return batched_inputs, batched_expecteds + + def __iter__(self): + """ + Make the DataLoader iterable. + + Returns: + self + """ + self.index = 0 + return self + + def __next__(self): + """ + Return the next batch of data. + + Returns: + tuple: (inputs, expecteds) for the next batch. + """ + if self.index >= len(self.batched_inputs): + raise StopIteration + + # Retrieve precomputed batch + input_batch = self.batched_inputs[self.index] + expected_batch = self.batched_expecteds[self.index] + + self.index += 1 + return input_batch, expected_batch + + def split(self, validation_split): + """ + Split the data into training and validation sets. + + Args: + validation_split (float): Ratio of validation data. Must be between 0.0 and 1.0. + + Returns: + tuple: Two DataLoader instances, one for training and one for validation. + """ + if not (0.0 <= validation_split <= 1.0): + raise ValueError("Validation split must be between 0.0 and 1.0.") + + split_index = int(len(self.inputs[0]) * (1.0 - validation_split)) + + train_inputs = [input_array[:split_index] for input_array in self.inputs] + val_inputs = [input_array[split_index:] for input_array in self.inputs] + train_expecteds = [ + expected_array[:split_index] for expected_array in self.expecteds + ] + val_expecteds = [ + expected_array[split_index:] for expected_array in self.expecteds + ] + + train_loader = DataLoader(train_inputs, train_expecteds, self.batch_size) + val_loader = DataLoader(val_inputs, val_expecteds, self.batch_size) + + return train_loader, val_loader + + # def generate_batches(self, batch_size): + # """ + # Generate batches of data. + + # Args: + # batch_size (int): Number of samples per batch. + + # Yields: + # tuple: A batch of inputs and expected outputs. + # """ + # for i in range(0, len(self.inputs), batch_size): + # yield ( + # self.inputs[i:i + batch_size], # Batch of input data + # self.expecteds[i:i + batch_size], # Batch of expected output data + # ) + + # @classmethod + # def from_data(cls, inputs, expecteds, batch_size): + # """ + # Create a DataLoader instance from raw data arrays. + + # Args: + # inputs (np.ndarray): Input data array. + # expecteds (np.ndarray): Expected output data array. + + # Returns: + # DataLoader: A new DataLoader instance. + # """ + # loader = cls.__new__(cls) # Bypass __init__ + # loader.inputs = inputs + # loader.expecteds = expecteds + # loader.batch_size = batch_size + # return loader diff --git a/runtime/onert/api/python/package/experimental/train/losses/__init__.py b/runtime/onert/api/python/package/experimental/train/losses/__init__.py new file mode 100644 index 00000000000..12977444839 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/__init__.py @@ -0,0 +1,5 @@ +from .cce import CategoricalCrossentropy +from .mse import MeanSquaredError +from onert.native.libnnfw_api_pybind import lossinfo + +__all__ = ["CategoricalCrossentropy", "MeanSquaredError", "lossinfo", "loss"] diff --git a/runtime/onert/api/python/package/experimental/train/losses/cce.py b/runtime/onert/api/python/package/experimental/train/losses/cce.py new file mode 100644 index 00000000000..dd89f5bc249 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/cce.py @@ -0,0 +1,36 @@ +import numpy as np +from .loss import LossFunction + + +class CategoricalCrossentropy(LossFunction): + """ + Categorical Cross-Entropy Loss Function with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the Categorical Cross-Entropy loss function. + + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + super().__init__(reduction) + + def __call__(self, y_true, y_pred): + """ + Compute the Categorical Cross-Entropy loss. + + Args: + y_true (np.ndarray): One-hot encoded ground truth values. + y_pred (np.ndarray): Predicted probabilities. + + Returns: + float or np.ndarray: Computed loss value(s). + """ + epsilon = 1e-7 # Prevent log(0) + y_pred = np.clip(y_pred, epsilon, 1 - epsilon) + loss = -np.sum(y_true * np.log(y_pred), axis=1) + + if self.reduction == "mean": + return np.mean(loss) + elif self.reduction == "sum": + return np.sum(loss) diff --git a/runtime/onert/api/python/package/experimental/train/losses/loss.py b/runtime/onert/api/python/package/experimental/train/losses/loss.py new file mode 100644 index 00000000000..01d1ed310e2 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/loss.py @@ -0,0 +1,25 @@ +from onert.native.libnnfw_api_pybind import loss_reduction + + +class LossFunction: + """ + Base class for loss functions with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the Categorical Cross-Entropy loss function. + + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + reduction_mapping = { + "mean": loss_reduction.SUM_OVER_BATCH_SIZE, + "sum": loss_reduction.SUM + } + + # Validate and assign the reduction type + if reduction not in reduction_mapping: + raise ValueError( + f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.") + + self.reduction = reduction_mapping[reduction] diff --git a/runtime/onert/api/python/package/experimental/train/losses/mse.py b/runtime/onert/api/python/package/experimental/train/losses/mse.py new file mode 100644 index 00000000000..953795620cf --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/mse.py @@ -0,0 +1,33 @@ +import numpy as np +from .loss import LossFunction + + +class MeanSquaredError(LossFunction): + """ + Mean Squared Error (MSE) Loss Function with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the MSE loss function. + + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + super().__init__(reduction) + + def __call__(self, y_true, y_pred): + """ + Compute the Mean Squared Error (MSE) loss. + + Args: + y_true (np.ndarray): Ground truth values. + y_pred (np.ndarray): Predicted values. + + Returns: + float or np.ndarray: Computed MSE loss value(s). + """ + loss = (y_true - y_pred)**2 + if self.reduction == "mean": + return np.mean(loss) + elif self.reduction == "sum": + return np.sum(loss) diff --git a/runtime/onert/api/python/package/experimental/train/metrics/__init__.py b/runtime/onert/api/python/package/experimental/train/metrics/__init__.py new file mode 100644 index 00000000000..7ec5015b1e5 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/__init__.py @@ -0,0 +1,4 @@ +from .metric import Metric +from .categorical_accuracy import CategoricalAccuracy + +__all__ = ["Metric", "CategoricalAccuracy"] diff --git a/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py b/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py new file mode 100644 index 00000000000..f36db57e66b --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py @@ -0,0 +1,58 @@ +import numpy as np +from .metric import Metric + + +class CategoricalAccuracy(Metric): + """ + Metric for computing categorical accuracy. + """ + def __init__(self): + self.correct = 0 + self.total = 0 + self.axis = 0 + + def reset_state(self): + """ + Reset the metric's state. + """ + self.correct = 0 + self.total = 0 + + def update_state(self, outputs, expecteds): + """ + Update the metric's state based on the outputs and expecteds. + + Args: + outputs (list of np.ndarray): List of model outputs for each output layer. + expecteds (list of np.ndarray): List of expected ground truth values for each output layer. + """ + if len(outputs) != len(expecteds): + raise ValueError( + "The number of outputs and expecteds must match. " + f"Got {len(outputs)} outputs and {len(expecteds)} expecteds.") + + for output, expected in zip(outputs, expecteds): + if output.shape[self.axis] != expected.shape[self.axis]: + raise ValueError( + f"Output and expected shapes must match along the specified axis {self.axis}. " + f"Got output shape {output.shape} and expected shape {expected.shape}." + ) + + batch_size = output.shape[self.axis] + for b in range(batch_size): + output_idx = np.argmax(output[b]) + expected_idx = np.argmax(expected[b]) + if output_idx == expected_idx: + self.correct += 1 + self.total += batch_size + + def result(self): + """ + Compute and return the final metric value. + + Returns: + float: Metric value. + """ + if self.total == 0: + return 0.0 + return self.correct / self.total diff --git a/runtime/onert/api/python/package/experimental/train/metrics/metric.py b/runtime/onert/api/python/package/experimental/train/metrics/metric.py new file mode 100644 index 00000000000..141de18e52d --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/metric.py @@ -0,0 +1,28 @@ +class Metric: + """ + Abstract base class for all metrics. + """ + def reset_state(self): + """ + Reset the metric's state. + """ + raise NotImplementedError + + def update_state(self, outputs, expecteds): + """ + Update the metric's state based on the outputs and expecteds. + + Args: + outputs (np.ndarray): Model outputs. + expecteds (np.ndarray): Expected ground truth values. + """ + raise NotImplementedError + + def result(self): + """ + Compute and return the final metric value. + + Returns: + float: Metric value. + """ + raise NotImplementedError diff --git a/runtime/onert/api/python/package/experimental/train/metrics/registry.py b/runtime/onert/api/python/package/experimental/train/metrics/registry.py new file mode 100644 index 00000000000..2ec2e288f89 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/registry.py @@ -0,0 +1,26 @@ +from .categorical_accuracy import CategoricalAccuracy + + +class MetricsRegistry: + """ + Registry for creating metrics by name. + """ + _metrics = { + "categorical_accuracy": CategoricalAccuracy, + } + + @staticmethod + def create_metric(name): + """ + Create a metric instance by name. + + Args: + name (str): Name of the metric. + + Returns: + BaseMetric: Metric instance. + """ + if name not in MetricsRegistry._metrics: + raise ValueError( + f"Unknown Metric: {name}. Custom metric is not supported yet") + return MetricsRegistry._metrics[name]() diff --git a/runtime/onert/api/python/package/experimental/train/optimizer/__init__.py b/runtime/onert/api/python/package/experimental/train/optimizer/__init__.py new file mode 100644 index 00000000000..ae450a7d0f4 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .sgd import SGD +from .adam import Adam +from onert.native.libnnfw_api_pybind import trainable_ops + +__all__ = ["SGD", "Adam", "trainable_ops"] diff --git a/runtime/onert/api/python/package/experimental/train/optimizer/adam.py b/runtime/onert/api/python/package/experimental/train/optimizer/adam.py new file mode 100644 index 00000000000..8054024e12e --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/optimizer/adam.py @@ -0,0 +1,47 @@ +from .optimizer import Optimizer + + +class Adam(Optimizer): + """ + Adam optimizer. + """ + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7): + """ + Initialize the Adam optimizer. + + Args: + learning_rate (float): The learning rate for optimization. + beta1 (float): Exponential decay rate for the first moment estimates. + beta2 (float): Exponential decay rate for the second moment estimates. + epsilon (float): Small constant to prevent division by zero. + """ + super().__init__(learning_rate) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.m = None + self.v = None + self.t = 0 + + def step(self, gradients, parameters): + """ + Update parameters using Adam optimization. + + Args: + gradients (list): List of gradients for each parameter. + parameters (list): List of parameters to be updated. + """ + if self.m is None: + self.m = [0] * len(parameters) + if self.v is None: + self.v = [0] * len(parameters) + + self.t += 1 + for i, (grad, param) in enumerate(zip(gradients, parameters)): + self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad + self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad**2) + + m_hat = self.m[i] / (1 - self.beta1**self.t) + v_hat = self.v[i] / (1 - self.beta2**self.t) + + param -= self.learning_rate * m_hat / (v_hat**0.5 + self.epsilon) diff --git a/runtime/onert/api/python/package/experimental/train/optimizer/optimizer.py b/runtime/onert/api/python/package/experimental/train/optimizer/optimizer.py new file mode 100644 index 00000000000..a8b9311ce6a --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/optimizer/optimizer.py @@ -0,0 +1,26 @@ +from onert.native.libnnfw_api_pybind import trainable_ops + + +class Optimizer: + """ + Base class for optimizers. Subclasses should implement the `step` method. + """ + def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL): + """ + Initialize the optimizer. + + Args: + learning_rate (float): The learning rate for optimization. + """ + self.learning_rate = learning_rate + self.nums_trainable_ops = nums_trainable_ops + + def step(self, gradients, parameters): + """ + Update parameters based on gradients. Should be implemented by subclasses. + + Args: + gradients (list): List of gradients for each parameter. + parameters (list): List of parameters to be updated. + """ + raise NotImplementedError("Subclasses must implement the `step` method.") diff --git a/runtime/onert/api/python/package/experimental/train/optimizer/sgd.py b/runtime/onert/api/python/package/experimental/train/optimizer/sgd.py new file mode 100644 index 00000000000..682b5a7dfbe --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/optimizer/sgd.py @@ -0,0 +1,38 @@ +from .optimizer import Optimizer + + +class SGD(Optimizer): + """ + Stochastic Gradient Descent (SGD) optimizer. + """ + def __init__(self, learning_rate=0.001, momentum=0.0): + """ + Initialize the SGD optimizer. + + Args: + learning_rate (float): The learning rate for optimization. + momentum (float): Momentum factor (default: 0.0). + """ + super().__init__(learning_rate) + + if momentum != 0.0: + raise NotImplementedError( + "Momentum is not supported in the current version of SGD.") + self.momentum = momentum + self.velocity = None + + def step(self, gradients, parameters): + """ + Update parameters using SGD with optional momentum. + + Args: + gradients (list): List of gradients for each parameter. + parameters (list): List of parameters to be updated. + """ + if self.velocity is None: + self.velocity = [0] * len(parameters) + + for i, (grad, param) in enumerate(zip(gradients, parameters)): + self.velocity[ + i] = self.momentum * self.velocity[i] - self.learning_rate * grad + parameters[i] += self.velocity[i] diff --git a/runtime/onert/api/python/package/experimental/train/session.py b/runtime/onert/api/python/package/experimental/train/session.py new file mode 100644 index 00000000000..85196791aef --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/session.py @@ -0,0 +1,359 @@ +import numpy as np + +from onert.native import libnnfw_api_pybind +from onert.native.libnnfw_api_pybind import optimizer as optimizer_type +from onert.native.libnnfw_api_pybind import loss as loss_type +from onert.common.basesession import BaseSession +from .metrics.registry import MetricsRegistry +from .metrics.metric import Metric +from .losses import CategoricalCrossentropy, MeanSquaredError +from .optimizer import Adam, SGD +import time + + +# TODO: Support import checkpoint +class TrainSession(BaseSession): + """ + Class for training and inference using nnfw_session. + """ + def __init__(self, nnpackage_path, backends="train"): + """ + Initialize the train session. + + Args: + nnpackage_path (str): Path to the nnpackage file or directory. + backends (str): Backends to use, default is "train". + """ + load_start = time.perf_counter() + super().__init__( + libnnfw_api_pybind.experimental.nnfw_session(nnpackage_path, backends)) + load_end = time.perf_counter() + self.total_time = {'MODEL_LOAD': (load_end - load_start) * 1000} + self.train_info = self.session.train_get_traininfo() + self.optimizer = None + self.loss = None + self.metrics = [] + + def compile(self, optimizer, loss, metrics=[], batch_size=16): + """ + Compile the session with optimizer, loss, and metrics. + + Args: + optimizer (Optimizer): Optimizer instance. + loss (Loss): Loss instance. + metrics (list): List of metrics to evaluate during training. + batch_size (int): Number of samples per batch. + + Raises: + ValueError: If the number of metrics does not match the number of model outputs. + """ + self.optimizer = optimizer + self.loss = loss + self.metrics = [ + MetricsRegistry.create_metric(m) if isinstance(m, str) else m for m in metrics + ] + + # Validate that all elements in self.metrics are instances of Metric + for metric in self.metrics: + if not isinstance(metric, Metric): + raise TypeError( + f"Invalid metric type: {type(metric).__name__}. " + "All metrics must inherit from the Metric base class." + ) + + # Check if the number of metrics matches the number of outputs + num_model_outputs = self.session.output_size() + if 0 < len(self.metrics) != num_model_outputs: + raise ValueError( + f"Number of metrics ({len(self.metrics)}) does not match the number of model outputs ({num_model_outputs}). " + "Please ensure one metric is provided for each model output.") + + # Set training information + self.train_info.learning_rate = optimizer.learning_rate + self.train_info.batch_size = batch_size + self.train_info.loss_info.loss = self._map_loss_function_to_enum(loss) + self.train_info.loss_info.reduction_type = loss.reduction + self.train_info.opt = self._map_optimizer_to_enum(optimizer) + self.train_info.num_of_trainable_ops = optimizer.nums_trainable_ops + self.session.train_set_traininfo(self.train_info) + + # Print training parameters + self._print_training_parameters() + + # Prepare session for training + compile_start = time.perf_counter() + self.session.train_prepare() + compile_end = time.perf_counter() + self.total_time["COMPILE"] = (compile_end - compile_start) * 1000 + + def _map_loss_function_to_enum(self, loss_instance): + """ + Maps a LossFunction instance to the appropriate enum value. + + Args: + loss_instance (LossFunction): An instance of a loss function. + + Returns: + train_loss: Corresponding enum value for the loss function. + + Raises: + TypeError: If the loss_instance is not a recognized LossFunction type. + """ + if isinstance(loss_instance, CategoricalCrossentropy): + return loss_type.CATEGORICAL_CROSSENTROPY + elif isinstance(loss_instance, MeanSquaredError): + return loss_type.MEAN_SQUARED_ERROR + else: + raise TypeError( + f"Unsupported loss function type: {type(loss_instance).__name__}. " + "Supported types are CategoricalCrossentropy and MeanSquaredError.") + + def _map_optimizer_to_enum(self, optimizer_instance): + """ + Maps an Optimizer instance to the appropriate enum value. + + Args: + optimizer_instance (Optimizer): An instance of an optimizer. + + Returns: + train_optimizer: Corresponding enum value for the optimizer. + + Raises: + TypeError: If the optimizer_instance is not a recognized Optimizer type. + """ + if isinstance(optimizer_instance, SGD): + return optimizer_type.SGD + elif isinstance(optimizer_instance, Adam): + return optimizer_type.ADAM + else: + raise TypeError( + f"Unsupported optimizer type: {type(optimizer_instance).__name__}. " + "Supported types are SGD and Adam.") + + def _print_training_parameters(self): + """ + Print the training parameters in a formatted way. + """ + # Get loss function name + loss_name = self.loss.__class__.__name__ if self.loss else "Unknown Loss" + + # Get reduction type name from enum value + reduction_name = self.train_info.loss_info.reduction_type.name.lower().replace( + "_", " ") + + # Get optimizer name + optimizer_name = self.optimizer.__class__.__name__ if self.optimizer else "Unknown Optimizer" + + print("== training parameter ==") + print( + f"- learning_rate = {f'{self.train_info.learning_rate:.4f}'.rstrip('0').rstrip('.')}" + ) + print(f"- batch_size = {self.train_info.batch_size}") + print( + f"- loss_info = {{loss = {loss_name}, reduction = {reduction_name}}}" + ) + print(f"- optimizer = {optimizer_name}") + print(f"- num_of_trainable_ops = {self.train_info.num_of_trainable_ops}") + print("========================") + + def train(self, data_loader, epochs, validation_split=0.0, checkpoint_path=None): + """ + Train the model using the given data loader. + + Args: + data_loader: A data loader providing input and expected data. + batch_size (int): Number of samples per batch. + epochs (int): Number of epochs to train. + validation_split (float): Ratio of validation data. Default is 0.0 (no validation). + checkpoint_path (str): Path to save or load the training checkpoint. + """ + if self.optimizer is None or self.loss is None: + raise RuntimeError( + "The training session is not properly configured. " + "Please call `compile(optimizer, loss)` before calling `train()`.") + + # Split data into training and validation + train_data, val_data = data_loader.split(validation_split) + + # Timings for summary + epoch_times = [] + + # Training loop + for epoch in range(epochs): + message = [f"Epoch {epoch + 1}/{epochs}"] + + epoch_start_time = time.perf_counter() + # Training phase + train_loss, avg_io_time, avg_train_time = self._run_phase(train_data, + train=True) + message.append(f"Train time: {avg_train_time:.3f}ms/step") + message.append(f"IO time: {avg_io_time:.3f}ms/step") + message.append(f"Train Loss: {train_loss:.4f}") + + # Validation phase + if validation_split > 0.0: + val_loss, _, _ = self._run_phase(val_data, train=False) + message.append(f"Validation Loss: {val_loss:.4f}") + + # Print metrics + for metric in self.metrics: + message.append(f"{metric.__class__.__name__}: {metric.result():.4f}") + metric.reset_state() + + epoch_time = (time.perf_counter() - epoch_start_time) * 1000 + epoch_times.append(epoch_time) + + print(" - ".join(message)) + + # Save checkpoint + if checkpoint_path is not None: + self.session.train_export_checkpoint(checkpoint_path) + + self.total_time["EXECUTE"] = sum(epoch_times) + self.total_time["EPOCH_TIMES"] = epoch_times + + return self.total_time + + def _run_phase(self, data, train=True): + """ + Run a training or validation phase. + + Args: + data: Data generator providing input and expected data. + train (bool): Whether to perform training or validation. + + Returns: + float: Average loss for the phase. + """ + total_loss = 0.0 + num_batches = 0 + + io_time = 0 + train_time = 0 + + for inputs, expecteds in data: + # Validate batch sizes + self._check_batch_size(inputs, self.train_info.batch_size, data_type="input") + self._check_batch_size(expecteds, + self.train_info.batch_size, + data_type="expected") + + set_io_start = time.perf_counter() + # Set inputs + for i, input_data in enumerate(inputs): + self.session.train_set_input(i, input_data) + + # Set expected outputs + outputs = [] + for i, expected_data in enumerate(expecteds): + expected = np.array(expected_data, + dtype=self.session.output_tensorinfo(i).dtype) + self.session.train_set_expected(i, expected) + + output = np.zeros(expected.shape, + dtype=self.session.output_tensorinfo(i).dtype) + self.session.train_set_output(i, output) + assert i == len(outputs) + outputs.append(output) + + set_io_end = time.perf_counter() + + # Run training or validation + train_start = time.perf_counter() + self.session.train(update_weights=train) + train_end = time.perf_counter() + + # Accumulate loss + batch_loss = sum( + self.session.train_get_loss(i) for i in range(len(expecteds))) + total_loss += batch_loss + num_batches += 1 + + # Update metrics + if not train: + for metric in self.metrics: + metric.update_state(outputs, expecteds) + + # Calculate times + io_time += (set_io_end - set_io_start) + train_time += (train_end - train_start) + + if num_batches > 0: + return (total_loss / num_batches, (io_time * 1000) / num_batches, + (train_time * 1000) / num_batches) + else: + return (0.0, 0.0, 0.0) + + def _check_batch_size(self, data, batch_size, data_type="input"): + """ + Validate that the batch size of the data matches the configured training batch size. + + Args: + data (list of np.ndarray): The data to validate. + batch_size (int): The expected batch size. + data_type (str): A string to indicate whether the data is 'input' or 'expected'. + + Raises: + ValueError: If the batch size does not match the expected value. + """ + for i, array in enumerate(data): + if array.shape[0] > batch_size: + raise ValueError( + f"Batch size mismatch for {data_type} data at index {i}: " + f"batch size ({array.shape[0]}) does not match the configured " + f"training batch size ({batch_size}).") + + def train_step(self, inputs, expecteds): + """ + Train the model for a single batch. + + Args: + inputs (list of np.ndarray): List of input arrays for the batch. + expecteds (list of np.ndarray): List of expected output arrays for the batch. + + Returns: + dict: A dictionary containing loss and metrics values. + """ + if self.optimizer is None or self.loss is None: + raise RuntimeError( + "The training session is not properly configured. " + "Please call `compile(optimizer, loss)` before calling `train_step()`.") + + # Validate batch sizes + self._check_batch_size(inputs, self.train_info.batch_size, data_type="input") + self._check_batch_size(expecteds, + self.train_info.batch_size, + data_type="expected") + + # Set inputs + for i, input_data in enumerate(inputs): + self.session.train_set_input(i, input_data) + + # Set expected outputs + outputs = [] + for i, expected_data in enumerate(expecteds): + self.session.train_set_expected(i, expected_data) + output = np.zeros(expected_data.shape, + dtype=self.session.output_tensorinfo(i).dtype) + self.session.train_set_output(i, output) + outputs.append(output) + + # Run a single training step + train_start = time.perf_counter() + self.session.train(update_weights=True) + train_end = time.perf_counter() + + # Calculate loss + losses = [self.session.train_get_loss(i) for i in range(len(expecteds))] + + # Update metrics + metric_results = {} + for metric in self.metrics: + metric.update_state(outputs, expecteds) + metric_results[metric.__class__.__name__] = metric.result() + + return { + "loss": losses, + "metrics": metric_results, + "train_time": (train_end - train_start) * 1000 + } diff --git a/runtime/onert/api/python/package/infer.py b/runtime/onert/api/python/package/infer.py deleted file mode 100644 index f3f95c63c52..00000000000 --- a/runtime/onert/api/python/package/infer.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -import os -import shutil - -from .native import libnnfw_api_pybind - - -def num_elems(tensor_info): - """Get the total number of elements in nnfw_tensorinfo.dims.""" - n = 1 - for x in range(tensor_info.rank): - n *= tensor_info.dims[x] - return n - - -class session(libnnfw_api_pybind.nnfw_session): - """Class inherited nnfw_session for easily processing input/output""" - def __init__(self, nnpackage_path, backends="cpu"): - super().__init__(nnpackage_path, backends) - self.inputs = [] - self.outputs = [] - self.set_outputs(self.output_size()) - - def set_inputs(self, size, inputs_array=[]): - """Set inputs for each index""" - for i in range(size): - input_tensorinfo = self.input_tensorinfo(i) - - if len(inputs_array) > i: - input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype) - else: - print( - f"model's input size is {size} but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0." - ) - input_array = np.zeros((num_elems(input_tensorinfo)), - dtype=input_tensorinfo.dtype) - - self.set_input(i, input_array) - self.inputs.append(input_array) - - def set_outputs(self, size): - """Set outputs for each index""" - for i in range(size): - output_tensorinfo = self.output_tensorinfo(i) - output_array = np.zeros((num_elems(output_tensorinfo)), - dtype=output_tensorinfo.dtype) - self.set_output(i, output_array) - self.outputs.append(output_array) - - def inference(self): - """Inference model and get outputs""" - self.run() - - return self.outputs - - -def tensorinfo(): - return libnnfw_api_pybind.nnfw_tensorinfo() diff --git a/runtime/onert/api/python/package/infer/__init__.py b/runtime/onert/api/python/package/infer/__init__.py new file mode 100644 index 00000000000..9c8078e7bbd --- /dev/null +++ b/runtime/onert/api/python/package/infer/__init__.py @@ -0,0 +1,3 @@ +from .session import session, tensorinfo + +__all__ = ["session", "tensorinfo"] diff --git a/runtime/onert/api/python/package/infer/session.py b/runtime/onert/api/python/package/infer/session.py new file mode 100644 index 00000000000..68e0bc2e984 --- /dev/null +++ b/runtime/onert/api/python/package/infer/session.py @@ -0,0 +1,59 @@ +import numpy as np + +from ..native import libnnfw_api_pybind +from ..common.basesession import BaseSession + + +class session(BaseSession): + """ + Class for inference using nnfw_session. + """ + def __init__(self, nnpackage_path: str = None, backends: str = "cpu"): + """ + Initialize the inference session. + + Args: + nnpackage_path (str): Path to the nnpackage file or directory. + backends (str): Backends to use, default is "cpu". + """ + if nnpackage_path is not None: + super().__init__( + libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends)) + self.session.prepare() + self.set_outputs(self.session.output_size()) + else: + super().__init__() + + def compile(self, nnpackage_path: str, backends: str = "cpu"): + """ + Prepare the session by recreating it with new parameters. + + Args: + nnpackage_path (str): Path to the nnpackage file or directory. Defaults to the existing path. + backends (str): Backends to use. Defaults to the existing backends. + """ + # Update parameters if provided + if nnpackage_path is None: + raise ValueError("nnpackage_path must not be None.") + + # Recreate the session with updated parameters + self._recreate_session( + libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends)) + + # Prepare the new session + self.session.prepare() + self.set_outputs(self.session.output_size()) + + def inference(self): + """ + Perform model and get outputs + + Returns: + list: Outputs from the model. + """ + self.session.run() + return self.outputs + + +def tensorinfo(): + return libnnfw_api_pybind.infer.nnfw_tensorinfo() diff --git a/runtime/onert/api/python/src/bindings/nnfw_api_module_pybind.cc b/runtime/onert/api/python/src/bindings/nnfw_api_module_pybind.cc new file mode 100644 index 00000000000..5b3b28ee30d --- /dev/null +++ b/runtime/onert/api/python/src/bindings/nnfw_api_module_pybind.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "nnfw_api_wrapper.h" +#include "nnfw_tensorinfo_bindings.h" +#include "nnfw_traininfo_bindings.h" +#include "nnfw_session_bindings.h" + +namespace py = pybind11; + +using namespace onert::api::python; + +PYBIND11_MODULE(libnnfw_api_pybind, m) +{ + m.doc() = "Main module that contains infer and experimental submodules"; + + // Bind common `NNFW_SESSION` class + bind_nnfw_session(m); + + // Bind `NNFW_SESSION` class for inference + // Currently, the `infer` session is the same as common. + auto infer = m.def_submodule("infer", "Inference submodule"); + infer.attr("nnfw_session") = m.attr("nnfw_session"); + + // Bind experimental `NNFW_SESSION` class + auto experimental = m.def_submodule("experimental", "Experimental submodule"); + experimental.attr("nnfw_session") = m.attr("nnfw_session"); + bind_experimental_nnfw_session(experimental); + + // Bind common `tensorinfo` class + bind_tensorinfo(m); + + m.doc() = "NNFW Python Bindings for Training"; + + // Bind training enums + bind_nnfw_train_enums(m); + + // Bind training nnfw_loss_info + bind_nnfw_loss_info(m); + + // Bind_train_info + bind_nnfw_train_info(m); +} diff --git a/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc b/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc similarity index 72% rename from runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc rename to runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc index 9737f0b58ce..1a9546d9ea0 100644 --- a/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc +++ b/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,24 +14,18 @@ * limitations under the License. */ +#include "nnfw_session_bindings.h" + #include "nnfw_api_wrapper.h" namespace py = pybind11; -PYBIND11_MODULE(libnnfw_api_pybind, m) -{ - m.doc() = "nnfw python plugin"; - - py::class_(m, "tensorinfo", "tensorinfo describes the type and shape of tensors") - .def(py::init<>(), "The constructor of tensorinfo") - .def_readwrite("dtype", &tensorinfo::dtype, "The data type") - .def_readwrite("rank", &tensorinfo::rank, "The number of dimensions (rank)") - .def_property( - "dims", [](const tensorinfo &ti) { return get_dims(ti); }, - [](tensorinfo &ti, const py::list &dims_list) { set_dims(ti, dims_list); }, - "The dimension of tensor. Maximum rank is 6 (NNFW_MAX_RANK)."); +using namespace onert::api::python; - py::class_(m, "nnfw_session") +// Bind the `NNFW_SESSION` class with common inference APIs +void bind_nnfw_session(py::module_ &m) +{ + py::class_(m, "nnfw_session", py::module_local()) .def( py::init(), py::arg("package_file_path"), py::arg("backends"), "Create a new session instance, load model from nnpackage file or directory, " @@ -48,6 +42,7 @@ PYBIND11_MODULE(libnnfw_api_pybind, m) "Parameters:\n" "\tindex (int): Index of input to be set (0-indexed)\n" "\ttensor_info (tensorinfo): Tensor info to be set") + .def("prepare", &NNFW_SESSION::prepare, "Prepare for inference") .def("run", &NNFW_SESSION::run, "Run inference") .def("run_async", &NNFW_SESSION::run_async, "Run inference asynchronously") .def("wait", &NNFW_SESSION::wait, "Wait for asynchronous run to finish") @@ -224,3 +219,44 @@ PYBIND11_MODULE(libnnfw_api_pybind, m) "Returns:\n" "\ttensorinfo: Tensor info (shape, type, etc)"); } + +// Bind the `NNFW_SESSION` class with experimental APIs +void bind_experimental_nnfw_session(py::module_ &m) +{ + // Add experimental APIs for the `NNFW_SESSION` class + m.attr("nnfw_session") + .cast>() + .def("train_get_traininfo", &NNFW_SESSION::train_get_traininfo, + "Retrieve training information for the model.") + .def("train_set_traininfo", &NNFW_SESSION::train_set_traininfo, py::arg("info"), + "Set training information for the model.") + .def("train_prepare", &NNFW_SESSION::train_prepare, "Prepare for training") + .def("train", &NNFW_SESSION::train, py::arg("update_weights") = true, + "Run a training step, optionally updating weights.") + .def("train_get_loss", &NNFW_SESSION::train_get_loss, py::arg("index"), + "Retrieve the training loss for a specific index.") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (float).") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (int).") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (uint8).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (float).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (int).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (uint8).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (float).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (int).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (uint8).") + .def("train_export_circle", &NNFW_SESSION::train_export_circle, py::arg("path"), + "Export the trained model to a circle file.") + .def("train_import_checkpoint", &NNFW_SESSION::train_import_checkpoint, py::arg("path"), + "Import a training checkpoint from a file.") + .def("train_export_checkpoint", &NNFW_SESSION::train_export_checkpoint, py::arg("path"), + "Export the training checkpoint to a file."); +} diff --git a/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc b/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc new file mode 100644 index 00000000000..88490a5e0eb --- /dev/null +++ b/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnfw_tensorinfo_bindings.h" + +#include "nnfw_api_wrapper.h" + +namespace py = pybind11; + +using namespace onert::api::python; + +// Bind the `tensorinfo` class +void bind_tensorinfo(py::module_ &m) +{ + py::class_(m, "tensorinfo", "tensorinfo describes the type and shape of tensors", + py::module_local()) + .def(py::init<>(), "The constructor of tensorinfo") + .def_readwrite("dtype", &tensorinfo::dtype, "The data type") + .def_readwrite("rank", &tensorinfo::rank, "The number of dimensions (rank)") + .def_property( + "dims", [](const tensorinfo &ti) { return get_dims(ti); }, + [](tensorinfo &ti, const py::list &dims_list) { set_dims(ti, dims_list); }, + "The dimension of tensor. Maximum rank is 6 (NNFW_MAX_RANK)."); +} diff --git a/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc b/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc new file mode 100644 index 00000000000..b088936014a --- /dev/null +++ b/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnfw_traininfo_bindings.h" + +#include "nnfw_api_wrapper.h" + +namespace py = pybind11; + +using namespace onert::api::python; + +// Declare binding train enums +void bind_nnfw_train_enums(py::module_ &m) +{ + // Bind NNFW_TRAIN_LOSS + py::enum_(m, "loss", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_LOSS_UNDEFINED) + .value("MEAN_SQUARED_ERROR", NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR) + .value("CATEGORICAL_CROSSENTROPY", NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY); + + // Bind NNFW_TRAIN_LOSS_REDUCTION + py::enum_(m, "loss_reduction", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED) + .value("SUM_OVER_BATCH_SIZE", NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE) + .value("SUM", NNFW_TRAIN_LOSS_REDUCTION_SUM); + + // Bind NNFW_TRAIN_OPTIMIZER + py::enum_(m, "optimizer", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_OPTIMIZER_UNDEFINED) + .value("SGD", NNFW_TRAIN_OPTIMIZER_SGD) + .value("ADAM", NNFW_TRAIN_OPTIMIZER_ADAM); + + // Bind NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES + py::enum_(m, "trainable_ops", py::module_local()) + .value("INCORRECT_STATE", NNFW_TRAIN_TRAINABLE_INCORRECT_STATE) + .value("ALL", NNFW_TRAIN_TRAINABLE_ALL) + .value("NONE", NNFW_TRAIN_TRAINABLE_NONE); +} + +// Declare binding loss info +void bind_nnfw_loss_info(py::module_ &m) +{ + py::class_(m, "lossinfo", py::module_local()) + .def(py::init<>()) // Default constructor + .def_readwrite("loss", &nnfw_loss_info::loss, "Loss type") + .def_readwrite("reduction_type", &nnfw_loss_info::reduction_type, "Reduction type"); +} + +// Declare binding train info +void bind_nnfw_train_info(py::module_ &m) +{ + py::class_(m, "traininfo", py::module_local()) + .def(py::init<>()) // Default constructor + .def_readwrite("learning_rate", &nnfw_train_info::learning_rate, "Learning rate") + .def_readwrite("batch_size", &nnfw_train_info::batch_size, "Batch size") + .def_readwrite("loss_info", &nnfw_train_info::loss_info, "Loss information") + .def_readwrite("opt", &nnfw_train_info::opt, "Optimizer type") + .def_readwrite("num_of_trainable_ops", &nnfw_train_info::num_of_trainable_ops, + "Number of trainable operations"); +} diff --git a/runtime/onert/api/python/src/nnfw_api_wrapper.cc b/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc similarity index 75% rename from runtime/onert/api/python/src/nnfw_api_wrapper.cc rename to runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc index 513311fd367..53159680d5e 100644 --- a/runtime/onert/api/python/src/nnfw_api_wrapper.cc +++ b/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc @@ -18,6 +18,13 @@ #include +namespace onert +{ +namespace api +{ +namespace python +{ + void ensure_status(NNFW_STATUS status) { switch (status) @@ -164,7 +171,6 @@ NNFW_SESSION::NNFW_SESSION(const char *package_file_path, const char *backends) ensure_status(nnfw_create_session(&(this->session))); ensure_status(nnfw_load_model_from_file(this->session, package_file_path)); ensure_status(nnfw_set_available_backends(this->session, backends)); - ensure_status(nnfw_prepare(this->session)); } NNFW_SESSION::~NNFW_SESSION() { @@ -190,6 +196,7 @@ void NNFW_SESSION::set_input_tensorinfo(uint32_t index, const tensorinfo *tensor } ensure_status(nnfw_set_input_tensorinfo(session, index, &ti)); } +void NNFW_SESSION::prepare() { ensure_status(nnfw_prepare(this->session)); } void NNFW_SESSION::run() { ensure_status(nnfw_run(session)); } void NNFW_SESSION::run_async() { ensure_status(nnfw_run_async(session)); } void NNFW_SESSION::wait() { ensure_status(nnfw_await(session)); } @@ -243,3 +250,71 @@ tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index) } return ti; } + +////////////////////////////////////////////// +// Experimental APIs for training +////////////////////////////////////////////// +nnfw_train_info NNFW_SESSION::train_get_traininfo() +{ + nnfw_train_info train_info = nnfw_train_info(); + ensure_status(nnfw_train_get_traininfo(session, &train_info)); + return train_info; +} + +void NNFW_SESSION::train_set_traininfo(const nnfw_train_info *info) +{ + ensure_status(nnfw_train_set_traininfo(session, info)); +} + +void NNFW_SESSION::train_prepare() { ensure_status(nnfw_train_prepare(session)); } + +void NNFW_SESSION::train(bool update_weights) +{ + ensure_status(nnfw_train(session, update_weights)); +} + +float NNFW_SESSION::train_get_loss(uint32_t index) +{ + float loss = 0.f; + ensure_status(nnfw_train_get_loss(session, index, &loss)); + return loss; +} + +void NNFW_SESSION::train_export_circle(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_export_circle(session, c_str_path)); +} + +void NNFW_SESSION::train_import_checkpoint(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_import_checkpoint(session, c_str_path)); +} + +void NNFW_SESSION::train_export_checkpoint(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_export_checkpoint(session, c_str_path)); +} + +////////////////////////////////////////////// +// Optional APIs for training +////////////////////////////////////////////// +// nnfw_tensorinfo NNFW_SESSION::train_input_tensorinfo(uint32_t index) +// { +// nnfw_tensorinfo tensorinfo = nnfw_tensorinfo(); +// ensure_status(nnfw_train_input_tensorinfo(session, index, &tensorinfo)); +// return tensorinfo; +// } + +// nnfw_tensorinfo NNFW_SESSION::train_expected_tensorinfo(uint32_t index) +// { +// nnfw_tensorinfo tensorinfo = nnfw_tensorinfo(); +// ensure_status(nnfw_train_expected_tensorinfo(session, index, &tensorinfo)); +// return tensorinfo; +// } + +} // namespace python +} // namespace api +} // namespace onert diff --git a/runtime/onert/sample/minimal-python/experimental/src/train_step_with_dataset.py b/runtime/onert/sample/minimal-python/experimental/src/train_step_with_dataset.py new file mode 100644 index 00000000000..03b63ec4c8b --- /dev/null +++ b/runtime/onert/sample/minimal-python/experimental/src/train_step_with_dataset.py @@ -0,0 +1,139 @@ +# import sys +import argparse +import numpy as np +from onert.experimental.train import session, DataLoader, optimizer, losses, metrics + + +def initParse(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', + '--nnpkg', + required=True, + help='Path to the nnpackage file or directory') + parser.add_argument('-i', + '--input', + required=True, + help='Path to the file containing input data (e.g., .npy or raw)') + parser.add_argument( + '-l', + '--label', + required=True, + help='Path to the file containing label data (e.g., .npy or raw).') + parser.add_argument('--data_length', required=True, type=int, help='data length') + parser.add_argument('--backends', default='train', help='Backends to use') + parser.add_argument('--batch_size', default=16, type=int, help='batch size') + parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate') + parser.add_argument('--loss', default='mse', choices=['mse', 'cce']) + parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam']) + parser.add_argument('--loss_reduction_type', default='mean', choices=['mean', 'sum']) + + return parser.parse_args() + + +def createOptimizer(optimizer_type, learning_rate=0.001, **kwargs): + """ + Create an optimizer based on the specified type. + Args: + optimizer_type (str): The type of optimizer ('SGD' or 'Adam'). + learning_rate (float): The learning rate for the optimizer. + **kwargs: Additional parameters for the optimizer. + Returns: + Optimizer: The created optimizer instance. + """ + if optimizer_type.lower() == "sgd": + return optimizer.SGD(learning_rate=learning_rate, **kwargs) + elif optimizer_type.lower() == "adam": + return optimizer.Adam(learning_rate=learning_rate, **kwargs) + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + +def createLoss(loss_type, reduction="mean"): + """ + Create a loss function based on the specified type and reduction. + Args: + loss_type (str): The type of loss function ('mse', 'cce'). + reduction (str): Reduction type ('mean', 'sum'). + Returns: + object: An instance of the specified loss function. + """ + if loss_type.lower() == "mse": + return losses.MeanSquaredError(reduction=reduction) + elif loss_type.lower() == "cce": + return losses.CategoricalCrossentropy(reduction=reduction) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + +def train_steps(args): + """ + Main function to train the model. + """ + # Create session and load nnpackage + sess = session(args.nnpkg, args.backends) + + # Load data + input_shape = sess.input_tensorinfo(0).dims + label_shape = sess.output_tensorinfo(0).dims + + input_shape[0] = args.data_length + label_shape[0] = args.data_length + + data_loader = DataLoader(args.input, + args.label, + args.batch_size, + input_shape=input_shape, + expected_shape=label_shape) + print('Load data') + + # optimizer + opt_fn = createOptimizer(args.optimizer, args.learning_rate) + + # loss + loss_fn = createLoss(args.loss, reduction=args.loss_reduction_type) + + sess.compile(optimizer=opt_fn, loss=loss_fn, batch_size=args.batch_size) + + # Train model + mtrs = [metrics.CategoricalAccuracy()] + total_loss = 0.0 + metric_aggregates = {metric.__class__.__name__: 0.0 for metric in mtrs} + train_time = 0.0 + + nums_steps = (args.data_length + args.batch_size - 1) // args.batch_size + for idx, (inputs, expecteds) in enumerate(data_loader): + # Train on a single step + results = sess.train_step(inputs, expecteds) + total_loss += sum(results['loss']) + + # Aggregate metrics + for metric_name, metric_value in results['metrics'].items(): + metric_aggregates[metric_name] += metric_value + + train_time += results['train_time'] + + print( + f"Step {idx + 1}/{nums_steps} - Train time: {results['train_time']:.3f} ms/step - Train Loss: {sum(results['loss']):.4f}" + ) + + # Average metrics + avg_metrics = { + name: value / args.batch_size + for name, value in metric_aggregates.items() + } + + # Print results + print("=" * 35) + print(f"Average Loss: {total_loss / nums_steps:.4f}") + for metric_name, metric_value in avg_metrics.items(): + print(f"{metric_name}: {metric_value:.4f}") + print(f"Average Time: {train_time / nums_steps:.4f} ms/step") + print("=" * 35) + + print(f"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.") + + +if __name__ == "__main__": + args = initParse() + + train_steps(args) diff --git a/runtime/onert/sample/minimal-python/experimental/src/train_with_dataset.py b/runtime/onert/sample/minimal-python/experimental/src/train_with_dataset.py new file mode 100644 index 00000000000..e16943161a4 --- /dev/null +++ b/runtime/onert/sample/minimal-python/experimental/src/train_with_dataset.py @@ -0,0 +1,127 @@ +# import sys +import argparse +import numpy as np +from onert.experimental.train import session, DataLoader, optimizer, losses, metrics + + +def initParse(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', + '--nnpkg', + required=True, + help='Path to the nnpackage file or directory') + parser.add_argument('-i', + '--input', + required=True, + help='Path to the file containing input data (e.g., .npy or raw)') + parser.add_argument( + '-l', + '--label', + required=True, + help='Path to the file containing label data (e.g., .npy or raw).') + parser.add_argument('--data_length', required=True, type=int, help='data length') + parser.add_argument('--backends', default='train', help='Backends to use') + parser.add_argument('--batch_size', default=16, type=int, help='batch size') + parser.add_argument('--epoch', default=5, type=int, help='epoch number') + parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate') + parser.add_argument('--loss', default='mse', choices=['mse', 'cce']) + parser.add_argument('--optimizer', default='sgd', choices=['sgd', 'adam']) + parser.add_argument('--loss_reduction_type', default='mean', choices=['mean', 'sum']) + parser.add_argument('--validation_split', + default=0.0, + type=float, + help='validation split rate') + + return parser.parse_args() + + +def createOptimizer(optimizer_type, learning_rate=0.001, **kwargs): + """ + Create an optimizer based on the specified type. + Args: + optimizer_type (str): The type of optimizer ('SGD' or 'Adam'). + learning_rate (float): The learning rate for the optimizer. + **kwargs: Additional parameters for the optimizer. + Returns: + Optimizer: The created optimizer instance. + """ + if optimizer_type.lower() == "sgd": + return optimizer.SGD(learning_rate=learning_rate, **kwargs) + elif optimizer_type.lower() == "adam": + return optimizer.Adam(learning_rate=learning_rate, **kwargs) + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + +def createLoss(loss_type, reduction="mean"): + """ + Create a loss function based on the specified type and reduction. + Args: + loss_type (str): The type of loss function ('mse', 'cce'). + reduction (str): Reduction type ('mean', 'sum'). + Returns: + object: An instance of the specified loss function. + """ + if loss_type.lower() == "mse": + return losses.MeanSquaredError(reduction=reduction) + elif loss_type.lower() == "cce": + return losses.CategoricalCrossentropy(reduction=reduction) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + +def train(args): + """ + Main function to train the model. + """ + # Create session and load nnpackage + sess = session(args.nnpkg, args.backends) + + # Load data + input_shape = sess.input_tensorinfo(0).dims + label_shape = sess.output_tensorinfo(0).dims + + input_shape[0] = args.data_length + label_shape[0] = args.data_length + + data_loader = DataLoader(args.input, + args.label, + args.batch_size, + input_shape=input_shape, + expected_shape=label_shape) + print('Load data') + + # optimizer + opt_fn = createOptimizer(args.optimizer, args.learning_rate) + + # loss + loss_fn = createLoss(args.loss, reduction=args.loss_reduction_type) + + sess.compile(optimizer=opt_fn, + loss=loss_fn, + batch_size=args.batch_size, + metrics=[metrics.CategoricalAccuracy()]) + + # Train model + total_time = sess.train(data_loader, + epochs=args.epoch, + validation_split=args.validation_split, + checkpoint_path="checkpoint.ckpt") + + # Print timing summary + print("=" * 35) + print(f"MODEL_LOAD takes {total_time['MODEL_LOAD']:.4f} ms") + print(f"COMPILE takes {total_time['COMPILE']:.4f} ms") + print(f"EXECUTE takes {total_time['EXECUTE']:.4f} ms") + epoch_times = total_time['EPOCH_TIMES'] + for i, epoch_time in enumerate(epoch_times): + print(f"- Epoch {i + 1} takes {epoch_time:.4f} ms") + print("=" * 35) + + print(f"nnpackage {args.nnpkg.split('/')[-1]} trains successfully.") + + +if __name__ == "__main__": + args = initParse() + + train(args) diff --git a/runtime/onert/sample/minimal-python/src/minimal.py b/runtime/onert/sample/minimal-python/infer/src/minimal.py similarity index 87% rename from runtime/onert/sample/minimal-python/src/minimal.py rename to runtime/onert/sample/minimal-python/infer/src/minimal.py index 2ae3f249fcd..987a95fc398 100644 --- a/runtime/onert/sample/minimal-python/src/minimal.py +++ b/runtime/onert/sample/minimal-python/infer/src/minimal.py @@ -1,6 +1,9 @@ from onert import infer import sys +# from onert.native import libnnfw_api_pybind +# print(dir(libnnfw_api_pybind)) + def main(nnpackage_path, backends="cpu"): # Create session and load nnpackage