Skip to content

Commit

Permalink
Rewrite segment annotation so that it's contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
hang-yin committed Jan 30, 2025
1 parent 6104457 commit 0d98ffd
Showing 1 changed file with 140 additions and 52 deletions.
192 changes: 140 additions & 52 deletions omnigibson/envs/data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional

import h5py
import imageio
Expand Down Expand Up @@ -480,8 +480,10 @@ def create_from_hdf5(
external_sensors_config=None,
n_render_iterations=5,
only_successes=False,
action_only=False,
enable_annotation=False,
subtasks=None,
annotation_fields=None,
):
"""
Create a DataPlaybackWrapper environment instance form the recorded demonstration info
Expand All @@ -507,8 +509,26 @@ def create_from_hdf5(
the physical state changes. Increasing this number will improve the rendered quality at the expense of
speed.
only_successes (bool): Whether to only save successful episodes
action_only (bool): If True, only replay actions but not all states
enable_annotation (bool): Whether to enable trajectory annotation during playback
subtasks (None or List[str]): List of possible subtask names for annotation (required if enable_annotation is True)
annotation_fields (None or Dict[str, dict]): Dictionary specifying additional annotation fields
and their properties. Each key is a field name, and the value is a dict containing:
- 'type': The expected type of the field (e.g., bool, str, int)
- 'prompt': The prompt to show when collecting this field
- 'options': (optional) List of valid options for this field
Example:
{
'success': {
'type': bool,
'prompt': 'Was this segment successful?'
},
'difficulty': {
'type': str,
'prompt': 'Select difficulty level:',
'options': ['easy', 'medium', 'hard']
}
}
Returns:
DataPlaybackWrapper: Generated playback environment
Expand All @@ -519,10 +539,11 @@ def create_from_hdf5(

# Hot swap in additional info for playing back data

# Minimize physics leakage during playback (we need to take an env step when loading state)
config["env"]["action_frequency"] = 1000.0
config["env"]["rendering_frequency"] = 1000.0
config["env"]["physics_frequency"] = 1000.0
if not action_only:
# Minimize physics leakage during playback (we need to take an env step when loading state)
config["env"]["action_frequency"] = 1000.0
config["env"]["rendering_frequency"] = 1000.0
config["env"]["physics_frequency"] = 1000.0

# Make sure obs space is flattened for recording
config["env"]["flatten_obs_space"] = True
Expand Down Expand Up @@ -554,8 +575,10 @@ def create_from_hdf5(
output_path=output_path,
n_render_iterations=n_render_iterations,
only_successes=only_successes,
action_only=action_only,
enable_annotation=enable_annotation,
subtasks=subtasks,
annotation_fields=annotation_fields,
)

def __init__(
Expand All @@ -565,8 +588,10 @@ def __init__(
output_path: str,
n_render_iterations: int = 5,
only_successes: bool = False,
action_only: bool = False,
enable_annotation: bool = False,
subtasks: Optional[List[str]] = None,
annotation_fields: Optional[Dict[str, dict]] = None,
):
"""
Args:
Expand All @@ -576,8 +601,11 @@ def __init__(
n_render_iterations (int): Number of rendering iterations to use when loading each stored frame from the
recorded data
only_successes (bool): Whether to only save successful episodes
action_only (bool): If True, only replay actions but not all states
enable_annotation: Whether to enable trajectory annotation during playback
subtasks: List of possible subtask names for annotation (required if enable_annotation is True)
annotation_fields (None or Dict[str, dict]): Additional annotation fields configuration.
See create_from_hdf5() docstring for details.
"""
# Make sure transition rules are DISABLED for playback since we manually propagate transitions
assert not gm.ENABLE_TRANSITION_RULES, "Transition rules must be disabled for DataPlaybackWrapper env!"
Expand All @@ -588,13 +616,15 @@ def __init__(

# Store additional variables
self.n_render_iterations = n_render_iterations
self.action_only = action_only
self.enable_annotation = enable_annotation
self.annotation_fields = annotation_fields or {}
if enable_annotation:
assert subtasks is not None, "subtasks must be provided when enable_annotation is True"
self.subtasks = subtasks
self.playback_state = PlaybackState.PAUSED
self.current_segment = {"start_step": None, "end_step": None, "subtask": None}
self.annotations = []
self.current_segment = {"start_step": None, "end_step": None}
self.annotations = {task: None for task in self.subtasks}
self._setup_keyboard_controls()

# Run super
Expand All @@ -617,61 +647,117 @@ def _toggle_pause(self):
if self.playback_state == PlaybackState.PLAYING:
self.playback_state = PlaybackState.PAUSED
print(f"\nPlayback paused at step {self.step_count}")
if self.current_segment["start_step"] is None:
print("Press 'Z' to mark segment start")
if all(v is None for v in self.annotations.values()): # If this is the first segment
print("Press 'Z' to mark first segment start")
else:
print("Press 'Z' to mark segment end")
print("Press 'Z' to mark current segment end")
print("Press 'ENTER' to resume playback")
else:
self.playback_state = PlaybackState.PLAYING
print("\nPlayback resumed")

def _mark_segment_boundary(self):
"""Mark segment boundaries and prompt for subtask selection when a segment is completed."""
def _prompt_for_subtask_annotation(self):
"""
Prompt user to select a subtask for annotation.
Returns:
str: Selected subtask or "skip"
"""
# Filter out already annotated subtasks
available_subtasks = {
task: f"Annotate segment as '{task}'" for task in self.subtasks if self.annotations[task] is None
}
available_subtasks["skip"] = "Skip annotation for this segment"

return choose_from_options(available_subtasks, "subtask")

def _handle_remaining_annotations(self):
"""Handle annotation for the final segment if any subtasks remain unannotated."""
unannotated_tasks = [task for task, annot in self.annotations.items() if annot is None]

if unannotated_tasks and self.current_segment["start_step"] is not None:
print("\nTrajectory complete. Would you like to annotate the final segment?")
print(f"Final segment spans steps {self.current_segment['start_step']} to {self.step_count}")

available_subtasks = {task: f"Annotate final segment as '{task}'" for task in unannotated_tasks}
available_subtasks["skip"] = "Skip final segment annotation"

selected = choose_from_options(available_subtasks, "subtask")
if selected != "skip":
annotation_data = self._collect_annotation_fields()
self.annotations[selected] = annotation_data
print(f"\nMarked final segment with subtask: {selected}")
for field, value in annotation_data.items():
print(f"{field}: {value}")

def _collect_annotation_fields(self) -> dict:
"""Collect values for all annotation fields for the current segment."""
field_values = {
"start_step": self.current_segment["start_step"],
"end_step": self.current_segment["end_step"],
}

for field_name, field_config in self.annotation_fields.items():
prompt = field_config["prompt"]
field_type = field_config["type"]

if "options" in field_config:
# For fields with predefined options
options = {str(opt): str(opt) for opt in field_config["options"]}
value = choose_from_options(options, field_name)
# Convert to appropriate type
field_values[field_name] = field_type(value)
else:
# For free-form input fields
while True:
try:
print(f"\n{prompt}")
value = input("> ")
field_values[field_name] = field_type(value)
break
except ValueError:
print(f"Invalid input. Please enter a valid {field_type.__name__}")

return field_values

def _mark_segment(self):
"""Mark segment boundaries and handle subtask selection."""
if self.playback_state == PlaybackState.PLAYING:
return

# If no segment start is marked, this is the start of a new segment
if self.current_segment["start_step"] is None:
# For the very first segment, we need to mark its start
if all(v is None for v in self.annotations.values()) and self.current_segment["start_step"] is None:
self.current_segment["start_step"] = self.step_count
print(f"\nMarked segment start at step {self.step_count}")
print(f"\nMarked first segment start at step {self.step_count}")
print("Press 'ENTER' to resume playback")
print("Press 'Z' when you want to mark the segment end")
return

# If we already have a start point, this marks the end of the segment
# For all other cases, pressing Z marks the end of current segment
self.current_segment["end_step"] = self.step_count

# Prompt for subtask selection
self.playback_state = PlaybackState.ANNOTATING
subtask_dict = {task: f"Annotate segment as '{task}'" for task in self.subtasks}
subtask_dict["skip"] = "Skip annotation for this segment"

selected = choose_from_options(subtask_dict, "subtask")
selected = self._prompt_for_subtask_annotation()

# Record the completed segment
self.current_segment["subtask"] = None if selected == "skip" else selected
self.annotations.append(
{
"start_step": self.current_segment["start_step"],
"end_step": self.current_segment["end_step"],
"subtask": self.current_segment["subtask"],
}
)

# Print segment info
if self.current_segment["subtask"] is not None:
print(f"\nMarked segment {len(self.annotations)} with subtask: {self.current_segment['subtask']}")
if selected != "skip":
# Collect all annotation fields
annotation_data = self._collect_annotation_fields()
self.annotations[selected] = annotation_data
print(f"\nMarked segment {len(self.annotations)} with subtask: {selected}")
for field, value in annotation_data.items():
print(f"{field}: {value}")
else:
print(f"\nMarked segment {len(self.annotations)} as unannotated")
print(f"From step {self.current_segment['start_step']} to {self.current_segment['end_step']}")

# Reset current segment for the next annotation
self.current_segment = {"start_step": None, "end_step": None, "subtask": None}
# Automatically start the next segment from where the previous one ended
self.current_segment = {
"start_step": self.step_count, # Start new segment from current step
"end_step": None,
}

print(f"\nStarted new segment at step {self.step_count}")
self.playback_state = PlaybackState.PAUSED
print("\nReady for next segment")
print("Press 'Z' to mark new segment start")
print("Press 'Z' to mark segment end")
print("Press 'ENTER' to resume playback")

def _parse_step_data(self, action, obs, reward, terminated, truncated, info):
Expand All @@ -694,6 +780,7 @@ def playback_episode(self, episode_id, record=True, video_path=None, video_write
record (bool): Whether to record data during playback or not
video_path (None or str): If specified, path to write the playback video to
video_writer (None or str): If specified, an imageio video writer to use for writing the video (can be specified in place of @video_path)
action_only (bool): If True, only replay actions but not all states
"""
using_external_writer = video_writer is not None
if video_writer is None and video_path is not None:
Expand Down Expand Up @@ -735,9 +822,11 @@ def playback_episode(self, episode_id, record=True, video_path=None, video_write
zip(action, state[1:], state_size[1:], reward, terminated, truncated)
):
if self.enable_annotation:
while self.playback_state in [PlaybackState.PAUSED, PlaybackState.ANNOTATING]:
if self.playback_state == PlaybackState.PAUSED:
og.sim.render()
while self.playback_state in [
PlaybackState.PAUSED,
PlaybackState.ANNOTATING,
]:
og.sim.render()

# Execute any transitions that should occur at this current step
if str(i) in transitions:
Expand All @@ -759,7 +848,8 @@ def playback_episode(self, episode_id, record=True, video_path=None, video_write

# Restore the sim state, and take a very small step with the action to make sure physics are
# properly propagated after the sim state update
og.sim.load_state(s[: int(ss)], serialized=True)
if not self.action_only:
og.sim.load_state(s[: int(ss)], serialized=True)
self.current_obs, _, _, _, info = self.env.step(action=a, n_render_iterations=self.n_render_iterations)

# If recording, record data
Expand All @@ -780,9 +870,11 @@ def playback_episode(self, episode_id, record=True, video_path=None, video_write

self.step_count += 1

# Handle any remaining annotations
if self.enable_annotation:
self._handle_remaining_annotations()

if record:
if self.enable_annotation:
self.current_traj_history[-1]["annotations"] = self.annotations
self.flush_current_traj()

# If we weren't using an external writer but we're still writing a video, close the writer
Expand Down Expand Up @@ -812,12 +904,8 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, nested_keys=("obs",)):
# Get trajectory group from parent processing
traj_grp = super().process_traj_to_hdf5(traj_data, traj_grp_name, nested_keys)

# Add annotations as metadata if enabled and present
if self.enable_annotation and len(traj_data) > 0 and "annotations" in traj_data[-1]:
self.add_metadata(
group=traj_grp,
name="annotations",
data=traj_data[-1]["annotations"],
)
# Add annotations as metadata if enabled
if self.enable_annotation:
self.add_metadata(group=traj_grp, name="annotations", data=self.annotations)

return traj_grp

0 comments on commit 0d98ffd

Please sign in to comment.