Skip to content

Commit

Permalink
Use int32 seeds for random sampler on HPU (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Jun 6, 2024
1 parent 7521cba commit 627b95d
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
maybe_expand_dim, is_hpu)

_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
Expand Down Expand Up @@ -507,22 +507,23 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
dtype=torch.int,
pin_memory=pin_memory,
)
idx_dtype = torch.long if not is_hpu() else torch.int # Gaudi doesn't have full native int64 support
sample_indices_t = torch.tensor(
sample_indices,
device="cpu",
dtype=torch.long,
dtype=idx_dtype,
pin_memory=pin_memory,
)
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
dtype=idx_dtype,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
dtype=idx_dtype,
pin_memory=pin_memory,
)
# need to transpose and make contiguous to
Expand All @@ -531,7 +532,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
sampling_seeds_t = torch.tensor(
sampling_seeds,
device="cpu",
dtype=torch.long,
dtype=idx_dtype,
pin_memory=pin_memory,
).T.contiguous()

Expand Down Expand Up @@ -580,7 +581,8 @@ def _get_sequence_seeds(
else:
generator = random.Random(str((seed, ) + extra_entropy))
randint_fn = generator.randint
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
idx_dtype = torch.long if not is_hpu() else torch.int # Gaudi doesn't have full native int64 support
lo, hi = torch.iinfo(idx_dtype).min, torch.iinfo(idx_dtype).max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
Expand Down

0 comments on commit 627b95d

Please sign in to comment.