Skip to content

Commit

Permalink
add test for extra labels
Browse files Browse the repository at this point in the history
  • Loading branch information
aryarm authored Nov 22, 2024
1 parent f96a20a commit c7d550d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
5 changes: 1 addition & 4 deletions haptools/data/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ def encode(self, labels: tuple[str] = None):
# replace the "pop" labels
arr = rcf.drop_fields(blocks[strand_num], ["pop"])
blocks[strand_num] = rcf.merge_arrays((arr, ints), flatten=True)[names]
self.labels = {
k:v for k,v in labels.items()
if k in seen
}
self.labels = {k: v for k, v in labels.items() if k in seen}

def recode(self):
"""
Expand Down
19 changes: 18 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,7 +2080,24 @@ def test_encode_reorder(self):
expected.labels = {"CEU": 0, "YRI": 1}

observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "YRI"))
observed.encode(labels=("CEU", "YRI", "AMR"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
for sample in expected.data:
for strand in range(len(expected.data[sample])):
exp_strand = expected.data[sample][strand]
obs_strand = observed.data[sample][strand]
assert len(exp_strand) == len(observed.data[sample][strand])
for obs, exp in zip(obs_strand["pop"], exp_strand["pop"]):
assert expected.labels[exp] == obs

# now try again with AMR in the middle
# In that case, it should keep the ordering when deciding the integers
# but the final labels should include the AMR key
expected.labels = {"CEU": 0, "YRI": 2}
observed = self._get_expected_breakpoints()
observed.encode(labels=("CEU", "AMR", "YRI"))

assert observed.labels == expected.labels
assert len(expected.data) == len(observed.data)
Expand Down

0 comments on commit c7d550d

Please sign in to comment.