Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix guided decoding crashes #811

Merged
merged 5 commits into from
Feb 12, 2025
Merged

Conversation

kzawora-intel
Copy link

@kzawora-intel kzawora-intel commented Feb 10, 2025

This PR mostly ports vllm-project#11389 to the design introduced by #358 and makes the custom caching code a little bit more robust.
Currently there are two problems with guided decode:

  • mask[list(allowed_tokens)] = 0 is causing crashes due to allowed_tokens containing tensors. Pretty easy fix.
  • The value type of self._fsm_state was changed from int to union of int and outlines.state.CFGState, which may cause self._cached_get_mask_tensor(state_id, scores.size(-1), scores.device) to crash, as outlines.state.CFGState is not hashable. This PR changes the caching mechanism so that if function arguments are not hashable, their id is taken as key. This might cause some cache misses, but that's better than crashing, as it does right now.
    None of the above is problem on upstream, as this stems from code introduced in HPU: offload logits processing to CPU #358.
    I've also added guided decode tests to CI suite.

@@ -36,12 +38,40 @@
def _cached(fn):
cache: Dict[Any, Any] = {}

def is_hashable(obj):
Copy link

@madamczykhabana madamczykhabana Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm thinking whether we should take a slightly different approach. Let me give you an example:
foo = [1, CFGState]
is_hashable(foo) = false => key(foo) = id(foo)
but if we change the logic so that instead of checking if everything is hashable we hash everything we can and use id where it's not possible? i.e.

semi_hash(obj) =
  case Iterable -> hash(semi_hash(sub) for sub in obj)
  case CFGState -> id(obj)
  case Hashable -> hash(obj)

This way we could hash foo like this:
semi_hash(foo) = hash(hash(int), id(CFGState))

Copy link
Author

@kzawora-intel kzawora-intel Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted this approach in the new hash_args function, but I decided not to make a special case for CFGState, as it's an iterable (namedtuple) containing some non-hashable stuff - so now it's something like

hash_args(obj) =
  case Iterable -> hash(tuple(hash_args(sub) for sub in obj))
  case Hashable -> hash(obj)
  case _ -> hash(id(obj))

so with foo = [1, CFGState]
the hash would look like this:
hash_args(foo) = hash(hash(int), hash(tuple(id(PartialParserState), hash(tuple(int, int, ...)))

the hasher now has one drawback that is currently a non-issue - namely if CFGState object share the same PartialParserState and the state is modified with each call, the id (and thus the hash) will stay the same, but the contents aren't - fortunately for us, it seems like the second field of CFGState is going to give that away; if we encounter bugs here, it probably would be a better idea to use hashes of some random uuids rather than unhashable object id

Copy link

@madamczykhabana madamczykhabana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgawarkiewicz mgawarkiewicz merged commit 4d91f3b into habana_main Feb 12, 2025
15 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants