Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SyncDataCollector Crashes with Resources Leak During Data Collection #2644

Closed
3 tasks done
AlexandreBrown opened this issue Dec 11, 2024 · 20 comments
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@AlexandreBrown
Copy link

Describe the bug

I've observed that lateset trainings crash after 180k steps with the following message :

micromamba/envs/dmc_env/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 4 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Killed

To Reproduce

  1. Create the DMControlEnv as follow :
env = TransformedEnv(
    DMControlEnv(
        env_name="cartpole",
        task_name="swingup",
        from_pixels=True,
        pixels_only=True,
        device="cuda"
    )
)

env.append_transform(DoubleToFloat())

env.append_transform(
    FrameSkipTransform(frame_skip=2)
)
env.append_transform(InitTracker())

env.append_transform(
    PermuteTransform(
        dims=(-1, -2, -3), in_keys=["pixels"], out_keys=["pixels"]
    ),  # H W C -> C W H
)
env.append_transform(
    Resize(
        w=cfg["env"]["pixels"]["width"],
        h=cfg["env"]["pixels"]["height"],
        in_keys=["pixels"],
        out_keys=["pixels"],
    ),  # C W H -> C W' H'
)

env.append_transform(
    PermuteTransform(
        dims=(-3, -1, -2), in_keys=["pixels"], out_keys=["pixels"]
    ),  # C W' H' -> C H' W'
)

env.append_transform(
    UnsqueezeTransform(
        dim=-4,
        in_keys=["pixels"],
        out_keys=["pixels"],
    )  # C H' W' -> 1 C H' W'
)

env.append_transform(
    CatFrames(
        N=int(frame_stack),
        dim=-4,
        in_keys=["pixels"],
        out_keys=["pixels"],
    )  # 1 C H' W' -> N C H' W'
)

# Other transforms omitted for previty
  1. Create the replay buffer as follow :
import torch
from omegaconf import DictConfig
from torchrl.data import ReplayBuffer
from torchrl.data import TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage
from torchrl.data import LazyTensorStorage
from torchrl.envs.transforms import Compose
from torchrl.envs.transforms import ExcludeTransform
from segdac.action_scaling.env_action_scaler import TanhEnvActionScaler
from segdac_dev.envs.transforms.unscale_image import UnscaleImage
from segdac_dev.envs.transforms.unscale_action import UnscaleAction
from hydra.utils import instantiate


def get_replay_buffer_data_saving_transforms(cfg: DictConfig) -> list:
    """
    These are transforms executed when saving data to the replay buffer.
    We want to exclude pixels_transformed because it is in float32 (expensive to store), we can store the uint8 RGB image instead.
    """
    transforms = [
        ExcludeTransform(
            "pixels_transformed", ("next", "pixels_transformed"), inverse=True
        ),
    ]

    for save_transform_config in (
        cfg.get("algo", {}).get("replay_buffer", {}).get("save_transforms", [])
    ):
        save_transform = instantiate(save_transform_config)
        assert save_transform.inverse is True

        if isinstance(save_transform, ExcludeTransform):
            arg_key = "_args_"
            next_args = []
            for key in save_transform_config.get(arg_key, []):
                next_args.append(("next", key))
            save_transform.excluded_keys = save_transform.excluded_keys + next_args
        elif isinstance(save_transform, UnscaleImage):
            arg_key = "in_keys_inv"
            next_in_keys_inv_args = []
            for key in save_transform_config.get(arg_key, []):
                next_in_keys_inv_args.append(("next", key))
            save_transform.in_keys_inv = (
                save_transform.in_keys_inv + next_in_keys_inv_args
            )
            arg_key = "out_keys_inv"
            next_out_keys_inv_args = []
            for key in save_transform_config.get(arg_key, []):
                next_out_keys_inv_args.append(("next", key))
            save_transform.out_keys_inv = (
                save_transform.out_keys_inv + next_out_keys_inv_args
            )

        transforms.append(save_transform)

    return transforms


