Skip to content

Commit

Permalink
Fix DataParallel validation forward signatures (#47)
Browse files Browse the repository at this point in the history
* Fix: DataParallel validation forward signatures

* Update: generalize forward_fn selection

* nit: space
  • Loading branch information
KSGulin authored May 27, 2022
1 parent 5afbd46 commit 053646a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,11 +2435,12 @@ def evaluation_loop(

observed_num_examples = 0
# Main evaluation loop
module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward
for step, inputs in enumerate(dataloader):
inputs = {
k: inputs[k]
for k in inputs
if k in list(inspect.signature(model.forward).parameters.keys())
if k in list(inspect.signature(module_forward_fn).parameters.keys())
}

# Update the observed num examples
Expand Down

0 comments on commit 053646a

Please sign in to comment.