Skip to content

Commit

Permalink
add tests for derivation process
Browse files Browse the repository at this point in the history
  • Loading branch information
LoannPeurey committed Mar 13, 2024
1 parent 727f69b commit 35f625e
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 20 deletions.
70 changes: 50 additions & 20 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from ChildProject.projects import ChildProject
from ChildProject.annotations import AnnotationManager
from ChildProject.tables import IndexTable
from ChildProject.converters import *
from functools import partial
import glob

import pandas as pd
import numpy as np
import datetime
import os
import pytest
import shutil
import subprocess
import sys
from pathlib import Path
import time


Expand All @@ -20,18 +18,19 @@ def standardize_dataframe(df, columns):
return df.sort_index(axis=1).sort_values(list(columns)).reset_index(drop=True)


DATA = os.path.join('tests', 'data')
TRUTH = os.path.join('tests', 'truth')
DATA = Path('tests', 'data')
TRUTH = Path('tests', 'truth')
PATH = Path('output', 'annotations')


@pytest.fixture(scope="function")
def project(request):
if os.path.exists("output/annotations"):
if os.path.exists(PATH):
# shutil.copytree(src="examples/valid_raw_data", dst="output/annotations")
shutil.rmtree("output/annotations")
shutil.copytree(src="examples/valid_raw_data", dst="output/annotations")
shutil.rmtree(PATH)
shutil.copytree(src="examples/valid_raw_data", dst=PATH)

project = ChildProject("output/annotations")
project = ChildProject(PATH)

yield project

