Skip to content

Commit

Permalink
update dataset doc
Browse files Browse the repository at this point in the history
  • Loading branch information
phython96 committed Dec 12, 2024
1 parent 3d1b82b commit 850692c
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 143 deletions.
3 changes: 1 addition & 2 deletions docs/source/data/dataset-event.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ Using Lightning to simplify the data loading process (especially for distributed

```python
import lightning as L
from tqdm import tqdm
from minestudio.data import MineDataModule

fabric = L.Fabric(accelerator="cuda", devices=2, strategy="ddp")
Expand All @@ -153,7 +152,7 @@ data_module.setup()
train_loader = data_module.train_dataloader()
train_loader = fabric.setup_dataloaders(train_loader, use_distributed_sampler=True)
rank = fabric.local_rank
for idx, batch in enumerate(tqdm(train_loader, disable=True)):
for idx, batch in enumerate(train_loader):
print(
f"{rank = } \t" + "\t".join(
[f"{a.shape} {b}" for a, b in zip(batch['image'], batch['text'])]
Expand Down
3 changes: 1 addition & 2 deletions docs/source/data/dataset-raw.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ We can use lightning fabric to simplify the distributed data loading (using buil

```python
import lightning as L
from tqdm import tqdm
from minestudio.data import MineDataModule


Expand Down Expand Up @@ -198,7 +197,7 @@ data_module.setup()
train_loader = data_module.train_dataloader()
train_loader = fabric.setup_dataloaders(train_loader, use_distributed_sampler=False)
rank = fabric.local_rank
for idx, batch in enumerate(tqdm(train_loader, disable=True)):
for idx, batch in enumerate(train_loader):
print(
f"{rank = } \t" + "\t".join(
[f"{a[-20:]} {b}" for a, b in zip(batch['episode'], batch['progress'])]
Expand Down
8 changes: 2 additions & 6 deletions docs/source/data/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<!--
* @Date: 2024-11-29 08:08:34
* @LastEditors: caishaofei caishaofei@stu.pku.edu.cn
* @LastEditTime: 2024-12-12 07:18:49
* @LastEditTime: 2024-12-12 09:48:08
* @FilePath: /MineStudio/docs/source/data/index.md
-->
# Data
Expand All @@ -13,6 +13,7 @@ We design a trajectory structure for storing Minecraft data. Based on this data
dataset-raw
dataset-event
visualization
```

## Quick Start
Expand Down Expand Up @@ -116,8 +117,3 @@ An video example generated by our tool to show video and the corresponding segme
```{youtube} QYBUxus3esI
```
````


### Build Dataset from Your Collected Trajectories


107 changes: 107 additions & 0 deletions docs/source/data/visualization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
<!--
* @Date: 2024-12-12 09:18:35
* @LastEditors: caishaofei caishaofei@stu.pku.edu.cn
* @LastEditTime: 2024-12-12 11:40:10
* @FilePath: /MineStudio/docs/source/data/visualization.md
-->

# Visualization Script

We provide a visual script that allows users to observe whether the configured Dataloader meets expectations. It is useful for debugging and verifying the correctness of the data.

## Visualize Dataloader

Here is the arguments of the `visualize_dataloader` function:

| Arguments | Description |
| --- | --- |
| `dataloader` | PyTorch dataloader |
| `num_samples` | Number of batches to visualize |
| `resolution` | Resolution of the video |
| `legend` | Print action, contractor info, and segment info in the video |
| `save_fps` | FPS of the saved video |
| `output_dir` | Output directory for the saved video |

## Visualize Continuous Batches

When visualizing continuous video frames, set `episode_continuous_batch=True`, `batch_size=1` in the `MineDataModule` configuration.

```python
import lightning as L
from tqdm import tqdm
from minestudio.data import MineDataModule
from minestudio.data.minecraft.utils import visualize_dataloader

data_module = MineDataModule(
data_params=dict(
mode='raw',
dataset_dirs=[
'/nfs-shared-2/data/contractors/dataset_10xx',
],
frame_width=224,
frame_height=224,
win_len=128,
split_ratio=0.8,
),
batch_size=1, # set to 1 for visualizing continuous video frames
num_workers=2,
prefetch_factor=4,
shuffle_episodes=True,
episode_continuous_batch=True, # `True` for visualizing continuous video frames
)
data_module.setup()
dataloader = data_module.val_dataloader()

visualize_dataloader(
dataloader,
num_samples=5,
resolution=(640, 360),
legend=True, # print action, contractor info, and segment info ... in the video
save_fps=30,
output_dir="./"
)
```

Here is the example video:


## Visualize Batches with Special Events

When visualizing video frames with special events, set `event_regex` in the `MineDataModule` configuration.

```python
import lightning as L
from tqdm import tqdm
from minestudio.data import MineDataModule
from minestudio.data.minecraft.utils import visualize_dataloader

data_module = MineDataModule(
data_params=dict(
mode='event',
dataset_dirs=[
'/nfs-shared-2/data/contractors/dataset_10xx',
],
frame_width=224,
frame_height=224,
win_len=128,
split_ratio=0.8,
shuffle_episodes=True,
event_regex='minecraft.mine_block:.*diamond.*',
),
batch_size=2,
)
data_module.setup()
dataloader = data_module.val_dataloader()

visualize_dataloader(
dataloader,
num_samples=5,
resolution=(640, 360),
legend=True, # print action, contractor info, and segment info ... in the video
save_fps=30,
output_dir="./"
)
```

Here is the example video:

4 changes: 2 additions & 2 deletions minestudio/data/minecraft/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-10 10:25:38
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-12-04 15:53:20
LastEditTime: 2024-12-12 11:34:15
FilePath: /MineStudio/minestudio/data/minecraft/dataset.py
'''
import torch
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
**kwargs,
) -> None:

super().__init__(**kwargs)
super().__init__()
self.mode = mode
self.split = split
self.common_kwargs = dict(
Expand Down
132 changes: 6 additions & 126 deletions minestudio/data/minecraft/demo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Date: 2024-11-10 11:01:51
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-11-28 16:24:11
LastEditTime: 2024-12-12 10:55:34
FilePath: /MineStudio/minestudio/data/minecraft/demo.py
'''
import os
Expand All @@ -24,129 +24,9 @@
from minestudio.data.minecraft.part_event import EventDataset
from minestudio.data.minecraft.part_raw import RawDataset
from minestudio.data.minecraft.dataset import MinecraftDataset
from minestudio.data.minecraft.utils import MineDistributedBatchSampler, write_video, batchify

def write_to_frame(frame: np.ndarray, txt: str, row: int, col: int, color=(255, 0, 0)) -> None:
cv2.putText(frame, txt, (col, row), cv2.FONT_HERSHEY_SIMPLEX, 2.0, color, 1)

def dump_trajectories(
dataloader,
num_samples: int = 1,
save_fps: int = 20,
**kwargs
) -> None:

def un_batchify_actions(actions_in: Dict[str, torch.Tensor]) -> List[Dict]:
actions_out = []
for bidx in range(len(actions_in['attack'])):
action = {}
for k, v in actions_in.items():
action[k] = v[bidx].numpy()
actions_out.append(action)
return actions_out

traj_dir = Path("./traj_dir")
video_dir = traj_dir / "videos"
action_dir = traj_dir / "actions"
video_dir.mkdir(parents=True, exist_ok=True)
action_dir.mkdir(parents=True, exist_ok=True)
for idx, data in enumerate(tqdm(dataloader)):
if idx > num_samples: break
image = data['img']
action = data['action']
action = un_batchify_actions(action)
B, T = image.shape[:2]
for i in range(B):
vid = ''.join(random.choices(string.ascii_letters + string.digits, k=11))
write_video(
file_name=str(video_dir / f"{vid}.mp4"),
frames=image[i].numpy().astype(np.uint8),
)
with open(action_dir / f"{vid}.pkl", 'wb') as f:
pickle.dump(action[i], f)

def read_dataloader(
dataloader,
num_samples: int = 1,
resolution: Tuple[int, int] = (320, 180),
legend: bool = False,
temporal_mask: bool = False,
save_fps: int = 20,
**kwargs,
) -> None:
frames = []
for idx, data in enumerate(tqdm(dataloader)):
# continue
if idx > num_samples:
break
action = data['env_action']
prev_action = data.get("env_prev_action", None)
image = data['image'].numpy()
text = data['text']

color = (255, 0, 0)
for bidx, (tframes, txt) in enumerate(zip(image, text)):
cache_frames = []
for tidx, frame in enumerate(tframes):
if 'segment' in data:
COLORS = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255),
(255, 255, 255), (0, 0, 0), (128, 128, 128),
(128, 0, 0), (128, 128, 0), (0, 128, 0),
(128, 0, 128), (0, 128, 128), (0, 0, 128),
]
obj_id = data['segment']['obj_id'][bidx][tidx].item()
if obj_id != -1:
segment_mask = data['segment']['obj_mask'][bidx][tidx]
if isinstance(segment_mask, torch.Tensor):
segment_mask = segment_mask.numpy()
colors = np.array(COLORS[obj_id]).reshape(1, 1, 3)
segment_mask = (segment_mask[..., None] * colors).astype(np.uint8)
segment_mask = segment_mask[:, :, ::-1] # bgr -> rgb
frame = cv2.addWeighted(frame, 1.0, segment_mask, 0.5, 0.0)

if 'timestamp' in data:
timestamp = data['timestamp'][bidx][tidx]
cv2.putText(frame, f"timestamp: {timestamp}", (150, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 0, 55), 2)

if legend:
cv2.putText(frame, f"time: {tidx}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
cv2.putText(frame, txt, (200, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)

if 'contractor_info' in data:
try:
pitch = data['contractor_info']['pitch'][bidx][tidx]
yaw = data['contractor_info']['yaw'][bidx][tidx]
cursor_x = data['contractor_info']['cursor_x'][bidx][tidx]
cursor_y = data['contractor_info']['cursor_y'][bidx][tidx]
isGuiInventory = data['contractor_info']['isGuiInventory'][bidx][tidx]
isGuiOpen = data['contractor_info']['isGuiOpen'][bidx][tidx]
cv2.putText(frame, f"Pitch: {pitch:.2f}", (150, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(frame, f"Yaw: {yaw:.2f}", (150, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(frame, f"isGuiOpen: {isGuiOpen}", (150, 130), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(frame, f"isGuiInventory: {isGuiInventory}", (150, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(frame, f"CursorX: {cursor_x:.2f}", (150, 170), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(frame, f"CursorY: {cursor_y:.2f}", (150, 190), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
except:
cv2.putText(frame, f"No Contractor Info", (150, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

act = {k: v[bidx][tidx].numpy() for k, v in action.items()}
if prev_action is not None:
pre_act = {k: v[bidx][tidx].numpy() for k, v in prev_action.items()}
for row, ((k, v), (_, pv)) in enumerate(zip(act.items(), pre_act.items())):
if k != 'camera':
v = int(v.item())
pv = int(pv.item())
cv2.putText(frame, f"{k}: {v}({pv})", (10, 45 + row*15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

cache_frames.append(frame.astype(np.uint8))

frames = frames + cache_frames

timestamp = datetime.now().strftime("%m-%d_%H-%M")
file_name = f"save_{timestamp}.mp4"
write_video(file_name, frames, fps=save_fps, width=resolution[0], height=resolution[1])
from minestudio.data.minecraft.utils import (
MineDistributedBatchSampler, write_video, batchify, visualize_dataloader
)

def visualize_raw_dataset(args):
raw_dataset = RawDataset(
Expand Down Expand Up @@ -177,7 +57,7 @@ def visualize_raw_dataset(args):
collate_fn=batchify,
)

read_dataloader(
visualize_dataloader(
dataloader,
num_samples=args.num_samples,
resolution=(args.frame_width, args.frame_height),
Expand Down Expand Up @@ -213,7 +93,7 @@ def visualize_event_dataset(args):
)

# dump_trajectories(
read_dataloader(
visualize_dataloader(
dataloader,
num_samples=args.num_samples,
resolution=(args.frame_width, args.frame_height),
Expand Down
Loading

0 comments on commit 850692c

Please sign in to comment.