From 053646a8f65fd40bb022dfe943ba92b7a0e6cafb Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Fri, 27 May 2022 18:44:43 +0100 Subject: [PATCH] Fix DataParallel validation forward signatures (#47) * Fix: DataParallel validation forward signatures * Update: generalize forward_fn selection * nit: space --- src/transformers/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6129fa6af9f975..73f6c7b55d82b5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2435,11 +2435,12 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop + module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward for step, inputs in enumerate(dataloader): inputs = { k: inputs[k] for k in inputs - if k in list(inspect.signature(model.forward).parameters.keys()) + if k in list(inspect.signature(module_forward_fn).parameters.keys()) } # Update the observed num examples