diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6129fa6af9f9..73f6c7b55d82 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2435,11 +2435,12 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop + module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward for step, inputs in enumerate(dataloader): inputs = { k: inputs[k] for k in inputs - if k in list(inspect.signature(model.forward).parameters.keys()) + if k in list(inspect.signature(module_forward_fn).parameters.keys()) } # Update the observed num examples