Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save task parameters into cache and model checkpoint #1719

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary
from torch.utils.data import DataLoader
from torch_audiomentations.core.composition import Compose

from pyannote.audio import __version__
from pyannote.audio.core.io import Audio
Expand Down Expand Up @@ -256,6 +257,8 @@ def on_save_checkpoint(self, checkpoint):
"specifications": self.specifications,
}

self.task.on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
check_version(
"pyannote.audio",
Expand Down Expand Up @@ -525,6 +528,7 @@ def from_pretrained(
subfolder: Optional[str] = None,
token: Union[str, bool, None] = None,
cache_dir: Union[Path, str, None] = None,
protocol: Union[Protocol, None] = None,
**kwargs,
) -> Optional["Model"]:
"""Load pretrained model
Expand All @@ -548,6 +552,8 @@ def from_pretrained(
Token to be used for the download.
cache_dir: Path or str, optional
Path to the folder where cached files are stored.
protocol: Protocol, optional
Protocol used to train the model. Needed to continue training.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Expand Down Expand Up @@ -627,4 +633,54 @@ def default_map_location(storage, loc):

raise e

# init task from the checkpoint, if any
if protocol and "task" in loaded_checkpoint["pyannote.audio"]:
task_module_name: str = loaded_checkpoint["pyannote.audio"]["task"]["module"]
task_module = import_module(task_module_name)
task_class_name: str = loaded_checkpoint["pyannote.audio"]["task"]["class"]
task_hparams = loaded_checkpoint["pyannote.audio"]["task"]["hyper_parameters"]

TaskClass = getattr(task_module, task_class_name)

# instantiate task augmentation
def instantiate_transform(transform_data):
transform_module = import_module(transform_data["module"])
transform_class = transform_data["class"]
transform_kwargs = transform_data["kwargs"]
TransformClass = getattr(transform_module, transform_class)
return TransformClass(**transform_kwargs)

augmentation_data = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"]
# BaseWaveformTransform case
if isinstance(augmentation_data, Dict):
task_hparams["augmentation"] = instantiate_transform(augmentation_data)

# Compose transform case
elif isinstance(augmentation_data , List):
transforms = []
for transform_data in augmentation_data:
transform = instantiate_transform(transform_data)
transforms.append(transform)

task_hparams["augmentation"] = Compose(transforms=transforms, output_type="dict")

# instanciate task metrics
metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"]
if metrics:
metric = {}
for metadata in metrics:
metric_module = import_module(metadata["module"])
metric_class = metadata["class"]
metric_kwargs = metadata["kwargs"]

MetricClass = getattr(metric_module, metric_class)
metric[metric_class] = MetricClass(**metric_kwargs)
else:
metric = None

task_hparams["metric"] = metric

# instanciate training task
model.task = TaskClass(protocol, **task_hparams)

return model
82 changes: 82 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

from __future__ import annotations

import inspect
import itertools
import json
import multiprocessing
import sys
import warnings
Expand All @@ -45,6 +47,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch_audiomentations import Identity
from torch_audiomentations.core.transforms_interface import BaseWaveformTransform
from torch_audiomentations.core.composition import BaseCompose
from torchmetrics import Metric, MetricCollection

from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss
Expand Down Expand Up @@ -327,6 +330,7 @@ def prepare_data(self):
'metadata-values': dict of lists of values for subset, scope and database
'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array.
'metadata-labels': array of global scope labels
'task-parameters': hyper-parameters used for the task
}

"""
Expand Down Expand Up @@ -595,6 +599,20 @@ def prepare_data(self):
prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_)
unique_labels.clear()

# keep track of task hyperparameters
parameters = []
dtype = []
for param_name, param_value in self.hparams.items():
if isinstance(param_value, (bool, float, int, str, type(None))):
parameters.append(param_value)
dtype.append((param_name, type(param_value)))

prepared_data["task-parameters"] = np.array(
tuple(parameters), dtype=np.dtype(dtype)
)
parameters.clear()
dtype.clear()

if self.has_validation:
self.prepare_validation(prepared_data)

Expand Down Expand Up @@ -646,6 +664,18 @@ def setup(self, stage=None):
f"does not correspond to the cached one ({self.prepared_data['protocol']})"
)

# checks that the task current hyperparameters matches the cached ones
for param_name, param_value in self.hparams.items():
if param_name not in self.prepared_data["task-parameters"].dtype.names:
continue
cached_value = self.prepared_data["task-parameters"][param_name]
if param_value != cached_value:
warnings.warn(
f"Value specified for the task hyperparameter {param_name} differs from the one in the cached data."
f"Current value = {param_value}, cached value = {cached_value}."
"You may need to create a new cache with the new value for this hyperparameter.",
)

@property
def automatic_optimization(self) -> bool:
return self.model.automatic_optimization
Expand Down Expand Up @@ -878,3 +908,55 @@ def val_monitor(self):

name, metric = next(iter(self.metric.items()))
return name, "max" if metric.higher_is_better else "min"

def on_save_checkpoint(self, checkpoint):
checkpoint["pyannote.audio"]["task"] = {
"module": self.__class__.__module__,
"class": self.__class__.__name__,
"hyper_parameters": self.hparams,
}

def serialize_object(obj: Any) -> Dict:
serialized_obj = {
"module": obj.__class__.__module__,
"class": obj.__class__.__name__,
"kwargs": {},
}

for param in inspect.signature(obj.__init__).parameters:
param_value = getattr(obj, param, None)
if isinstance(param_value, (bool, float, int, list, dict, str, type(None))):
serialized_obj["kwargs"][param] = param_value
else:
msg = f"Cannot serialize {obj.__class__.__name__}.{param}. This parameter will not be saved in model checkpoint."
warnings.warn(msg, RuntimeWarning)

return serialized_obj

# save augmentation:
if not self.augmentation:
checkpoint["pyannote.audio"]["task"]["augmentation"] = None
elif isinstance(self.augmentation, BaseWaveformTransform):
checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_object(
self.augmentation
)
elif isinstance(self.augmentation, BaseCompose):
checkpoint["pyannote.audio"]["task"]["augmentation"] = []
for augmentation in self.augmentation.transforms:
checkpoint["pyannote.audio"]["task"]["augmentation"].append(
serialize_object(augmentation)
)

# save metrics:
if isinstance(self.metric, Metric):
checkpoint["pyannote.audio"]["task"]["metrics"] = [
json.dumps(self.metric, default=serialize_object)
]
elif isinstance(self.metric, MetricCollection):
checkpoint["pyannote.audio"]["task"]["metrics"] = []
for metric in self.metric.values():
checkpoint["pyannote.audio"]["task"]["metrics"].append(
serialize_object(metric)
)
else:
checkpoint["pyannote.audio"]["task"]["metrics"] = None
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/embedding/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def __init__(
metric=metric,
)

self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"])

def setup_loss_func(self):

_, embedding_size = self.model(self.model.example_input_array).shape
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
self.weight = weight
self.classes = classes

self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"])

# task specification depends on the data: we do not know in advance which
# classes should be detected. therefore, we postpone the definition of
# specifications to setup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def __init__(
self.balance = balance
self.weight = weight

self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"])


def prepare_chunk(self, file_id: int, start_time: float, duration: float):
"""Prepare chunk for overlapped speech detection

Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def __init__(
self.balance = balance
self.weight = weight

self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"])

def setup(self, stage=None):
super().setup(stage)

Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/voice_activity_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(
],
)

self.save_hyperparameters(ignore=["augmentation", "metric", "protocol"])

def prepare_chunk(self, file_id: int, start_time: float, duration: float):
"""Prepare chunk for voice activity detection

Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def __init__(
self.separation_loss_weight = separation_loss_weight
self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True)

self.save_hyperparameters(ignore=["augmentation", "loss", "metric", "protocol"])

def setup(self, stage=None):
super().setup(stage)

Expand Down
Loading