Expand Down Expand Up @@ -286,13 +285,6 @@ def test_multiple_imports(project, am, input_file, ow, rimported, rerrors, excep
assert len(errors) == 0 and len(warnings) == 0, "malformed annotation indexes detected"


@pytest.mark.parametrize("input_set",
[(1)]
)
def test_derive(project, am, input_set):
pass


# function used as a derivation function, it should throw errors if not returning dataframe or without required columns
def dv_func(a, b, x, type):
if type == 'number':
Expand All @@ -303,17 +295,50 @@ def dv_func(a, b, x, type):
return x


@pytest.mark.parametrize("exists,ow",
[(False, False),
(True, False),
(False, True),
(True, True),
])
def test_derive(project, am, exists, ow):
input_set = 'vtc_present'
output_set = 'output'
function = partial(dv_func, type='normal')
am.read()

# copy the input set to act as an existing output_set
if exists:
shutil.copytree(src=PATH / 'annotations' / input_set, dst=PATH / 'annotations' / output_set)
additions = am.annotations[am.annotations['set'] == input_set].copy()
additions['set'] = output_set
am.annotations = pd.concat([am.annotations, additions])

imported, errors = am.derive_annotations(input_set, output_set, function, overwrite_existing=ow)
assert imported.shape[0] == am.annotations[am.annotations['set'] == input_set].shape[0]
assert errors.shape[0] == 0

truth = am.annotations[am.annotations['set'] == input_set]
truth['merged_from'] = truth['set']
truth['set'] = output_set
truth['format'] = 'NA'
cols = ['imported_at', 'package_version']
pd.testing.assert_frame_equal(truth.drop(columns=cols).reset_index(drop=True),
imported.drop(columns=cols).reset_index(drop=True))


# function used for derivation but does not hav correct signature
def bad_func(a, b):
return b


@pytest.mark.parametrize("input_set,function,output_set,exists,ow,error",
[("missing", partial(dv_func, type='normal'), "output", False, False, AssertionError),
("vtc_present", partial(dv_func, type='number'), "output", False, False, None),
("vtc_present", partial(dv_func, type='columns'), "output", False, False, None),
("vtc_present", bad_func, "output", False, False, None),
("vtc_present", partial(dv_func, type='normal'), "vtc_present", False, False, AssertionError),
("vtc_present", partial(dv_func, type='normal'), "output", True, False, AssertionError),
("vtc_present", partial(dv_func, type='normal'), "output", True, True, AssertionError),
("vtc_present", 'not_a_function', "output", False, False, ValueError),
])
def test_derive_inputs(project, am, input_set, function, output_set, exists, ow, error):
am.read()
Expand All @@ -326,7 +351,12 @@ def test_derive_inputs(project, am, input_set, function, output_set, exists, ow,

if error:
with pytest.raises(error):
am.rename_set(old, new)
am.derive_annotations(input_set, output_set, function, overwrite_existing=ow)
else:
imported, errors = am.derive_annotations(input_set, output_set, function, overwrite_existing=ow)
# check that 0 lines were imported because of bad input
assert imported.shape[0] == 0
assert errors.shape[0] == am.annotations[am.annotations['set'] == input_set].shape[0]

def test_intersect(project, am):
input_annotations = pd.read_csv("examples/valid_raw_data/annotations/intersect.csv")
Expand Down
49 changes: 49 additions & 0 deletions tests/test_derivations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pandas as pd
from pathlib import Path

from ChildProject.projects import ChildProject
import ChildProject.pipelines.derivations as deriv

CP_PATH = Path('examples','valid_raw_data')
CSV_DF = pd.read_csv(Path('examples', 'valid_raw_data', 'annotations', 'vtc_present', 'converted',
'sound2_0_4000.csv'))
TRUTH = Path('tests', 'truth', 'derivations')


def test_conversations():
df = CSV_DF.copy()
meta = {}
project = None

res = deriv.conversations(project, meta, df)
# res.to_csv(TRUTH / 'conversations.csv', index=False)
truth = pd.read_csv(TRUTH / 'conversations.csv')

pd.testing.assert_frame_equal(res, truth, check_dtype=False)

def test_acoustics():
df = CSV_DF.copy()
project = ChildProject(CP_PATH)
project.read()
meta = {'recording_filename': 'sound.wav'}

res = deriv.acoustics(project, meta, df, profile=None, target_sr=4096)
# res.to_csv(TRUTH / 'acoustics.csv', index=False)
truth = pd.read_csv(TRUTH / 'acoustics.csv')

print(truth.to_string())
print(res.to_string())

pd.testing.assert_frame_equal(res, truth)


def test_remove_overlaps():
df = CSV_DF.copy()
meta = {}
project = None

res = deriv.remove_overlaps(project, meta, df)
# res.to_csv(TRUTH / 'remove-overlaps.csv', index=False)
truth = pd.read_csv(TRUTH / 'remove-overlaps.csv')

pd.testing.assert_frame_equal(res, truth, check_dtype=False)
17 changes: 17 additions & 0 deletions tests/truth/derivations/acoustics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
segment_onset,segment_offset,speaker_type,raw_filename,mean_pitch_semitone,median_pitch_semitone,p5_pitch_semitone,p95_pitch_semitone,pitch_range_semitone
0,342,CHI,example.rttm,30.556933071954266,40.273725722703276,14.038385565680954,40.273725722703276,26.23534015702232
113,486,FEM,example.rttm,39.84245110739157,40.273725722703276,39.10928426136167,40.273725722703276,1.1644414613416032
200,901,OCH,example.rttm,34.18339779437408,40.273725722703276,12.867250045221894,40.273725722703276,27.40647567748138
782,1421,MAL,example.rttm,38.294152588657674,40.273725722703276,31.36564661949806,40.273725722703276,8.908079103205214
1401,1753,,example.rttm,40.273725722703276,40.273725722703276,40.273725722703276,40.273725722703276,0.0
203,1000,CHI,example.rttm,35.17511859935001,40.273725722703276,15.29055081827227,40.273725722703276,24.983174904431007
1200,1656,OCH,example.rttm,34.6622956535196,40.273725722703276,21.19486348747878,40.273725722703276,19.078862235224495
1350,2111,MAL,example.rttm,39.05753217752032,40.273725722703276,34.3143773513068,40.273725722703276,5.959348371396473
1821,2324,FEM,example.rttm,38.991258223371645,40.273725722703276,35.14385572537674,40.273725722703276,5.129869997326537
2301,2845,,example.rttm,37.46240668323427,40.273725722703276,29.02844956482724,40.273725722703276,11.245276157876035
2556,2890,CHI,example.rttm,33.547635907323105,40.273725722703276,22.11328322117681,40.273725722703276,18.160442501526465
2901,3205,FEM,example.rttm,36.16407064848453,40.273725722703276,29.177657022312662,40.273725722703276,11.096068700390614
3056,3457,OCH,example.rttm,32.7496604661,40.273725722703276,14.691903850252153,40.273725722703276,25.581821872451123
3301,3654,MAL,example.rttm,30.550152227679973,40.273725722703276,14.020077286140356,40.273725722703276,26.25364843656292
3456,3705,,example.rttm,36.88339947865374,36.88339947865374,33.83210585900916,39.934693098298325,6.102587239289164
3900,3987,CHI,example.rttm,34.47045449069267,34.47045449069267,34.47045449069267,34.47045449069267,0.0
14 changes: 14 additions & 0 deletions tests/truth/derivations/conversations.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
segment_onset,segment_offset,speaker_type,raw_filename,iti,prev_speaker_type,delay,is_CT,conv_count
0,342,CHI,example.rttm,,,,False,1
113,486,FEM,example.rttm,-229.0,CHI,113.0,True,1
200,901,OCH,example.rttm,-286.0,FEM,87.0,True,1
782,1421,MAL,example.rttm,-119.0,OCH,582.0,True,1
203,1000,CHI,example.rttm,-1218.0,MAL,-579.0,False,2
1200,1656,OCH,example.rttm,200.0,CHI,997.0,True,2
1350,2111,MAL,example.rttm,-306.0,OCH,150.0,True,2
1821,2324,FEM,example.rttm,-290.0,MAL,471.0,True,2
2556,2890,CHI,example.rttm,232.0,FEM,735.0,True,2
2901,3205,FEM,example.rttm,11.0,CHI,345.0,True,2
3056,3457,OCH,example.rttm,-149.0,FEM,155.0,True,2
3301,3654,MAL,example.rttm,-156.0,OCH,245.0,True,2
3900,3987,CHI,example.rttm,246.0,MAL,599.0,True,2
10 changes: 10 additions & 0 deletions tests/truth/derivations/remove-overlaps.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
segment_onset,segment_offset,speaker_type,raw_filename
0,113,CHI,example.rttm
1000,1200,MAL,example.rttm
1656,1821,MAL,example.rttm
2111,2324,FEM,example.rttm
2556,2890,CHI,example.rttm
2901,3056,FEM,example.rttm
3205,3301,OCH,example.rttm
3457,3654,MAL,example.rttm
3900,3987,CHI,example.rttm

0 comments on commit 35f625e

Please sign in to comment.