diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index d47bd8f1..d57a7116 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -87,9 +87,14 @@ def testGeneformerPerturb(self): ]) assert input_tensor.shape[0] == attention_mask.shape[0] assert input_tensor.shape[1] == attention_mask.shape[1] - outputs = model(input_tensor, - attention_mask=attention_mask, - output_hidden_states=True) + try: + outputs = model(input_tensor, + attention_mask=attention_mask, + output_hidden_states=True) + except Exception as e: + raise Exception( + f"sizes: {input_tensor.shape[0]}, {input_tensor.shape[1]}\n {e}" + ) num_out_in_batch = len(outputs.hidden_states[-1]) input_batch_size = input_tensor.shape[0] num_gene_out_in_batch = len(outputs.hidden_states[-1][0])