Skip to content

Commit

Permalink
Upgrade Python version for mypy config
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 18, 2024
1 parent 8cb7c4d commit a2eddd6
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def slice_observations(
observations: ObservationSequence, start: int, end: int
) -> ObservationSequence:
if isinstance(observations, np.ndarray):
return observations[start:end] # type: ignore
return observations[start:end]
elif isinstance(observations, (list, tuple)):
return [obs[start:end] for obs in observations]
else:
Expand Down
7 changes: 4 additions & 3 deletions d3rlpy/dataset/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def append(
self._cursor += 1

def to_episode(self, terminated: bool) -> Episode:
observations: ObservationSequence
if len(self._observations) == 1:
observations = self._observations[0][: self._cursor].copy()
else:
Expand Down Expand Up @@ -187,17 +188,17 @@ def size(self) -> int:
@property
def observations(self) -> ObservationSequence:
if len(self._observations) == 1:
return self._observations[0][: self._cursor] # type: ignore
return self._observations[0][: self._cursor]
else:
return [obs[: self._cursor] for obs in self._observations]

@property
def actions(self) -> NDArray:
return self._actions[: self._cursor] # type: ignore
return self._actions[: self._cursor]

@property
def rewards(self) -> NDArray:
return self._rewards[: self._cursor] # type: ignore
return self._rewards[: self._cursor]

@property
def terminated(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.7
python_version = 3.8
strict = True
strict_optional = True
disallow_untyped_defs = True
Expand Down
8 changes: 4 additions & 4 deletions tests/dataset/test_transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def test_basic_transition_picker(
transition = picker(episode, length - 1)
if isinstance(observation_shape[0], tuple):
for i, shape in enumerate(observation_shape):
dummy_observation = np.zeros(shape)
dummy_observation = np.zeros(shape) # type: ignore
assert transition.observation_signature.shape[i] == shape
assert np.all(
transition.observation[i] == episode.observations[i][-1]
)
assert np.all(transition.next_observation[i] == dummy_observation)
else:
dummy_observation = np.zeros(observation_shape)
dummy_observation = np.zeros(observation_shape) # type: ignore
assert transition.observation_signature.shape[0] == observation_shape
assert np.all(transition.observation == episode.observations[-1])
assert np.all(transition.next_observation == dummy_observation)
Expand Down Expand Up @@ -188,14 +188,14 @@ def test_multi_step_transition_picker(
transition = picker(episode, length - n_steps)
if isinstance(observation_shape[0], tuple):
for i, shape in enumerate(observation_shape):
dummy_observation = np.zeros(shape)
dummy_observation = np.zeros(shape) # type: ignore
assert transition.observation_signature.shape[i] == shape
assert np.all(
transition.observation[i] == episode.observations[i][-n_steps]
)
assert np.all(transition.next_observation[i] == dummy_observation)
else:
dummy_observation = np.zeros(observation_shape)
dummy_observation = np.zeros(observation_shape) # type: ignore
assert transition.observation_signature.shape[0] == observation_shape
assert np.all(transition.observation == episode.observations[-n_steps])
assert np.all(transition.next_observation == dummy_observation)
Expand Down
11 changes: 6 additions & 5 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def create_observation(
np.random.random(shape).astype(dtype) for shape in observation_shape
]
else:
observation = np.random.random(observation_shape).astype(dtype)
observation = np.random.random(observation_shape) # type: ignore
observation = observation.astype(dtype)
return observation


Expand Down Expand Up @@ -205,10 +206,10 @@ def create_transition(
for shape in observation_shape
]
else:
observation = np.random.random(observation_shape).astype(np.float32)
next_observation = np.random.random(observation_shape).astype(
np.float32
)
observation = np.random.random(observation_shape) # type: ignore
observation = observation.astype(np.float32)
next_observation = np.random.random(observation_shape) # type: ignore
next_observation = next_observation.astype(np.float32)

action: NDArray
if discrete_action:
Expand Down

0 comments on commit a2eddd6

Please sign in to comment.