From dc12a66b6b8aeff2a25db618e924513d83fd278b Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:57:46 +0900 Subject: [PATCH] Improve readability of ResidueRepresentations repr --- src/sceptr/model.py | 3 +++ tests/test_residue_representations.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 tests/test_residue_representations.py diff --git a/src/sceptr/model.py b/src/sceptr/model.py index c7e810c..91e7554 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -104,6 +104,9 @@ def __init__(self, representation_array: ndarray, compartment_mask: ndarray) -> self.representation_array = representation_array self.compartment_mask = compartment_mask + def __repr__(self) -> str: + return f"ResidueRepresentations[num_tcrs: {self.representation_array.shape[0]}, rep_dim: {self.representation_array.shape[2]}]" + class Sceptr: """ diff --git a/tests/test_residue_representations.py b/tests/test_residue_representations.py new file mode 100644 index 0000000..64b2355 --- /dev/null +++ b/tests/test_residue_representations.py @@ -0,0 +1,14 @@ +import numpy as np +import pytest +from sceptr.model import ResidueRepresentations + + +def test_repr(res_reps): + assert res_reps.__repr__() == "ResidueRepresentations[num_tcrs: 3, rep_dim: 64]" + + +@pytest.fixture +def res_reps() -> ResidueRepresentations: + rep_array = np.zeros((3, 10, 64)) + comp_mask = np.zeros_like(rep_array, dtype=int) + return ResidueRepresentations(rep_array, comp_mask)