Skip to content

Commit

Permalink
fix nan and saved twice
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed May 30, 2024
2 parents b2e4b77 + d2c40d8 commit 00d1324
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 34 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.86.2
--------------
* plots epoch progression for finetuned models now

Version 0.86.1
--------------
* functionality to push to hub
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION="0.86.1"
VERSION="0.86.2"
SAMPLING_RATE = 16000
4 changes: 2 additions & 2 deletions nkululeko/models/model_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@


class MLP_model(Model):
"""MLP = multi layer perceptron"""
"""MLP = multi layer perceptron."""

is_classifier = True

def __init__(self, df_train, df_test, feats_train, feats_test):
"""Constructor taking the configuration and all dataframes"""
"""Constructor taking the configuration and all dataframes."""
super().__init__(df_train, df_test, feats_train, feats_test)
super().set_model_type("ann")
self.name = "mlp"
Expand Down
36 changes: 33 additions & 3 deletions nkululeko/models/model_tuned.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Code based on @jwagner."""

import ast
import dataclasses
import json
import os
Expand Down Expand Up @@ -242,8 +243,8 @@ def compute_metrics(self, p: transformers.EvalPrediction):
def train(self):
"""Train the model."""
model_root = self.util.get_path("model_dir")
log_root = os.path.join(self.util.get_exp_dir(), "log")
audeer.mkdir(log_root)
self.log_root = os.path.join(self.util.get_exp_dir(), "log")
audeer.mkdir(self.log_root)
self.torch_root = audeer.path(model_root, "torch")
conf_file = os.path.join(self.torch_root, "config.json")
if os.path.isfile(conf_file):
Expand Down Expand Up @@ -351,8 +352,15 @@ def compute_loss(
tokenizer=self.processor.feature_extractor,
callbacks=[transformers.integrations.TensorBoardCallback()],
)

trainer.train()
# trainer.save_model(self.torch_root) # already saved above
# trainer.save_model(self.torch_root)
log_file = os.path.join(
self.log_root,
"log.txt",
)
with open(log_file, "w") as text_file:
print(trainer.state.log_history, file=text_file)
self.util.debug(f"saved best model to {self.torch_root}")
self.load(self.run, self.epoch)

Expand Down Expand Up @@ -383,8 +391,30 @@ def predict(self):
self.run,
self.epoch_num,
)
self._plot_epoch_progression(report)
return report

def _plot_epoch_progression(self, report):
log_file = os.path.join(
self.log_root,
"log.txt",
)
with open(log_file, "r") as file:
data = file.read()
list = ast.literal_eval(data)
epochs, vals, loss = [], [], []
for index, tp in enumerate(list):
try:
epochs.append(tp["epoch"])
measure = self.measure.upper()
vals.append(tp[f"eval_{measure}"])
loss.append(tp["eval_loss"])
except KeyError:
del epochs[-1]
# print(f'no value at {index}')
df = pd.DataFrame({"results": vals, "losses": loss}, index=epochs)
report.plot_epoch_progression_finetuned(df)

def predict_sample(self, signal):
"""Predict one sample"""
prediction = {}
Expand Down
45 changes: 19 additions & 26 deletions nkululeko/plots.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# plots.py
import pandas as pd
import ast

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
import numpy as np
import ast
import pandas as pd
from scipy import stats
from nkululeko.utils.util import Util
import nkululeko.utils.stats as su
import seaborn as sns
from sklearn.manifold import TSNE

import nkululeko.glob_conf as glob_conf
from nkululeko.reporting.report_item import ReportItem
from nkululeko.reporting.defines import Header
from nkululeko.reporting.report_item import ReportItem
import nkululeko.utils.stats as su
from nkululeko.utils.util import Util


class Plots:
def __init__(self):
"""Initializing the util system"""
"""Initializing the util system."""
self.util = Util("plots")
self.format = self.util.config_val("PLOT", "format", "png")
self.target = self.util.config_val("DATA", "target", "emotion")
Expand Down Expand Up @@ -138,8 +140,7 @@ def plot_distributions(self, df, type_s="samples"):
df, att1, class_label, att1, type_s
)
else:
ax, caption = self._plot2cont(
df, class_label, att1, type_s)
ax, caption = self._plot2cont(df, class_label, att1, type_s)
self._save_plot(
ax,
caption,
Expand All @@ -152,8 +153,7 @@ def plot_distributions(self, df, type_s="samples"):
att1 = att[0]
att2 = att[1]
if att1 == self.target or att2 == self.target:
self.util.debug(
f"no need to correlate {self.target} with itself")
self.util.debug(f"no need to correlate {self.target} with itself")
return
if att1 not in df:
self.util.error(f"unknown feature: {att1}")
Expand All @@ -168,8 +168,7 @@ def plot_distributions(self, df, type_s="samples"):
if self.util.is_categorical(df[att1]):
if self.util.is_categorical(df[att2]):
# class_label = cat, att1 = cat, att2 = cat
ax, caption = self._plot2cat(
df, att1, att2, att1, type_s)
ax, caption = self._plot2cat(df, att1, att2, att1, type_s)
else:
# class_label = cat, att1 = cat, att2 = cont
ax, caption = self._plotcatcont(
Expand All @@ -190,8 +189,7 @@ def plot_distributions(self, df, type_s="samples"):
if self.util.is_categorical(df[att1]):
if self.util.is_categorical(df[att2]):
# class_label = cont, att1 = cat, att2 = cat
ax, caption = self._plot2cat(
df, att1, att2, att1, type_s)
ax, caption = self._plot2cat(df, att1, att2, att1, type_s)
else:
# class_label = cont, att1 = cat, att2 = cont
ax, caption = self._plot2cont_cat(
Expand All @@ -205,8 +203,7 @@ def plot_distributions(self, df, type_s="samples"):
)
else:
# class_label = cont, att1 = cont, att2 = cont
ax, caption = self._plot2cont(
df, att1, att2, type_s)
ax, caption = self._plot2cont(df, att1, att2, type_s)

self._save_plot(
ax, caption, f"Correlation of {att1} and {att2}", filename, type_s
Expand Down Expand Up @@ -238,8 +235,7 @@ def _save_plot(self, ax, caption, header, filename, type_s):
)

def _check_binning(self, att, df):
bin_reals_att = eval(self.util.config_val(
"EXPL", f"{att}.bin_reals", "False"))
bin_reals_att = eval(self.util.config_val("EXPL", f"{att}.bin_reals", "False"))
if bin_reals_att:
self.util.debug(f"binning continuous variable {att} to categories")
att_new = f"{att}_binned"
Expand Down Expand Up @@ -342,8 +338,7 @@ def plot_durations(self, df, filename, sample_selection, caption=""):

def describe_df(self, name, df, target, filename):
"""Make a stacked barplot of samples and speakers per sex and target values. speaker, gender and target columns must be present"""
fig_dir = self.util.get_path(
"fig_dir") + "../" # one up because of the runs
fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
sampl_num = df.shape[0]
sex_col = "gender"
if target == "gender":
Expand Down Expand Up @@ -392,8 +387,7 @@ def scatter_plot(self, feats, label_df, label, dimred_type):
dim_num = int(self.util.config_val("EXPL", "scatter.dim", 2))
# one up because of the runs
fig_dir = self.util.get_path("fig_dir") + "../"
sample_selection = self.util.config_val(
"EXPL", "sample_selection", "all")
sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
filename = f"{label}_{self.util.get_feattype_name()}_{sample_selection}_{dimred_type}_{str(dim_num)}d"
filename = f"{fig_dir}{filename}.{self.format}"
self.util.debug(f"computing {dimred_type}, this might take a while...")
Expand Down Expand Up @@ -435,8 +429,7 @@ def scatter_plot(self, feats, label_df, label, dimred_type):

if dim_num == 2:
plot_data = np.vstack((data.T, labels)).T
plot_df = pd.DataFrame(
data=plot_data, columns=("Dim_1", "Dim_2", "label"))
plot_df = pd.DataFrame(data=plot_data, columns=("Dim_1", "Dim_2", "label"))
# plt.tight_layout()
ax = (
sns.FacetGrid(plot_df, hue="label", height=6)
Expand Down
17 changes: 17 additions & 0 deletions nkululeko/reporting/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,23 @@ def make_conf_animation(self, out_name):
def get_result(self):
return self.result

def plot_epoch_progression_finetuned(self, df):
plot_name_suggest = self.util.get_exp_name()
fig_dir = self.util.get_path("fig_dir")
plot_name = (
self.util.config_val("PLOT", "name", plot_name_suggest)
+ "_epoch_progression"
)
ax = df.plot()
fig = ax.figure
plt.xlabel("epochs")
plt.ylabel(f"{self.MEASURE}")
plot_path = f"{fig_dir}{plot_name}.{self.format}"
plt.savefig(plot_path)
self.util.debug(f"plotted epoch progression to {plot_path}")
plt.close(fig)
fig.clear()

def plot_epoch_progression(self, reports, out_name):
fig_dir = self.util.get_path("fig_dir")
results, losses, train_results, losses_eval = [], [], [], []
Expand Down
4 changes: 2 additions & 2 deletions tests/exp_emodb_finetune.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ emodb = ./data/emodb/emodb
emodb.split_strategy = specified
emodb.test_tables = ['emotion.categories.test.gold_standard']
emodb.train_tables = ['emotion.categories.train.gold_standard']
labels = ['anger', 'happiness']
labels = ['anger', 'sadness']
target = emotion
[FEATS]
type = []
[MODEL]
type = finetune
device = cpu
batch_size = 8
batch_size = 4
# pretrained_model = microsoft/wavlm-base

0 comments on commit 00d1324

Please sign in to comment.