From 9f5b2448201e467e69484dee48f76b959ed2ecb6 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 28 May 2024 14:59:10 +0200 Subject: [PATCH 01/11] Add a warning when task parameters differ from those of the cache in use --- pyannote/audio/core/task.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 04c73ab51..8aec4b0cb 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -327,6 +327,7 @@ def prepare_data(self): 'metadata-values': dict of lists of values for subset, scope and database 'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array. 'metadata-labels': array of global scope labels + 'task-parameters': hyper-parameters used for the task } """ @@ -595,6 +596,23 @@ def prepare_data(self): prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) unique_labels.clear() + # keep track of task parameters + parameters = [] + dtype = [] + for param_name, param_value in self.__dict__.items(): + # only keep public parameters with native type + if param_name[0] == "_": + continue + if isinstance(param_value, (bool, float, int, str)): + parameters.append(param_value) + dtype.append((param_name, type(param_value))) + + prepared_data["task-parameters"] = np.array( + tuple(parameters), dtype=np.dtype(dtype) + ) + parameters.clear() + dtype.clear() + if self.has_validation: self.prepare_validation(prepared_data) @@ -646,6 +664,19 @@ def setup(self, stage=None): f"does not correspond to the cached one ({self.prepared_data['protocol']})" ) + # checks that the task current hyperparameters matches the cached ones + for param_name, param_value in self.__dict__.items(): + if param_name not in self.prepared_data["task-parameters"].dtype.names: + continue + cached_value = self.prepared_data["task-parameters"][param_name] + if param_value != cached_value: + warnings.warn( + f"Value specified for {param_name} of the task differs from the one in the cached data." + f"Current one = {param_value}, cached one = {cached_value}." + "You may need to create a new cache for this task with" + " the new value for this hyperparameter.", + ) + @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing From e95f3c37b6d17d58d172c79640df6dbe85864e0d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jun 2024 15:28:51 +0200 Subject: [PATCH 02/11] use `inspect.signature` instead `__dict__` --- pyannote/audio/core/task.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 8aec4b0cb..c10de3ac8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,6 +23,7 @@ from __future__ import annotations +import inspect import itertools import multiprocessing import sys @@ -599,11 +600,15 @@ def prepare_data(self): # keep track of task parameters parameters = [] dtype = [] - for param_name, param_value in self.__dict__.items(): - # only keep public parameters with native type - if param_name[0] == "_": + for param_name in inspect.signature(self.__init__).parameters: + try: + param_value = getattr(self, param_name) + # skip specification-dependent parameters and non-attributed parameters + # (for instance because they were deprecated) + except (AttributeError, UnknownSpecificationsError): + print(param_name) continue - if isinstance(param_value, (bool, float, int, str)): + if isinstance(param_value, (bool, float, int, str, type(None))): parameters.append(param_value) dtype.append((param_name, type(param_value))) @@ -665,11 +670,19 @@ def setup(self, stage=None): ) # checks that the task current hyperparameters matches the cached ones - for param_name, param_value in self.__dict__.items(): + for param_name in inspect.signature(self.__init__).parameters: + try: + param_value = getattr(self, param_name) + # skip specification-dependent parameters and non-attributed parameters + # (for instance because they were deprecated) + except (AttributeError, UnknownSpecificationsError): + continue + if param_name not in self.prepared_data["task-parameters"].dtype.names: continue cached_value = self.prepared_data["task-parameters"][param_name] if param_value != cached_value: + print("passing here") warnings.warn( f"Value specified for {param_name} of the task differs from the one in the cached data." f"Current one = {param_value}, cached one = {cached_value}." From 4bccb40e5e389e0b0ed837a6539562fe555dc942 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jun 2024 15:30:35 +0200 Subject: [PATCH 03/11] clear the code --- pyannote/audio/core/task.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index c10de3ac8..4f4670272 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -606,7 +606,6 @@ def prepare_data(self): # skip specification-dependent parameters and non-attributed parameters # (for instance because they were deprecated) except (AttributeError, UnknownSpecificationsError): - print(param_name) continue if isinstance(param_value, (bool, float, int, str, type(None))): parameters.append(param_value) @@ -682,7 +681,6 @@ def setup(self, stage=None): continue cached_value = self.prepared_data["task-parameters"][param_name] if param_value != cached_value: - print("passing here") warnings.warn( f"Value specified for {param_name} of the task differs from the one in the cached data." f"Current one = {param_value}, cached one = {cached_value}." From 63472e4ec25c6a1be20a63eab6a01f3c0a59007b Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 12 Dec 2024 09:40:07 +0100 Subject: [PATCH 04/11] save task hyper-parameters into checkpoint --- pyannote/audio/core/model.py | 25 ++++++++++++++++ pyannote/audio/core/task.py | 29 +++++-------------- pyannote/audio/tasks/embedding/arcface.py | 2 ++ .../audio/tasks/segmentation/multilabel.py | 2 ++ .../overlapped_speech_detection.py | 3 ++ .../tasks/segmentation/speaker_diarization.py | 2 ++ .../segmentation/voice_activity_detection.py | 2 ++ pyannote/audio/tasks/separation/PixIT.py | 2 ++ 8 files changed, 46 insertions(+), 21 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index f3e1ec26b..51cbf3f96 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -52,6 +52,8 @@ from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.version import check_version +from pyannote.database import get_protocol, FileFinder + CACHE_DIR = os.getenv( "PYANNOTE_CACHE", os.path.expanduser("~/.cache/torch/pyannote"), @@ -262,6 +264,11 @@ def on_save_checkpoint(self, checkpoint): "class": self.__class__.__name__, }, "specifications": self.specifications, + "task": { + "module": self.task.__class__.__module__, + "class": self.task.__class__.__name__, + "hyper_parameters": self.task.hparams, + }, } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): @@ -533,6 +540,7 @@ def from_pretrained( subfolder: Optional[str] = None, use_auth_token: Union[Text, None] = None, # todo: deprecate in favor of token cache_dir: Union[Path, Text] = CACHE_DIR, + database_path: Union[Path, Text, None] = None, **kwargs, ) -> "Model": """Load pretrained model @@ -663,4 +671,21 @@ def default_map_location(storage, loc): raise e + # obtain task class from the checkpoint, if any + if "task" in loaded_checkpoint["pyannote.audio"]: + # move code to core.Task + + task_module_name: str = loaded_checkpoint["pyannote.audio"]["task"]["module"] + task_module = import_module(task_module_name) + task_class_name: str = loaded_checkpoint["pyannote.audio"]["task"]["class"] + task_hparams = loaded_checkpoint["pyannote.audio"]["task"]["hyper_parameters"] + + TaskClass = getattr(task_module, task_class_name) + + protocol = get_protocol( + task_hparams.pop("protocol"), preprocessors={"audio": FileFinder()} + ) + + model.task = TaskClass(protocol, **task_hparams) + return model diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 8401b83be..1156a9bab 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,7 +23,6 @@ from __future__ import annotations -import inspect import itertools import multiprocessing import sys @@ -306,6 +305,8 @@ def __init__( self.augmentation = augmentation or Identity(output_type="dict") self._metric = metric + self.hparams["protocol"] = protocol.name + def prepare_data(self): """Use this to prepare data from task protocol @@ -597,16 +598,10 @@ def prepare_data(self): prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) unique_labels.clear() - # keep track of task parameters + # keep track of task hyperparameters parameters = [] dtype = [] - for param_name in inspect.signature(self.__init__).parameters: - try: - param_value = getattr(self, param_name) - # skip specification-dependent parameters and non-attributed parameters - # (for instance because they were deprecated) - except (AttributeError, UnknownSpecificationsError): - continue + for param_name, param_value in self.hparams.items(): if isinstance(param_value, (bool, float, int, str, type(None))): parameters.append(param_value) dtype.append((param_name, type(param_value))) @@ -669,23 +664,15 @@ def setup(self, stage=None): ) # checks that the task current hyperparameters matches the cached ones - for param_name in inspect.signature(self.__init__).parameters: - try: - param_value = getattr(self, param_name) - # skip specification-dependent parameters and non-attributed parameters - # (for instance because they were deprecated) - except (AttributeError, UnknownSpecificationsError): - continue - + for param_name, param_value in self.hparams.items(): if param_name not in self.prepared_data["task-parameters"].dtype.names: continue cached_value = self.prepared_data["task-parameters"][param_name] if param_value != cached_value: warnings.warn( - f"Value specified for {param_name} of the task differs from the one in the cached data." - f"Current one = {param_value}, cached one = {cached_value}." - "You may need to create a new cache for this task with" - " the new value for this hyperparameter.", + f"Value specified for the task hyperparameter {param_name} differs from the one in the cached data." + f"Current value = {param_value}, cached value = {cached_value}." + "You may need to create a new cache with the new value for this hyperparameter.", ) @property diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index cb6401e2b..5bc2f9f06 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -111,6 +111,8 @@ def __init__( metric=metric, ) + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + def setup_loss_func(self): _, embedding_size = self.model(self.model.example_input_array).shape diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 9184121c4..7112c2c23 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -125,6 +125,8 @@ def __init__( self.weight = weight self.classes = classes + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + # task specification depends on the data: we do not know in advance which # classes should be detected. therefore, we postpone the definition of # specifications to setup() diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 89d299a8d..0503fc33a 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -142,6 +142,9 @@ def __init__( self.balance = balance self.weight = weight + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk for overlapped speech detection diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 6ca6417e9..6fa5e06f4 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -162,6 +162,8 @@ def __init__( self.balance = balance self.weight = weight + self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"]) + def setup(self, stage=None): super().setup(stage) diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index e52613aeb..57defa88c 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -123,6 +123,8 @@ def __init__( ], ) + self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"]) + def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk for voice activity detection diff --git a/pyannote/audio/tasks/separation/PixIT.py b/pyannote/audio/tasks/separation/PixIT.py index 88c3495ee..362e146e9 100644 --- a/pyannote/audio/tasks/separation/PixIT.py +++ b/pyannote/audio/tasks/separation/PixIT.py @@ -221,6 +221,8 @@ def __init__( self.separation_loss_weight = separation_loss_weight self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"]) + def setup(self, stage=None): super().setup(stage) From 0d0c19a9e1c4afcff3e56852ecf45ec8190e92dd Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 12 Dec 2024 10:39:13 +0100 Subject: [PATCH 05/11] set protocol instead of database path when loading checkpoint --- pyannote/audio/core/model.py | 12 +++++------- pyannote/audio/core/task.py | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 51cbf3f96..16a51f6ad 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -52,7 +52,7 @@ from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.version import check_version -from pyannote.database import get_protocol, FileFinder +from pyannote.database import Protocol CACHE_DIR = os.getenv( "PYANNOTE_CACHE", @@ -540,7 +540,7 @@ def from_pretrained( subfolder: Optional[str] = None, use_auth_token: Union[Text, None] = None, # todo: deprecate in favor of token cache_dir: Union[Path, Text] = CACHE_DIR, - database_path: Union[Path, Text, None] = None, + protocol: Union[Protocol, None] = None, **kwargs, ) -> "Model": """Load pretrained model @@ -567,6 +567,8 @@ def from_pretrained( cache_dir: Path or str, optional Path to model cache directory. Defaults to content of PYANNOTE_CACHE environment variable, or "~/.cache/torch/pyannote" when unset. + protocol: Protocol, optional + Protocol used to train the model. Needed to continue training. kwargs: optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -672,7 +674,7 @@ def default_map_location(storage, loc): raise e # obtain task class from the checkpoint, if any - if "task" in loaded_checkpoint["pyannote.audio"]: + if protocol and "task" in loaded_checkpoint["pyannote.audio"]: # move code to core.Task task_module_name: str = loaded_checkpoint["pyannote.audio"]["task"]["module"] @@ -682,10 +684,6 @@ def default_map_location(storage, loc): TaskClass = getattr(task_module, task_class_name) - protocol = get_protocol( - task_hparams.pop("protocol"), preprocessors={"audio": FileFinder()} - ) - model.task = TaskClass(protocol, **task_hparams) return model diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 1156a9bab..172f23076 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -305,8 +305,6 @@ def __init__( self.augmentation = augmentation or Identity(output_type="dict") self._metric = metric - self.hparams["protocol"] = protocol.name - def prepare_data(self): """Use this to prepare data from task protocol From 45a626bd0854865cc9e615fc692efb9aa4ed61ef Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 12 Dec 2024 11:50:51 +0100 Subject: [PATCH 06/11] add task metric to checkpoint --- pyannote/audio/core/model.py | 22 ++++++++++++++-------- pyannote/audio/core/task.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 16a51f6ad..4250fd807 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -264,13 +264,10 @@ def on_save_checkpoint(self, checkpoint): "class": self.__class__.__name__, }, "specifications": self.specifications, - "task": { - "module": self.task.__class__.__module__, - "class": self.task.__class__.__name__, - "hyper_parameters": self.task.hparams, - }, } + self.task.on_save_checkpoint(checkpoint) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]): check_version( "pyannote.audio", @@ -673,10 +670,8 @@ def default_map_location(storage, loc): raise e - # obtain task class from the checkpoint, if any + # init task from the checkpoint, if any if protocol and "task" in loaded_checkpoint["pyannote.audio"]: - # move code to core.Task - task_module_name: str = loaded_checkpoint["pyannote.audio"]["task"]["module"] task_module = import_module(task_module_name) task_class_name: str = loaded_checkpoint["pyannote.audio"]["task"]["class"] @@ -684,6 +679,17 @@ def default_map_location(storage, loc): TaskClass = getattr(task_module, task_class_name) + metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] + metric = {} + for name, metadata in metrics.items(): + metric_module = import_module(metadata["module"]) + metric_class = metadata["class"] + metric_kwargs = metadata["kwargs"] + + MetricClass = getattr(metric_module, metric_class) + metric[name] = MetricClass(**metric_kwargs) + task_hparams["metric"] = metric + model.task = TaskClass(protocol, **task_hparams) return model diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 172f23076..81c244b9d 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,6 +23,7 @@ from __future__ import annotations +import inspect import itertools import multiprocessing import sys @@ -905,3 +906,30 @@ def val_monitor(self): name, metric = next(iter(self.metric.items())) return name, "max" if metric.higher_is_better else "min" + + def on_save_checkpoint(self, checkpoint): + checkpoint["pyannote.audio"]["task"] = { + "module": self.__class__.__module__, + "class": self.__class__.__name__, + "hyper_parameters": self.hparams, + } + + # save metrics: + if isinstance(self.metric, Metric): + metrics = {self.metric.__class__.__name__: self.metric} + elif isinstance(self.metric, Sequence): + metrics = {metric.__class__.__name__: metric for metric in self.metric} + else: + metrics = self.metric + + if metrics: + checkpoint["pyannote.audio"]["task"]["metrics"] = { + name: { + "module": metric.__class__.__module__, + "class": metric.__class__.__name__, + "kwargs": { + param : getattr(metric, param, None) + for param in inspect.signature(metric.__init__).parameters + } + } for name, metric in metrics.items() + } From da3356bd48d2724ab572cf94642d57601debbd9e Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 12 Dec 2024 14:44:35 +0100 Subject: [PATCH 07/11] save task augmentation into checkpoint --- pyannote/audio/core/model.py | 33 +++++++++++++++++++++++++-------- pyannote/audio/core/task.py | 16 ++++++++++++++++ 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 4250fd807..6dfe1a8bf 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -679,17 +679,34 @@ def default_map_location(storage, loc): TaskClass = getattr(task_module, task_class_name) + # instanciate task augmentation + augmentation = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] + if augmentation: + augmentation_module = import_module(augmentation["module"]) + augmentation_class = augmentation["class"] + augmentation_kwargs = augmentation["kwargs"] + AugmentationClass = getattr(augmentation_module, augmentation_class) + augmentation = AugmentationClass(**augmentation_kwargs) + + task_hparams["augmentation"] = augmentation + + # instanciate task metrics metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] - metric = {} - for name, metadata in metrics.items(): - metric_module = import_module(metadata["module"]) - metric_class = metadata["class"] - metric_kwargs = metadata["kwargs"] - - MetricClass = getattr(metric_module, metric_class) - metric[name] = MetricClass(**metric_kwargs) + if metrics: + metric = {} + for name, metadata in metrics.items(): + metric_module = import_module(metadata["module"]) + metric_class = metadata["class"] + metric_kwargs = metadata["kwargs"] + + MetricClass = getattr(metric_module, metric_class) + metric[name] = MetricClass(**metric_kwargs) + else: + metric = None + task_hparams["metric"] = metric + # instanciate training task model.task = TaskClass(protocol, **task_hparams) return model diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 81c244b9d..605e3a975 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -914,6 +914,20 @@ def on_save_checkpoint(self, checkpoint): "hyper_parameters": self.hparams, } + # save augmentation: + # TODO: add support for compose augmentation + if not self.augmentation: + checkpoint["pyannote.audio"]["task"]["augmentation"] = None + elif isinstance(self.augmentation, BaseWaveformTransform): + checkpoint["pyannote.audio"]["task"]["augmentation"] = [{ + "module": self.augmentation.__class__.__module__, + "class": self.augmentation.__class__.__name__, + "kwargs": { + param: getattr(self.augmentation, param, None) + for param in inspect.signature(self.augmentation.__init__).parameters + } + }] + # save metrics: if isinstance(self.metric, Metric): metrics = {self.metric.__class__.__name__: self.metric} @@ -933,3 +947,5 @@ def on_save_checkpoint(self, checkpoint): } } for name, metric in metrics.items() } + else: + checkpoint["pyannote.audio"]["task"]["metrics"] = None From 2a59bc7b0dcce819edd97986e8000435f8e4ed23 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 9 Jan 2025 09:51:22 +0100 Subject: [PATCH 08/11] add support for compose transform --- pyannote/audio/core/model.py | 32 ++++++++++++++++++++++---------- pyannote/audio/core/task.py | 25 ++++++++++++++++--------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 6dfe1a8bf..fbc45c249 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -40,6 +40,7 @@ from pyannote.core import SlidingWindow from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary from torch.utils.data import DataLoader +from torch_audiomentations.core.composition import Compose from pyannote.audio import __version__ from pyannote.audio.core.io import Audio @@ -679,16 +680,27 @@ def default_map_location(storage, loc): TaskClass = getattr(task_module, task_class_name) - # instanciate task augmentation - augmentation = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] - if augmentation: - augmentation_module = import_module(augmentation["module"]) - augmentation_class = augmentation["class"] - augmentation_kwargs = augmentation["kwargs"] - AugmentationClass = getattr(augmentation_module, augmentation_class) - augmentation = AugmentationClass(**augmentation_kwargs) - - task_hparams["augmentation"] = augmentation + # instantiate task augmentation + def instantiate_transform(transform_data): + transform_module = import_module(transform_data["module"]) + transform_class = transform_data["class"] + transform_kwargs = transform_data["kwargs"] + TransformClass = getattr(transform_module, transform_class) + return TransformClass(**transform_kwargs) + + augmentation_data = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] + # BaseWaveformTransform case + if isinstance(augmentation_data, Dict): + task_hparams["augmentation"] = instantiate_transform(augmentation_data) + + # Compose transform case + elif isinstance(augmentation_data , List): + transforms = [] + for transform_data in augmentation_data: + transform = instantiate_transform(transform_data) + transforms.append(transform) + + task_hparams["augmentation"] = Compose(transforms=transforms, output_type="dict") # instanciate task metrics metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 605e3a975..89706b0b3 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -46,6 +46,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torch_audiomentations import Identity from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations.core.composition import BaseCompose from torchmetrics import Metric, MetricCollection from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss @@ -915,18 +916,24 @@ def on_save_checkpoint(self, checkpoint): } # save augmentation: - # TODO: add support for compose augmentation + def serialize_augmentation(augmentation) -> Dict: + return { + "module": augmentation.__class__.__module__, + "class": augmentation.__class__.__name__, + "kwargs": { + param: getattr(augmentation, param, None) + for param in inspect.signature(augmentation.__init__).parameters + } + } + if not self.augmentation: checkpoint["pyannote.audio"]["task"]["augmentation"] = None elif isinstance(self.augmentation, BaseWaveformTransform): - checkpoint["pyannote.audio"]["task"]["augmentation"] = [{ - "module": self.augmentation.__class__.__module__, - "class": self.augmentation.__class__.__name__, - "kwargs": { - param: getattr(self.augmentation, param, None) - for param in inspect.signature(self.augmentation.__init__).parameters - } - }] + checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_augmentation(self.augmentation) + elif isinstance(self.augmentation, BaseCompose): + checkpoint["pyannote.audio"]["task"]["augmentation"] = [] + for augmentation in self.augmentation.transforms: + checkpoint["pyannote.audio"]["task"]["augmentation"].append(serialize_augmentation(augmentation)) # save metrics: if isinstance(self.metric, Metric): From 092dd22d858e8848357906f5a7f1df9bef3e2376 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 23 Jan 2025 11:55:08 +0100 Subject: [PATCH 09/11] add a warning message when param cannot be serialized --- pyannote/audio/core/model.py | 4 +-- pyannote/audio/core/task.py | 60 +++++++++++++++++++----------------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index fbc45c249..c6f6b43fc 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -706,13 +706,13 @@ def instantiate_transform(transform_data): metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] if metrics: metric = {} - for name, metadata in metrics.items(): + for metadata in metrics: metric_module = import_module(metadata["module"]) metric_class = metadata["class"] metric_kwargs = metadata["kwargs"] MetricClass = getattr(metric_module, metric_class) - metric[name] = MetricClass(**metric_kwargs) + metric[metric_class] = MetricClass(**metric_kwargs) else: metric = None diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 89706b0b3..eeaff89ca 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -25,6 +25,7 @@ import inspect import itertools +import json import multiprocessing import sys import warnings @@ -915,44 +916,47 @@ def on_save_checkpoint(self, checkpoint): "hyper_parameters": self.hparams, } - # save augmentation: - def serialize_augmentation(augmentation) -> Dict: - return { - "module": augmentation.__class__.__module__, - "class": augmentation.__class__.__name__, - "kwargs": { - param: getattr(augmentation, param, None) - for param in inspect.signature(augmentation.__init__).parameters - } + def serialize_object(obj: Any) -> Dict: + serialized_obj = { + "module": obj.__class__.__module__, + "class": obj.__class__.__name__, + "kwargs": {}, } + for param in inspect.signature(obj.__init__).parameters: + param_value = getattr(obj, param, None) + if isinstance(param_value, (bool, float, int, list, dict, str, type(None))): + serialized_obj["kwargs"][param] = param_value + else: + msg = f"Cannot serialize {obj.__class__.__name__}.{param}. This parameter will not be saved in model checkpoint." + warnings.warn(msg, RuntimeWarning) + + return serialized_obj + + # save augmentation: if not self.augmentation: checkpoint["pyannote.audio"]["task"]["augmentation"] = None elif isinstance(self.augmentation, BaseWaveformTransform): - checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_augmentation(self.augmentation) + checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_object( + self.augmentation + ) elif isinstance(self.augmentation, BaseCompose): checkpoint["pyannote.audio"]["task"]["augmentation"] = [] for augmentation in self.augmentation.transforms: - checkpoint["pyannote.audio"]["task"]["augmentation"].append(serialize_augmentation(augmentation)) + checkpoint["pyannote.audio"]["task"]["augmentation"].append( + serialize_object(augmentation) + ) # save metrics: if isinstance(self.metric, Metric): - metrics = {self.metric.__class__.__name__: self.metric} - elif isinstance(self.metric, Sequence): - metrics = {metric.__class__.__name__: metric for metric in self.metric} - else: - metrics = self.metric - - if metrics: - checkpoint["pyannote.audio"]["task"]["metrics"] = { - name: { - "module": metric.__class__.__module__, - "class": metric.__class__.__name__, - "kwargs": { - param : getattr(metric, param, None) - for param in inspect.signature(metric.__init__).parameters - } - } for name, metric in metrics.items() - } + checkpoint["pyannote.audio"]["task"]["metrics"] = [ + json.dumps(self.metric, default=serialize_object) + ] + elif isinstance(self.metric, MetricCollection): + checkpoint["pyannote.audio"]["task"]["metrics"] = [] + for metric in self.metric.values(): + checkpoint["pyannote.audio"]["task"]["metrics"].append( + serialize_object(metric) + ) else: checkpoint["pyannote.audio"]["task"]["metrics"] = None From 1791fbacd01e5424af10fbfd62bdab7e79b552ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pag=C3=A9s?= <55240756+clement-pages@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:44:54 +0100 Subject: [PATCH 10/11] Update model.py --- pyannote/audio/core/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 10b181a19..01238772e 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -527,7 +527,7 @@ def from_pretrained( strict: bool = True, subfolder: Optional[str] = None, token: Union[str, bool, None] = None, - cache_dir: Union[Path, Text] = CACHE_DIR, + cache_dir: Union[Path, str, None] = None protocol: Union[Protocol, None] = None, **kwargs, ) -> Optional["Model"]: From b265eb20951c0ad4d75492a6d6482e94e2d63fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pag=C3=A9s?= <55240756+clement-pages@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:45:39 +0100 Subject: [PATCH 11/11] Update model.py --- pyannote/audio/core/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 01238772e..74fe9ae8b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -527,7 +527,7 @@ def from_pretrained( strict: bool = True, subfolder: Optional[str] = None, token: Union[str, bool, None] = None, - cache_dir: Union[Path, str, None] = None + cache_dir: Union[Path, str, None] = None, protocol: Union[Protocol, None] = None, **kwargs, ) -> Optional["Model"]: