-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix guided decoding crashes #811
Conversation
vllm/model_executor/guided_decoding/outlines_logits_processors.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/guided_decoding/outlines_logits_processors.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/guided_decoding/outlines_logits_processors.py
Outdated
Show resolved
Hide resolved
@@ -36,12 +38,40 @@ | |||
def _cached(fn): | |||
cache: Dict[Any, Any] = {} | |||
|
|||
def is_hashable(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm thinking whether we should take a slightly different approach. Let me give you an example:
foo = [1, CFGState]
is_hashable(foo) = false => key(foo) = id(foo)
but if we change the logic so that instead of checking if everything is hashable we hash everything we can and use id where it's not possible? i.e.
semi_hash(obj) =
case Iterable -> hash(semi_hash(sub) for sub in obj)
case CFGState -> id(obj)
case Hashable -> hash(obj)
This way we could hash foo like this:
semi_hash(foo) = hash(hash(int), id(CFGState))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I adopted this approach in the new hash_args
function, but I decided not to make a special case for CFGState
, as it's an iterable (namedtuple) containing some non-hashable stuff - so now it's something like
hash_args(obj) =
case Iterable -> hash(tuple(hash_args(sub) for sub in obj))
case Hashable -> hash(obj)
case _ -> hash(id(obj))
so with foo = [1, CFGState]
the hash would look like this:
hash_args(foo) = hash(hash(int), hash(tuple(id(PartialParserState), hash(tuple(int, int, ...)))
the hasher now has one drawback that is currently a non-issue - namely if CFGState
object share the same PartialParserState
and the state is modified with each call, the id (and thus the hash) will stay the same, but the contents aren't - fortunately for us, it seems like the second field of CFGState is going to give that away; if we encounter bugs here, it probably would be a better idea to use hashes of some random uuids rather than unhashable object id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR mostly ports vllm-project#11389 to the design introduced by #358 and makes the custom caching code a little bit more robust.
Currently there are two problems with guided decode:
mask[list(allowed_tokens)] = 0
is causing crashes due to allowed_tokens containing tensors. Pretty easy fix.self._fsm_state
was changed fromint
to union ofint
andoutlines.state.CFGState
, which may causeself._cached_get_mask_tensor(state_id, scores.size(-1), scores.device)
to crash, asoutlines.state.CFGState
is not hashable. This PR changes the caching mechanism so that if function arguments are not hashable, their id is taken as key. This might cause some cache misses, but that's better than crashing, as it does right now.None of the above is problem on upstream, as this stems from code introduced in HPU: offload logits processing to CPU #358.
I've also added guided decode tests to CI suite.