def get_replay_buffer_sample_transforms(
    cfg: DictConfig, env_action_scaler: TanhEnvActionScaler
) -> list:
    """
    These are transforms executed when sampling data from the replay buffer.
    """
    transforms = []
    for sample_transform_config in (
        cfg.get("algo", {}).get("replay_buffer", {}).get("sample_transforms", [])
    ):
        sample_transform = instantiate(sample_transform_config)
        sample_transform.in_keys = sample_transform.in_keys + [
            ("next", key) for key in sample_transform.in_keys
        ]
        sample_transform.out_keys = sample_transform.out_keys + [
            ("next", key) for key in sample_transform.out_keys
        ]
        transforms.append(sample_transform)

    transforms.append(UnscaleAction(env_action_scaler))

    return transforms


def create_replay_buffer(cfg: DictConfig, env_action_scaler) -> ReplayBuffer:
    storage_device = torch.device(cfg["storage_device"]) # cpu in my test
    capacity = cfg["algo"]["replay_buffer"]["capacity"] # 1M in my test

    transforms = []
    transforms.extend(get_replay_buffer_data_saving_transforms(cfg))
    transforms.extend(get_replay_buffer_sample_transforms(cfg, env_action_scaler))
    transform = Compose(*transforms)

    storage_kwargs = {}
    storage_kwargs["max_size"] = capacity
    storage_kwargs["device"] = storage_device
    storage_dim = 1
    if cfg["env"]["num_workers"] > 1: # In my test num_workers = 1
        storage_dim += 1
    storage_kwargs["ndim"] = storage_dim

    if "cpu" in storage_device.type: # cpu was used in my test
        # LazyMemmapStorage is only supported on CPU
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(**storage_kwargs),
            transform=transform,
            batch_size=int(cfg["training"]["batch_size"]), # 128 in my test
        )
    else:
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyTensorStorage(**storage_kwargs),
            transform=transform,
            batch_size=int(cfg["training"]["batch_size"]),
        )

    return replay_buffer
  1. Create the sync data collector :
SyncDataCollector(
        create_env_fn=env,
        policy=policy,
        total_frames=data_collector_cfg["total_frames"], # 1M in my test
        max_frames_per_traj=max_frames_per_traj, # 1000 in my test
        frames_per_batch=frames_per_batch, # 1 in my test
        env_device=cfg["env"]["device"], # cuda in my test
        storing_device=cfg["storage_device"], # cpu in my test
        policy_device=cfg["policy_device"], # cuda in my test
        exploration_type=exploration_type, # RANDOM
        init_random_frames=data_collector_cfg.get("init_random_frames", 0), # 1000 in my test
        postproc=None,
)
  1. yield from the data collector (crash occurs at ~180k steps for me, 2 trainings in a row) :
from tqdm import tqdm

num_iters = 1_000_000
for data in tqdm(
    self.train_data_collector, "Env Data Collection", total=num_iters
):
    env_step += self.train_frames_per_batch

    self.replay_buffer.extend(data)

Expected behavior

No crash

Screenshots

If applicable, add screenshots to help explain your problem.

System info

  • CPU : 6
  • GPU : 1xA100
  • Disk : 100GB
  • RAM : 48GB
  • Headless : yes (cluster)
  • Python : 3.10.16
  • TorchRL : 0.6.0
  • Torch: 2.5.1
  • pip list :
Package                   Version                                                       Editable project location
------------------------- ------------------------------------------------------------- -------------------------------------
absl-py                   2.1.0
annotated-types           0.7.0
antlr4-python3-runtime    4.9.3
asttokens                 3.0.0
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
click                     8.1.7
clip                      1.0
cloudpickle               3.1.0
coloredlogs               15.0.1
comet-ml                  3.47.1
comm                      0.2.2
configobj                 5.0.9
contourpy                 1.3.1
cycler                    0.12.1
Cython                    3.0.11
debugpy                   1.8.9
decorator                 5.1.1
diffusers                 0.31.0
dm_control                1.0.25
dm-env                    1.6
dm-tree                   0.1.8
docker-pycreds            0.4.0
dulwich                   0.22.6
efficientvit              0.0.0
einops                    0.8.0
etils                     1.11.0
everett                   3.1.0
exceptiongroup            1.2.2
executing                 2.1.0
filelock                  3.16.1
flatbuffers               24.3.25
fonttools                 4.55.2
fsspec                    2024.10.0
ftfy                      6.3.1
gitdb                     4.0.11
GitPython                 3.1.43
glfw                      2.8.0
huggingface-hub           0.26.2
humanfriendly             10.0
hydra-core                1.3.2
idna                      3.10
igraph                    0.11.8
imageio                   2.36.1
importlib_metadata        8.5.0
importlib_resources       6.4.5
ipdb                      0.13.13
ipykernel                 6.29.5
ipython                   8.30.0
ipywidgets                8.1.5
jedi                      0.19.2
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyterlab_widgets        3.0.13
kiwisolver                1.4.7
labmaze                   1.0.6
lazy_loader               0.4
lightning-utilities       0.11.9
loguru                    0.7.2
lvis                      0.5.3
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
matplotlib                3.9.3
matplotlib-inline         0.1.7
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.6
nest-asyncio              1.6.0
networkx                  3.4.2
numpy                     1.26.4
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
onnx                      1.17.0
onnxruntime               1.20.1
onnxsim                   0.4.36
opencv-python             4.10.0.84
opencv-python-headless    4.10.0.84
orjson                    3.10.12
packaging                 24.2
pandas                    2.2.3
parso                     0.8.4
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.3.1
platformdirs              4.3.6
prompt_toolkit            3.0.48
protobuf                  5.29.1
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
py-cpuinfo                9.0.0
pycocotools               2.0.8
pydantic                  2.10.3
pydantic_core             2.27.1
Pygments                  2.18.0
PyOpenGL                  3.1.7
PyOpenGL-accelerate       3.1.7
pyparsing                 3.2.0
python-box                6.1.0
python-dateutil           2.9.0.post0
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.22.3
ruamel.yaml               0.18.6
ruamel.yaml.clib          0.2.12
safetensors               0.4.5
scikit-image              0.24.0
scipy                     1.14.1
seaborn                   0.13.2
XXX                    0.0.1                                                       
XXX                0.0.1                                                   
segment_anything          1.0
semantic-version          2.10.0
sentry-sdk                2.19.2
setproctitle              1.3.4
setuptools                75.6.0
simplejson                3.19.3
six                       1.17.0
smmap                     5.0.1
stack-data                0.6.3
sympy                     1.13.1
tensordict                0.6.2
texttable                 1.7.0
tifffile                  2024.9.20
timm                      1.0.12
TinyNeuralNetwork         0.1.0.20241202154922+f79b0ccf02a92247c9cae4ac403c33917f8f6f6f
tokenizers                0.21.0
tomli                     2.2.1
torch                     2.5.1
torch-fidelity            0.3.0
torchaudio                2.5.1
torchmetrics              1.6.0
torchprofile              0.0.4
torchrl                   0.6.0
torchvision               0.20.1
tornado                   6.4.2
tqdm                      4.66.5
traitlets                 5.14.3
transformers              4.47.0
triton                    3.1.0
typing_extensions         4.12.2
tzdata                    2024.2
ultralytics               8.3.48
ultralytics-thop          2.0.13
urllib3                   2.2.3
wandb                     0.19.0
wcwidth                   0.2.13
wheel                     0.45.1
widgetsnbextension        4.0.13
wrapt                     1.17.0
wurlitzer                 3.1.1
zipp                      3.21.0
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

output :

0.6.0 1.26.4 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0] linux

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@AlexandreBrown AlexandreBrown added the bug Something isn't working label Dec 11, 2024
@AlexandreBrown
Copy link
Author

Maybe there is a common/root cause between #2614 and this issue

@AlexandreBrown
Copy link
Author

Maybe it's related to my disk storage being too small ? I'm storing stacked frames (4, 3, 84, 84) into my replay buffer which uses LazyMemmapStorage.
Could be related to #914

@yu-fz
Copy link
Contributor

yu-fz commented Dec 17, 2024

I have ran into a similar problem before. When I was using torchRL with IsaacLab, I would have training runs die midway through when using SyncDataCollector. I made a wrapper for SyncDataCollector that overloaded the iterator() function to remove the CUDA memory management stuff in the beginning, and that seemed to fix the problem.

https://github.com/isaac-sim/IsaacLab/pull/1178/files#diff-82f19e3b1196887a446d1932e2626119f009999373febb153e0fe60e422da9aa


class SyncDataCollectorWrapper(SyncDataCollector):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def iterator(self) -> Iterator[TensorDictBase]:
        """Iterates through the DataCollector.
        Yields: TensorDictBase objects containing (chunks of) trajectories
        """
        # The portion of the code handling cuda streams has been removed in this inherited method, which
        # caused CUDA memory allocation issues with IsaacSim during env stepping.
        total_frames = self.total_frames

@vmoens
Copy link
Contributor

vmoens commented Dec 18, 2024

Is this what you were running into @AlexandreBrown?
@yu-fz Do you think I should patch SyncDataCollector to allow users to turn off the stream? Another option would be to deactive that whenever self.return_same_td=False (

if self.return_same_td:
# This is used with multiprocessed collectors to use the buffers
# stored in the tensordict.
if events:
for event in events:
event.record()
event.synchronize()
yield tensordict_out
) - IIRC this is the only scenario where this is really needed and you won't be using multiprocessed collectors with Isaac (presumably)

@fyu-bdai
Copy link

I think either would work! There shouldn't be a case where one would use multiprocessed collectors with IsaacSim.

@AlexandreBrown
Copy link
Author

I haven't tried the fix but I'm not opposed to having the option.

@tobiabir
Copy link

tobiabir commented Jan 3, 2025

@yu-fz Do you think I should patch SyncDataCollector to allow users to turn off the stream?

This option would be great. We are using nvidia warp in our custom environment and had synchronisation problems because SyncDataCollector sets custom streams and warp uses the default stream. To resolve it we reset the warp stream to the torch stream at every step with warp.set_stream(warp.stream_from_torch()).

@fyu-bdai
Copy link

@vmoens Any updates on this?

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Jan 26, 2025

The crash occurred again today when training a policy on Maniskill3 (which uses Physx just like Isaac so this might be why I also get an issue just like @yu-fz ) .
Will try @yu-fz 's solution.
From a user's perspective it seems odd that a SyncDataCollector crashes with multiprocessing errors.
Would really love this bug to be fixed because it makes training unsustainable (keeps crashing before 1M)

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Jan 26, 2025

I can confirm that the bug can still occur even with @yu-fz 's SyncDataCollectorWrapper.

/home/mila/b/xxx/.conda/envs/maniskill3_env/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 4 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Killed

Maybe the issue is LazyMemmapStorage, will try again with LazyTensorStorage

@AlexandreBrown
Copy link
Author

Update : It does not crash with LazyTensorStorage + SyncDataCollector Wrapper.

I will try with LazyTensorStorage+ official SyncDataCollector.

I suspect the issue happens when I do trainings on a VM with too little free disk space (required by MemmapTensorStorage).

@vmoens
Copy link
Contributor

vmoens commented Jan 28, 2025

Update : It does not crash with LazyTensorStorage + SyncDataCollector Wrapper.

I will try with LazyTensorStorage+ official SyncDataCollector.

I suspect the issue happens when I do trainings on a VM with too little free disk space (required by MemmapTensorStorage).

Thanks so much for looking into it @AlexandreBrown! I guess you're seeing that only when using a RB within the collector then?

I will add an arg in SyncDataCollector to bypass the cuda syncs, that should be sufficient to avoid subclassing right @yu-fz and @fyu-bdai ?

See #2727 for a solution

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Jan 28, 2025

Update : When using the official SyncDataCollector + LazyMemmapStorage on a machine with enough storage (3TB+). I get an error fairly early during the first evaluation loop where I do env.rollouts :

  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 3289, in cpu
    return self.to("cpu", **kwargs)
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 10642, in to
    self._sync_all()
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 10739, in _sync_all
    torch.cuda.synchronize()
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/torch/cuda/__init__.py", line 954, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

And :

[2025-01-28 07:53:15.039] [SAPIEN] [critical] Mem free failed with error code 700!

[2025-01-28 07:53:15.039] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-28 07:53:15.039] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-28 07:53:15.039] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-28 07:53:15.039] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-28 07:53:15.040] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
CUDA error at /__w/SAPIEN/SAPIEN/3rd_party/sapien-vulkan-2/src/core/buffer.cpp 103: an illegal memory access was encountered

Which suggest that the SyncDataCollector CUDA memory management might cause issues in environments like Isaac Lab and/or Maniskill3.

@vmoens
Copy link
Contributor

vmoens commented Jan 28, 2025

I updated the docstrings of no_cuda_sync with this, thanks for the suggestion.
Now that I'm thinking about it, we should not only prevent syncs in the datacollector but also in tensordict!

