Skip to content

Commit

Permalink
Merge pull request #41 from yutanagano/allow-batch-size-adjustments
Browse files Browse the repository at this point in the history
Allow batch size adjustments
  • Loading branch information
yutanagano authored Jan 29, 2025
2 parents 317ea66 + 29ef2e1 commit 9d07cf1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/sceptr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.nn import utils


BATCH_SIZE = 512
BATCH_SIZE_DEFAULT = 512


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -134,6 +134,7 @@ def __init__(self, name: str, tokeniser: Tokeniser, bert: Bert) -> None:
self._tokeniser = tokeniser
self._bert = bert.eval()
self._device = torch.device("cpu")
self._batch_size = BATCH_SIZE_DEFAULT

def enable_hardware_acceleration(self) -> None:
"""
Expand All @@ -157,6 +158,17 @@ def disable_hardware_acceleration(self) -> None:
f"disable_hardware_acceleration called on {self} ({self.name}), setting device to cpu"
)

def set_batch_size(self, batch_size: int) -> None:
"""
Set the batch size used when generating TCR vector representations.
That is, how many representations are computed at a time on the CPU / GPU.
By default, the batch size is set to 512.
"""
if not isinstance(batch_size, int):
raise TypeError(f"The batch size must be an int. Got {type(batch_size)}.")

self._batch_size = batch_size

def calc_vector_representations(self, instances: DataFrame) -> ndarray:
"""
Map TCRs to their corresponding vector representations.
Expand Down Expand Up @@ -215,8 +227,8 @@ def calc_residue_representations(
residue_reps_collection = []
compartment_masks_collection = []

for idx in range(0, len(tcrs), BATCH_SIZE):
batch = tcrs.iloc[idx : idx + BATCH_SIZE]
for idx in range(0, len(tcrs), self._batch_size):
batch = tcrs.iloc[idx : idx + self._batch_size]
tokenised_batch = [self._tokeniser.tokenise(tcr) for tcr in batch]
padded_batch = utils.rnn.pad_sequence(
sequences=tokenised_batch,
Expand Down Expand Up @@ -257,8 +269,8 @@ def _calc_torch_representations(self, instances: DataFrame) -> FloatTensor:
tcrs = schema.generate_tcr_series(instances)

representations = []
for idx in range(0, len(tcrs), BATCH_SIZE):
batch = tcrs.iloc[idx : idx + BATCH_SIZE]
for idx in range(0, len(tcrs), self._batch_size):
batch = tcrs.iloc[idx : idx + self._batch_size]
tokenised_batch = [self._tokeniser.tokenise(tcr) for tcr in batch]
padded_batch = utils.rnn.pad_sequence(
sequences=tokenised_batch,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def dummy_data():
return df


@pytest.fixture
def default_model():
return variant.default()


@pytest.mark.parametrize(
"model",
(
Expand Down Expand Up @@ -45,12 +50,14 @@ def test_enable_hardware_acceleration(self, model, caplog):
caplog.set_level(logging.DEBUG)
model.enable_hardware_acceleration()
assert "enable_hardware_acceleration" in caplog.text
assert model.name in caplog.text
model.disable_hardware_acceleration()

def test_disable_hardware_acceleration(self, model, caplog):
caplog.set_level(logging.DEBUG)
model.disable_hardware_acceleration()
assert "disable_hardware_acceleration" in caplog.text
assert model.name in caplog.text

def test_embed(self, model, dummy_data):
result = model.calc_vector_representations(dummy_data)
Expand Down Expand Up @@ -94,3 +101,14 @@ def test_residue_representations(self, model, dummy_data):
):
with pytest.raises(NotImplementedError):
model.calc_residue_representations(dummy_data)


def test_set_batch_size(default_model):
assert default_model._batch_size == 512
default_model.set_batch_size(128)
assert default_model._batch_size == 128


def test_set_batch_size_type_error(default_model):
with pytest.raises(TypeError):
default_model.set_batch_size("128")

0 comments on commit 9d07cf1

Please sign in to comment.