Skip to content

Commit

Permalink
Fix out-of-range value
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 5, 2024
1 parent 42581e3 commit 030f097
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions d3rlpy/tokenizers/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def decode(self, y: Int32NDArray) -> NDArray:


class FloatTokenizer(Tokenizer):
_maximum: float
_minimum: float
_bins: Float32NDArray
_use_mu_law_encode: bool
_mu: float
Expand All @@ -36,6 +38,8 @@ def __init__(
basis: float = 256.0,
token_offset: int = 0,
):
self._maximum = maximum
self._minimum = minimum
self._bins = np.array(
(maximum - minimum) * np.arange(num_bins) / num_bins + minimum,
dtype=np.float32,
Expand All @@ -48,6 +52,8 @@ def __init__(
def __call__(self, x: NDArray) -> Int32NDArray:
if self._use_mu_law_encode:
x = mu_law_encode(x, self._mu, self._basis)
else:
x = np.clip(x, self._minimum, self._maximum)
return np.digitize(x, self._bins) - 1 + self._token_offset

def decode(self, y: Int32NDArray) -> NDArray:
Expand Down

0 comments on commit 030f097

Please sign in to comment.