Skip to content

Commit

Permalink
empty
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 11, 2024
1 parent 89df04d commit 77747db
Showing 1 changed file with 28 additions and 32 deletions.
60 changes: 28 additions & 32 deletions examples/tutorials/audio_resampling_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
This tutorial shows how to use torchaudio's resampling API.
"""

# %%
import torch
import torchaudio
import torchaudio.functional as F
Expand All @@ -16,7 +17,7 @@
print(torch.__version__)
print(torchaudio.__version__)

######################################################################
# %%
# Preparation
# -----------
#
Expand All @@ -26,15 +27,10 @@
import math
import timeit

import librosa
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
import resampy
from IPython.display import Audio

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)

DEFAULT_OFFSET = 201

Expand Down Expand Up @@ -107,7 +103,7 @@ def plot_sweep(
plt.colorbar(cax)


######################################################################
# %%
# Resampling Overview
# -------------------
#
Expand Down Expand Up @@ -151,7 +147,7 @@ def plot_sweep(
plot_sweep(waveform, sample_rate, title="Original Waveform")
Audio(waveform.numpy()[0], rate=sample_rate)

######################################################################
# %%
#
# Now we resample (downsample) it.
#
Expand All @@ -168,7 +164,7 @@ def plot_sweep(
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)

######################################################################
# %%
# Controling resampling quality with parameters
# ---------------------------------------------
#
Expand All @@ -190,13 +186,13 @@ def plot_sweep(
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")

######################################################################
# %%
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")

######################################################################
# %%
# Rolloff
# ~~~~~~~
#
Expand All @@ -216,14 +212,14 @@ def plot_sweep(
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")

######################################################################
# %%
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")


######################################################################
# %%
# Window function
# ~~~~~~~~~~~~~~~
#
Expand All @@ -242,14 +238,14 @@ def plot_sweep(
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")

######################################################################
# %%
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")


######################################################################
# %%
# Comparison against librosa
# --------------------------
#
Expand All @@ -260,7 +256,7 @@ def plot_sweep(
sample_rate = 48000
resample_rate = 32000

######################################################################
# %%
# kaiser_best
# ~~~~~~~~~~~
#
Expand All @@ -275,21 +271,21 @@ def plot_sweep(
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")

######################################################################
# %%
#

librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")

######################################################################
# %%
#

mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)

######################################################################
# %%
# kaiser_fast
# ~~~~~~~~~~~
#
Expand All @@ -304,21 +300,21 @@ def plot_sweep(
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")

######################################################################
# %%
#

librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")

######################################################################
# %%
#

mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)

######################################################################
# %%
# Performance Benchmarking
# ------------------------
#
Expand All @@ -334,7 +330,7 @@ def plot_sweep(
print(f"librosa: {librosa.__version__}")
print(f"resampy: {resampy.__version__}")

######################################################################
# %%
#


Expand Down Expand Up @@ -370,7 +366,7 @@ def benchmark_resample_functional(
)


######################################################################
# %%
#


Expand Down Expand Up @@ -409,7 +405,7 @@ def benchmark_resample_transforms(
)


######################################################################
# %%
#


Expand Down Expand Up @@ -440,7 +436,7 @@ def benchmark_resample_librosa(
)


######################################################################
# %%
#


Expand Down Expand Up @@ -492,7 +488,7 @@ def benchmark(sample_rate, resample_rate):
return df


######################################################################
# %%
#
def plot(df):
print(df.round(2))
Expand All @@ -504,39 +500,39 @@ def plot(df):
ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small")


######################################################################
# %%
#
# Downsample (48 -> 44.1 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(48_000, 44_100)
plot(df)

######################################################################
# %%
#
# Downsample (16 -> 8 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(16_000, 8_000)
plot(df)

######################################################################
# %%
#
# Upsample (44.1 -> 48 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(44_100, 48_000)
plot(df)

######################################################################
# %%
#
# Upsample (8 -> 16 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(8_000, 16_000)
plot(df)

######################################################################
# %%
#
# Summary
# ~~~~~~~
Expand Down

0 comments on commit 77747db

Please sign in to comment.