llamamodel: prevent CUDA OOM crash by allocating VRAM early #2393
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a proposed fix for the issue where CUDA OOM can happen later than expected and crash GPT4All. The question is whether the benefit (falling back early instead of crashing later) is worth the load latency cost.
After a model is loaded onto a CUDA device, we run one full batch of (meaningless) input through it. Small batches don't use as much VRAM, and llama.cpp seems to allocate the full KV cache for the context regardless of where in context the input lies - so n_batch matters a lot, but n_past seems to not matter at all.
The call to testModel() can be seen in the UI as the progress bar staying at near 100% before the load completes. With 24 layers of Llama 3 8B, this takes about 2 seconds on my GTX 970 and 0.3 seconds on my Tesla P40. Worst case timing under high memory pressure and a batch size of 512 (which I had to patch in since the upper limit is normally 128) is about 11.2 seconds. At a batch size of 128 I have seen this take as long as 7.6 seconds.
Testing
You can test this PR by choosing a model that does not fit in your card's VRAM and finding a number of layers to offload that just barely doesn't fit. On the main branch, GPT4All can crash either during load or when you are sending input to it. With this PR, an exception is logged to the console during testModel() and GPT4All falls back to CPU as it does for Kompute.