Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue resuming training on tansformer based NER #6323

Open
fcggamou opened this issue Oct 29, 2020 · 11 comments
Open

Issue resuming training on tansformer based NER #6323

fcggamou opened this issue Oct 29, 2020 · 11 comments
Labels
feat / transformer Feature: Transformer 🌙 nightly Discussion and contributions related to nightly builds perf / memory Performance: memory use training Training and updating models

Comments

@fcggamou
Copy link

fcggamou commented Oct 29, 2020

I'm using the nightly version, I have successfully trained a transformer based NER model and saved it; now I'm trying to resume training on it.

Firstly, I'm not sure if I have set up the config file correctly, the relevant part looks like this:

[components]

[components.ner]
# This is the path to my trained model
source='best-model'

[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v1"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = false
nO = null

[components.ner.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
pooling = {"@layers":"reduce_mean.v1"}

[components.transformer]
# This is the path to my trained model
source='best-model'

[components.transformer.model]
@architectures = "spacy-transformers.TransformerModel.v1"
name = "dccuchile/bert-base-spanish-wwm-uncased"

[components.transformer.model.get_spans]
@span_getters = "spacy-transformers.strided_spans.v1"
window = 128
stride = 96

[components.transformer.model.tokenizer_config]
use_fast = true

Now, after trying to train like this:
!python -m spacy train 'config.cfg' --output='model_t' --gpu-id=0 --paths.train train.spacy --paths.dev test.spacy

I'm getting this error message:

2020-10-29 14:36:11.541313: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
ℹ Using GPU: 0

=========================== Initializing pipeline ===========================
Set up nlp object from config
Pipeline: ['transformer', 'ner']
Resuming training for: ['ner', 'transformer']
Created vocabulary
Finished initializing nlp object
Initialized pipeline components: []
✔ Initialized pipeline

============================= Training pipeline =============================
ℹ Pipeline: ['transformer', 'ner']
ℹ Initial learn rate: 0.0
E    #       LOSS TRANS...  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE 
---  ------  -------------  --------  ------  ------  ------  ------
⚠ Aborting and saving the final best model. Encountered exception: CUDA
out of memory. Tried to allocate 94.00 MiB (GPU 0; 15.75 GiB total capacity;
13.81 GiB already allocated; 78.88 MiB free; 14.34 GiB reserved in total by
PyTorch)
✔ Saved pipeline to output directory
model_t2/model-last
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/dist-packages/spacy/__main__.py", line 4, in <module>
    setup_cli()
  File "/usr/local/lib/python3.6/dist-packages/spacy/cli/_util.py", line 65, in setup_cli
    command(prog_name=COMMAND)
  File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 1259, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/typer/main.py", line 497, in wrapper
    return callback(**use_params)  # type: ignore
  File "/usr/local/lib/python3.6/dist-packages/spacy/cli/train.py", line 59, in train_cli
    train(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)
  File "/usr/local/lib/python3.6/dist-packages/spacy/training/loop.py", line 105, in train
    raise e
  File "/usr/local/lib/python3.6/dist-packages/spacy/training/loop.py", line 85, in train
    for batch, info, is_best_checkpoint in training_step_iterator:
  File "/usr/local/lib/python3.6/dist-packages/spacy/training/loop.py", line 201, in train_while_improving
    score, other_scores = evaluate()
  File "/usr/local/lib/python3.6/dist-packages/spacy/training/loop.py", line 253, in evaluate
    scores = nlp.evaluate(dev_examples)
  File "/usr/local/lib/python3.6/dist-packages/spacy/language.py", line 1312, in evaluate
    docs = list(docs)
  File "/usr/local/lib/python3.6/dist-packages/spacy/util.py", line 1363, in _pipe
    yield from proc.pipe(docs, **kwargs)
  File "spacy/pipeline/transition_parser.pyx", line 170, in pipe
  File "/usr/local/lib/python3.6/dist-packages/spacy/util.py", line 1322, in minibatch
    batch = list(itertools.islice(items, int(batch_size)))
  File "/usr/local/lib/python3.6/dist-packages/spacy/util.py", line 1363, in _pipe
    yield from proc.pipe(docs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/spacy_transformers/pipeline_component.py", line 173, in pipe
    self.set_annotations(subbatch, self.predict(subbatch))
  File "/usr/local/lib/python3.6/dist-packages/spacy_transformers/pipeline_component.py", line 189, in predict
    activations = self.model.predict(docs)
  File "/usr/local/lib/python3.6/dist-packages/thinc/model.py", line 312, in predict
    return self._func(self, X, is_train=False)[0]
  File "/usr/local/lib/python3.6/dist-packages/spacy_transformers/layers/transformer_model.py", line 111, in forward
    tensors, bp_tensors = transformer(token_data, is_train)
  File "/usr/local/lib/python3.6/dist-packages/thinc/model.py", line 288, in __call__
    return self._func(self, X, is_train=is_train)
  File "/usr/local/lib/python3.6/dist-packages/thinc/layers/pytorchwrapper.py", line 79, in forward
    Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
  File "/usr/local/lib/python3.6/dist-packages/thinc/shims/pytorch.py", line 29, in __call__
    return self.predict(inputs), lambda a: ...
  File "/usr/local/lib/python3.6/dist-packages/thinc/shims/pytorch.py", line 38, in predict
    outputs = self._model(*inputs.args, **inputs.kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_bert.py", line 762, in forward
    output_hidden_states=output_hidden_states,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_bert.py", line 439, in forward
    output_attentions,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_bert.py", line 388, in forward
    intermediate_output = self.intermediate(attention_output)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_bert.py", line 333, in forward
    hidden_states = self.intermediate_act_fn(hidden_states)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1369, in gelu
    return torch._C._nn.gelu(input)
**RuntimeError: CUDA out of memory. Tried to allocate 94.00 MiB (GPU 0; 15.75 GiB total capacity; 13.81 GiB already allocated; 78.88 MiB free; 14.34 GiB reserved in total by PyTorch)**

I understand the message is telling me I'm out of memory, but it seems weird that I'm able to train from scratch with no issues but getting this error when trying to resume training on the saved model. Any help is appreciated.

Your Environment

  • spaCy version: 3.0.0rc2
  • Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • Pipelines: es_core_news_md (3.0.0a0), es_dep_news_trf (3.0.0a0)
@fcggamou fcggamou changed the title Issue continue training on tansformer based NER Issue resuming training on tansformer based NER Oct 29, 2020
@adrianeboyd adrianeboyd added training Training and updating models 🌙 nightly Discussion and contributions related to nightly builds labels Oct 29, 2020
@fcggamou
Copy link
Author

I noticed from the stack trace that error raises on evaluation over the dev set, so I reduced its size to half and now it works ok.
It's not an ideal workaround though, any other suggestion is appreciated.

@svlandeg svlandeg added perf / memory Performance: memory use feat / transformer Feature: Transformer labels Nov 12, 2020
@svlandeg
Copy link
Member

I'm a bit confused, when you originally trained the model, didn't you evaluate it on the dev set?

@fcggamou
Copy link
Author

I did evaluate it on the dev set, and it worked, hence my confusion as well: why does it work when training from scratch but it fails when attempting to re-train?

@svlandeg
Copy link
Member

And the pipeline was otherwise entirely the same? So there are no differences between the first training run and the resuming run (other than the "source" bit in the config, ofcourse)

@fcggamou
Copy link
Author

Yes, exactly the same. Also the same train and dev data.

@maxtrem
Copy link

maxtrem commented Nov 12, 2020

Looking at the output:

Aborting and saving the final best model. Encountered exception: CUDA
out of memory. Tried to allocate 94.00 MiB (GPU 0; 15.75 GiB total capacity;
13.81 GiB already allocated; 78.88 MiB free; 14.34 GiB reserved in total by
PyTorch)

It appears that your memory is already occupied somewhere else. I'm not sure how this fits together with the shell command (this should actually not apply here), but PyTorch can sometimes be a bit problematic when it comes to releasing GPU memory.

Edit: It could of course be that the allocation happens gradually and only the last part is shown. But it may be worth checking the memory allocation before resuming the training.

@svlandeg
Copy link
Member

svlandeg commented Nov 12, 2020

Either way, this part in language.evaluate() is probably the culprit:

        if len(self.pipeline):
            docs = list(docs)

This was added just for timing purposes. I think the code should simply still run without these two lines though - any chance you can check whether removing them improves things memory-wise? (your timing results will be temporarily wrong but let's worry about that later)

@adrianeboyd
Copy link
Contributor

Oops, yeah that should be done differently. (But I don't understand why this ends up different in the second round than in the first?)

@fcggamou
Copy link
Author

fcggamou commented Nov 12, 2020

Great! Thanks a lot for the workaround, I will test this and post an update.

@fcggamou
Copy link
Author

fcggamou commented Dec 7, 2020

Just FYI the workaround did not work, I still get the same error in this line:

for i, (doc, eg) in enumerate(zip(docs, examples)):

I pulled your fix @adrianeboyd and still I get the same OOM exception at language.py line 1319:

# iterate over final generator
if docs is not None:
   for doc in docs:
      pass

Is this just for timing purposes? Can I safely remove those lines? Thanks!

@adrianeboyd
Copy link
Contributor

It isn't just for timing purposes because you're not actually running the final component (which is the NER model you're trying to train) unless you iterate over that generator. (Earlier versions had the scorer iterate over this generator, and the overall goal here was to separate the pipeline timing from the scorer timing.) I think the previous version was still a bit clunky so I've reworked it a bit more. Can you try the updated version here? #6386

Looking at this again, I think the problem might actually be that the default batch_size (256) is too high for a GPU if you have some longer dev docs. We've trained a fair number of models internally, but we don't have many docs that are over a paragraph or so long. How many dev docs were you using? Were any particularly long? Using my updated PR, is it better if you manually lower the default batch_size in the evaluate() kwargs?

We're also running into some memory issues internally on CPU (for xx models that we haven't published yet) either due to large training corpora or long dev docs, so I'll be looking into a few spots where we can improve the memory usage in the near future.

Since this is something that may need to be adjusted and have different defaults for CPU vs. GPU, I think we'll most likely need a way to specify the batch size for evaluate from the config, but I'm not sure exactly how yet. We may need to add a training parameter like eval_batch_size? We'll have to discuss what makes sense...

(And I still don't know what's going on with the differences between training from scratch and resuming.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat / transformer Feature: Transformer 🌙 nightly Discussion and contributions related to nightly builds perf / memory Performance: memory use training Training and updating models
Projects
None yet
Development

No branches or pull requests

4 participants