forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR mostly ports vllm-project#11389 to the design introduced by #358 and makes the custom caching code a little bit more robust. Currently there are two problems with guided decode: - `mask[list(allowed_tokens)] = 0` is causing crashes due to allowed_tokens containing tensors. Pretty easy fix. - The value type of `self._fsm_state` was changed from `int` to union of `int` and `outlines.state.CFGState`, which may cause `self._cached_get_mask_tensor(state_id, scores.size(-1), scores.device)` to crash, as `outlines.state.CFGState` is not hashable. This PR changes the caching mechanism so that if function arguments are not hashable, their id is taken as key. This might cause some cache misses, but that's better than crashing, as it does right now. None of the above is problem on upstream, as this stems from code introduced in #358. I've also added guided decode tests to CI suite.
- Loading branch information
1 parent
e8f66d5
commit 4d91f3b
Showing
3 changed files
with
32 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters