Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jan 21, 2025
1 parent c06a67f commit 373b671
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ def testGeneformerPerturb(self):
])
assert input_tensor.shape[0] == attention_mask.shape[0]
assert input_tensor.shape[1] == attention_mask.shape[1]
outputs = model(input_tensor,
attention_mask=attention_mask,
output_hidden_states=True)
try:
outputs = model(input_tensor,
attention_mask=attention_mask,
output_hidden_states=True)
except Exception as e:
raise Exception(
f"sizes: {input_tensor.shape[0]}, {input_tensor.shape[1]}\n {e}"
)
num_out_in_batch = len(outputs.hidden_states[-1])
input_batch_size = input_tensor.shape[0]
num_gene_out_in_batch = len(outputs.hidden_states[-1][0])
Expand Down

0 comments on commit 373b671

Please sign in to comment.