So here's the gist:

  • when you do td.to("cpu") we want to do that fast. So what we do is that we use non_blocking=True -- but that's unsafe because there's no sync and your data could be corrupted.
  • So what we do is that tensordict detects whether you have cuda tensors that need to be transferred and if so, we sync when we're done. Problem is: Isaac doesn't like that.
  • Solution: we should use non_blocking=False but that means slow data transfer :(
  • For cpu -> cuda we should be good though bc CUDA is smart enough to sync itself when needed (provided your tensor is NOT in pinned memory, see this).

So I'm going to fall back on asking you guys: If no_cuda_sync is on like I implemented in #2727, do we want also to use non_blocking=False for D2H transfers and pay the price of sync transfers to CPU? (that also means that you will have a sync at each step in the collector which could drastically slow down your collection!)
Another option, also unsafe, is to avoid syncing CUDA and do the D2H transfer async with non_blocking=True. That will be very fast but if you're reading your cpu tensor in the loop you may have corrupted data.

@AlexandreBrown @yu-fz @fyu-bdai @tobiabir

@fyu-bdai
Copy link

Update : It does not crash with LazyTensorStorage + SyncDataCollector Wrapper.
I will try with LazyTensorStorage+ official SyncDataCollector.
I suspect the issue happens when I do trainings on a VM with too little free disk space (required by MemmapTensorStorage).

Thanks so much for looking into it @AlexandreBrown! I guess you're seeing that only when using a RB within the collector then?

I will add an arg in SyncDataCollector to bypass the cuda syncs, that should be sufficient to avoid subclassing right @yu-fz and @fyu-bdai ?

See #2727 for a solution

Yup! That will be sufficient.

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Jan 30, 2025

I updated the docstrings of no_cuda_sync with this, thanks for the suggestion. Now that I'm thinking about it, we should not only prevent syncs in the datacollector but also in tensordict!

So here's the gist:

  • when you do td.to("cpu") we want to do that fast. So what we do is that we use non_blocking=True -- but that's unsafe because there's no sync and your data could be corrupted.

This is pretty major, I encountered this today during evaluation data collection.
Using non_blocking=False did not fix the issue (maybe TorchRL does some .to("cpu") under the hood that I'm not able to control).

for _ in tqdm(range(nb_iters), "Evaluation"):
            rollouts = self.eval_env.rollout(
                max_steps=self.env_max_frames_per_traj,
                policy=policy,
                auto_reset=False,
                auto_cast_to_device=True,
                tensordict=tensordict,
            ).cpu(non_blocking=False)

So I'm going to fall back on asking you guys: If no_cuda_sync is on like I implemented in #2727, do we want also to use non_blocking=False for D2H transfers and pay the price of sync transfers to CPU? (that also means that you will have a sync at each step in the collector which could drastically slow down your collection!) Another option, also unsafe, is to avoid syncing CUDA and do the D2H transfer async with non_blocking=True. That will be very fast but if you're reading your cpu tensor in the loop you may have corrupted data.

It seems like a sync is non-avoidable if it leads to corrupted data that crashes training.
I would rather lose a few fps than have a taining crashing.

This makes cuda support for Maniskill impossible for me right now...

@vmoens
Copy link
Contributor

vmoens commented Jan 30, 2025

Sorry @AlexandreBrown I don't really get it: does the new feature introduce more bugs? I don't understand what the problem is precisely

@AlexandreBrown
Copy link
Author

Hi @vmoens , no I don't think the new issue introduces more bugs or at least that's not what I encountered.
Your fix and @yu-fz 's data collectors work and prevents issues when using a data collector.
The issue I am encountering is when doing env.rollout (no data collector/replay buffer in play).
Ex:

for _ in tqdm(range(nb_iters), "Evaluation"):
            rollouts = self.eval_env.rollout(
                max_steps=self.env_max_frames_per_traj,
                policy=policy,
                auto_reset=False,
                auto_cast_to_device=False,
                tensordict=tensordict,
            ).to(device="cpu", non_blocking=False)

Stacktrace:

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
    run()
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/home/user/.vscode/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "scripts/train_rl.py", line 118, in <module>
    main()
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "scripts/train_rl.py", line 107, in main
    trainer.train()
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 90, in train
    eval_metrics = self.evaluator.evaluate(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/evaluation/rl_evaluator.py", line 147, in evaluate
    eval_metrics = self.log_eval_metrics(agent, env_step)
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/evaluation/rl_evaluator.py", line 158, in log_eval_metrics
    eval_metrics = self.gather_eval_rollouts_metrics(policy)
  File "/home/user/Documents/SegDAC/segdac_dev/src/segdac_dev/evaluation/rl_evaluator.py", line 171, in gather_eval_rollouts_metrics
    rollouts = self.eval_env.rollout(
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 10623, in to
    tensors = [to(t) for t in tensors]
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 10623, in <listcomp>
    tensors = [to(t) for t in tensors]
  File "/home/user/miniconda3/envs/maniskill3_env/lib/python3.10/site-packages/tensordict/base.py", line 10595, in to
    return tensor.to(
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[2025-01-30 19:23:26.032] [SAPIEN] [critical] Mem free failed with error code 700!

[2025-01-30 19:23:26.032] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-30 19:23:26.032] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-30 19:23:26.032] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-30 19:23:26.032] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
[2025-01-30 19:23:26.033] [SAPIEN] [critical] /buildAgent/work/eb2f45c4acc808a0/physx/source/gpucommon/src/PxgCudaMemoryAllocator.cpp
CUDA error at /__w/SAPIEN/SAPIEN/3rd_party/sapien-vulkan-2/src/core/buffer.cpp 103: an illegal memory access was encountered

This occurs when the Maniskill env is using cuda.

@vmoens
Copy link
Contributor

vmoens commented Jan 31, 2025

Ah ok I thought it was the same issue. Your policy is on cuda?
Maybe open another issue to keep track since this one is solved?

@AlexandreBrown
Copy link
Author

@vmoens You are right, I should open another issue, thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants