Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 16, 2025
1 parent ffed37c commit f115e78
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 5 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/transformer/tacr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def inner_create_impl(
device=self._device,
enable_ddp=self._enable_ddp,
)
critic_optim = self._config.optim_factory.create(
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/models/torch/q_functions/ensemble_q_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def compute_error(
target: torch.Tensor,
terminals: torch.Tensor,
gamma: Union[float, torch.Tensor] = 0.99,
masks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return compute_ensemble_q_function_error(
forwarders=self._forwarders,
Expand All @@ -193,6 +194,7 @@ def compute_error(
target=target,
terminals=terminals,
gamma=gamma,
masks=masks,
)

def compute_target(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def to_transition_batch(self) -> tuple[TorchMiniBatch, torch.Tensor]:
]
actions = self.actions[:, :-1].reshape(-1, *self.actions.shape[2:])
rewards = self.rewards[:, :-1].reshape(-1, 1)
terminals = self.terminals[:, 1:].reshape(-1, 1)
terminals = self.terminals[:, :-1].reshape(-1, 1)
next_actions = self.actions[:, 1:].reshape(-1, *self.actions.shape[2:])
returns_to_go = self.returns_to_go[:, :-1].reshape(-1, 1)
intervals = torch.ones_like(rewards)
Expand Down Expand Up @@ -583,7 +583,7 @@ def __call__(self, batch: BatchT_contra) -> RetT_co:
self._out = self._func(self._inpt)
if self._step >= self._warmup_steps: # reuse cuda graph
assert self._inpt
assert self._out
assert self._out is not None
assert self._graph
with torch.no_grad():
self._inpt.copy_(batch) # type: ignore
Expand Down
32 changes: 32 additions & 0 deletions tests/algos/transformer/test_tacr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional

import pytest

from d3rlpy.algos import TACRConfig
from d3rlpy.types import Shape

from ...models.torch.model_test import DummyEncoderFactory
from ...testing_utils import create_scaler_tuple
from .algo_test import algo_tester


@pytest.mark.parametrize(
"observation_shape", [(100,), (4, 8, 8), ((100,), (200,))]
)
@pytest.mark.parametrize("scalers", [None, "min_max"])
def test_tacr(observation_shape: Shape, scalers: Optional[str]) -> None:
observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
scalers, observation_shape
)
config = TACRConfig(
encoder_factory=DummyEncoderFactory(),
critic_encoder_factory=DummyEncoderFactory(),
observation_scaler=observation_scaler,
action_scaler=action_scaler,
reward_scaler=reward_scaler,
)
tacr = config.create()
algo_tester(
tacr, # type: ignore
observation_shape,
)
28 changes: 26 additions & 2 deletions tests/models/torch/q_functions/test_ensemble_q_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_reduce_quantile_ensemble(
@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn"])
@pytest.mark.parametrize("n_quantiles", [200])
@pytest.mark.parametrize("embed_size", [64])
@pytest.mark.parametrize("use_masks", [False, True])
def test_discrete_ensemble_q_function_forwarder(
observation_shape: Shape,
action_size: int,
Expand All @@ -91,6 +92,7 @@ def test_discrete_ensemble_q_function_forwarder(
q_func_factory: str,
n_quantiles: int,
embed_size: int,
use_masks: bool,
) -> None:
forwarders: list[DiscreteQFunctionForwarder] = []
for _ in range(ensemble_size):
Expand Down Expand Up @@ -172,23 +174,33 @@ def test_discrete_ensemble_q_function_forwarder(
q_tp1 = torch.rand(batch_size, 1)
else:
q_tp1 = torch.rand(batch_size, n_quantiles)
masks = (
torch.randint(2, size=(batch_size, 1), dtype=torch.float32)
if use_masks
else None
)
ref_td_sum = 0.0
for forwarder in forwarders:
ref_td_sum += forwarder.compute_error(
ind_loss = forwarder.compute_error(
observations=obs_t,
actions=act_t,
rewards=rew_tp1,
target=q_tp1,
terminals=ter_tp1,
gamma=gamma,
reduction="none",
)
if masks is not None:
ind_loss *= masks
ref_td_sum += ind_loss.mean()
loss = ensemble_forwarder.compute_error(
observations=obs_t,
actions=act_t,
rewards=rew_tp1,
target=q_tp1,
terminals=ter_tp1,
gamma=gamma,
masks=masks,
)
if q_func_factory != "iqn":
assert torch.allclose(ref_td_sum, loss)
Expand All @@ -202,6 +214,7 @@ def test_discrete_ensemble_q_function_forwarder(
@pytest.mark.parametrize("n_quantiles", [200])
@pytest.mark.parametrize("q_func_factory", ["mean", "qr", "iqn"])
@pytest.mark.parametrize("embed_size", [64])
@pytest.mark.parametrize("use_masks", [False, True])
def test_ensemble_continuous_q_function(
observation_shape: Shape,
action_size: int,
Expand All @@ -211,6 +224,7 @@ def test_ensemble_continuous_q_function(
q_func_factory: str,
n_quantiles: int,
embed_size: int,
use_masks: bool,
) -> None:
forwarders: list[ContinuousQFunctionForwarder] = []
for _ in range(ensemble_size):
Expand Down Expand Up @@ -282,23 +296,33 @@ def test_ensemble_continuous_q_function(
q_tp1 = torch.rand(batch_size, 1)
else:
q_tp1 = torch.rand(batch_size, n_quantiles)
masks = (
torch.randint(2, size=(batch_size, 1), dtype=torch.float32)
if use_masks
else None
)
ref_td_sum = 0.0
for forwarder in forwarders:
ref_td_sum += forwarder.compute_error(
ind_loss = forwarder.compute_error(
observations=obs_t,
actions=act_t,
rewards=rew_tp1,
target=q_tp1,
terminals=ter_tp1,
gamma=gamma,
reduction="none",
)
if masks is not None:
ind_loss = ind_loss * masks
ref_td_sum += ind_loss.mean()
loss = ensemble_forwarder.compute_error(
observations=obs_t,
actions=act_t,
rewards=rew_tp1,
target=q_tp1,
terminals=ter_tp1,
gamma=gamma,
masks=masks,
)
if q_func_factory != "iqn":
assert torch.allclose(ref_td_sum, loss)
49 changes: 49 additions & 0 deletions tests/test_torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,55 @@ def test_torch_trajectory_mini_batch(

assert np.all(torch_batch.terminals.numpy() == batch.terminals)

# test to_transition_batch
transition_batch_size = batch_size * (length - 1)
transition_batch, mask = torch_batch.to_transition_batch()
assert isinstance(transition_batch.observations, torch.Tensor)
assert isinstance(transition_batch.next_observations, torch.Tensor)
assert transition_batch.observations.shape == (
transition_batch_size,
*observation_shape,
)
assert transition_batch.next_observations.shape == (
transition_batch_size,
*observation_shape,
)
assert transition_batch.actions.shape == (
transition_batch_size,
action_size,
)
assert transition_batch.rewards.shape == (transition_batch_size, 1)
assert transition_batch.terminals.shape == (transition_batch_size, 1)
assert transition_batch.returns_to_go.shape == (transition_batch_size, 1)
assert mask.shape == (transition_batch_size, 1)
assert torch.all(
transition_batch.observations
== torch_batch.observations[:, :-1].reshape(-1, *observation_shape)
)
assert torch.all(
transition_batch.next_observations
== torch_batch.observations[:, 1:].reshape(-1, *observation_shape)
)
assert torch.all(
transition_batch.actions
== torch_batch.actions[:, :-1].reshape(-1, action_size)
)
assert torch.all(
transition_batch.next_actions
== torch_batch.actions[:, 1:].reshape(-1, action_size)
)
assert torch.all(
transition_batch.rewards == torch_batch.rewards[:, :-1].reshape(-1, 1)
)
assert torch.all(
transition_batch.terminals
== torch_batch.terminals[:, :-1].reshape(-1, 1)
)
assert torch.all(
transition_batch.returns_to_go
== torch_batch.returns_to_go[:, :-1].reshape(-1, 1)
)

torch_batch2 = TorchTrajectoryMiniBatch(
observations=torch.zeros_like(torch_batch.observations),
actions=torch.zeros_like(torch_batch.actions),
Expand Down

0 comments on commit f115e78

Please sign in to comment.