Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jurihock committed Mar 29, 2024
1 parent 5565784 commit 4475c4f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 64 deletions.
120 changes: 56 additions & 64 deletions tests/test_remucs.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,35 @@
import pathlib
from dataclasses import dataclass
from pathlib import Path
from typing import List
from test_utils import find, freqs, isless, issame, time, wave

import numpy
import pytest
import soundfile

import remucs

numpy.set_printoptions(suppress=True)

def time(d, sr):

return numpy.arange(0, d, 1/sr)

def wave(f, t):

return numpy.sin(t[..., None] * f * numpy.pi * 2)

def freqs(x, sr):

n = min(4096, len(x))

w = numpy.hanning(n)
y = numpy.fft.rfft(w[..., None] * x[-n:], axis=0)

db = 20 * numpy.log10(2 * numpy.abs(y) / numpy.sum(w))
hz = numpy.fft.rfftfreq(n, 1/sr)

db = numpy.ceil(db).astype(int)

return db, hz

def find(hz, f):
DEBUG = False

return numpy.abs(hz[..., None] - f).argmin(axis=0)

def issame(x, y, tol=1):
numpy.set_printoptions(suppress=True)

return abs(x - y) <= tol

def isless(x, y):
@dataclass
class Session:
data: Path
src: Path
dst: Path
sr: int
f: List[int]

return x <= y

@pytest.fixture(scope='session')
def session(tmpdir_factory):
@pytest.fixture(name='session', scope='session')
def create_test_session(tmpdir_factory) -> Session:

if False:
data = pathlib.Path(__file__).resolve().parent.parent
if DEBUG:
data = Path(__file__).resolve().parent.parent
else:
data = pathlib.Path(tmpdir_factory.mktemp('remucs'))
data = Path(tmpdir_factory.mktemp('remucs'))

src = data / 'test.wav'
dst = data / 'test.remucs.wav'
Expand All @@ -59,37 +40,39 @@ def session(tmpdir_factory):
hz = numpy.fft.rfftfreq(4096, 1/sr)
f = hz[find(hz, f)]

f = numpy.floor(f).astype(int)
f = numpy.floor(f).astype(int).tolist()

return Session(data=data, src=src, dst=dst, sr=sr, f=f)

return dict(sr=sr, f=f, src=src, dst=dst, data=data)

def probe(session, **kwargs):
def probe(session: Session, **kwargs):

data = session['data']
data = session.data

src = session['src']
dst = session['dst']
src = session.src
dst = session.dst

sr = session['sr']
f = session['f']
sr = session.sr
f = session.f

remucs.remucs(src, data, remucs.RemucsOptions(**kwargs))
y = numpy.array(soundfile.read(dst)[0])

db, hz = freqs(y, sr)
i = find(hz, f)
idx = find(hz, f)

return db[i]
return db[idx]

def test_setup(session):

data = session['data']
def test_setup(session: Session):

src = session['src']
dst = session['dst']
data = session.data

sr = session['sr']
f = session['f']
src = session.src
dst = session.dst

sr = session.sr
f = session.f

t = time(1, sr)

Expand All @@ -111,19 +94,22 @@ def test_setup(session):
assert y.shape[0] == len(t)
assert y.shape[1] == 2

def test_debug(session):

data = session['data']
def test_debug(session: Session):

data = session.data

src = session['src']
src = session.src
dst = data / '.remucs' / src.stem / 'htdemucs' / 'other.wav'

sr = session['sr']
f = session['f']
sr = session.sr
f = session.f

x = numpy.array(soundfile.read(src)[0])
y = numpy.array(soundfile.read(dst)[0])

print('f debug', f)

db, hz = freqs(x, sr)
assert len(db) == len(hz)

Expand All @@ -138,7 +124,8 @@ def test_debug(session):
j = numpy.floor(hz[i]).astype(int)
print('y debug', db[i[0]], db[i[1]], '@', j)

def test_stereo(session):

def test_stereo(session: Session):

db = probe(session, mono=False)
print('y stereo', db[0], db[1])
Expand All @@ -148,7 +135,8 @@ def test_stereo(session):
assert isless(db[1, 0], -40)
assert issame(db[1, 1], 0)

def test_mono(session):

def test_mono(session: Session):

db = probe(session, mono=True)
print('y mono', db[0], db[1])
Expand All @@ -158,7 +146,8 @@ def test_mono(session):
assert issame(db[1, 0], -6)
assert issame(db[1, 1], -6)

def test_gain(session):

def test_gain(session: Session):

db = probe(session, norm=False, gain=[1, 1, 0.5, 1])
print('y gain', db[0], db[1])
Expand All @@ -168,7 +157,8 @@ def test_gain(session):
assert isless(db[1, 0], -40)
assert issame(db[1, 1], -6)

def test_norm(session):

def test_norm(session: Session):

db = probe(session, norm=True, gain=[1, 1, 0.5, 1])
print('y norm', db[0], db[1])
Expand All @@ -178,7 +168,8 @@ def test_norm(session):
assert isless(db[1, 0], -40)
assert issame(db[1, 1], 0)

def test_balance_left(session):

def test_balance_left(session: Session):

db = probe(session, bala=[0, 0, -1, 0])
print('y balance left', db[0], db[1])
Expand All @@ -188,7 +179,8 @@ def test_balance_left(session):
assert isless(db[1, 0], -40)
assert isless(db[1, 1], -40)

def test_balance_right(session):

def test_balance_right(session: Session):

db = probe(session, bala=[0, 0, +1, 0])
print('y balance right', db[0], db[1])
Expand Down
44 changes: 44 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Tuple
from numpy.typing import ArrayLike, NDArray

import numpy


def issame(x: float, y: float, tol: float = 1) -> bool:

return abs(x - y) <= tol


def isless(x: float, y: float) -> bool:

return x <= y


def find(x: NDArray, y: ArrayLike) -> NDArray:

return numpy.abs(x[..., None] - y).argmin(axis=0)


def freqs(x: NDArray, sr: int) -> Tuple[NDArray, NDArray]:

n = min(4096, len(x))

w = numpy.hanning(n)
y = numpy.fft.rfft(w[..., None] * x[-n:], axis=0)

db = 20 * numpy.log10(2 * numpy.abs(y) / numpy.sum(w))
hz = numpy.fft.rfftfreq(n, 1/sr)

db = numpy.ceil(db).astype(int)

return db, hz


def time(d: int, sr: int) -> NDArray:

return numpy.arange(0, d, 1/sr)


def wave(f: ArrayLike, t: NDArray) -> NDArray:

return numpy.sin(t[..., None] * f * numpy.pi * 2)

0 comments on commit 4475c4f

Please sign in to comment.