Skip to content

Commit

Permalink
opt code
Browse files Browse the repository at this point in the history
  • Loading branch information
tonycaisy committed May 20, 2024
1 parent 4d3589c commit a3b0902
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 35 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,5 @@ If you find this repository useful, please cite this paper:
<tbody>
</table>
<!-- readme: contributors -end -->

TODO: test monitor save
1 change: 1 addition & 0 deletions car_dreamer/carla_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _is_terminal(self):
terminal = False
for k, v in terminal_conds.items():
if v:
print(f'[CARLA] Terminal condition triggered: {k}')
terminal = True
terminal_conds[k] = np.array([v], dtype=np.bool_)
if terminal:
Expand Down
12 changes: 10 additions & 2 deletions car_dreamer/carla_wpt_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,17 @@ def reward(self):
dist = perp_direction_norm
if dist > 0.5:
r_out_of_lane = -reward_scales['out_of_lane'] * (dist - 0.5)

# Reward for reaching the destination
r_destination = 0.0
if self.is_destination_reached():
r_destination = reward_scales['destination_reached']

# Time penalty
time_penalty = -reward_scales['time']

# Total reward
total_reward = r_waypoints + r_speed + r_collision + r_out_of_lane + time_penalty
total_reward = r_waypoints + r_speed + r_collision + r_out_of_lane + r_destination + time_penalty

ttc = TTCCalculator.get_ttc(ego, self._world.carla_world, self._world.carla_map)

Expand All @@ -124,6 +129,9 @@ def reward(self):
}

return total_reward, info

def is_destination_reached(self):
return len(self.waypoints) <= 3

def get_terminal_conditions(self):
terminal_config = self._config.terminal
Expand All @@ -132,7 +140,7 @@ def get_terminal_conditions(self):
'is_collision': self.is_collision(),
'time_exceeded': self._time_step > terminal_config.time_limit,
'out_of_lane': self.get_wpt_dist(ego_location) > terminal_config.out_lane_thres,
'destination_reached': len(self.waypoints) == 0,
'destination_reached': self.is_destination_reached(),
}
return conds

Expand Down
2 changes: 1 addition & 1 deletion car_dreamer/configs/tasks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ carla_wpt: &carla_wpt
# It's a base configuration for wpt tasks, don't use it directly
reward:
desired_speed: 4 # desired speed (m/s)
scales: { waypoint: 2.0, speed: 0.5, collision: 30.0, out_of_lane: 3.0, time: 0.0 }
scales: { waypoint: 2.0, speed: 0.5, collision: 30.0, out_of_lane: 3.0, time: 0.0, destination_reached: 20.0 }
terminal:
time_limit: 500 # maximum timesteps per episode
out_lane_thres: 3 # threshold for out of lane
Expand Down
2 changes: 1 addition & 1 deletion car_dreamer/toolkit/carla_manager/world_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def actors(self) -> List[carla.Actor]:
def _get_actor_polygons(self) -> ActorPolygonDict:
actor_polygons: ActorPolygonDict = {}

for actor in self.actor_dict.values():
for actor in self.actors:
actor_transform = actor.get_transform()
x = actor_transform.location.x
y = actor_transform.location.y
Expand Down
4 changes: 2 additions & 2 deletions car_dreamer/toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def _ensure_values(self, mapping):
if len(value) == 0:
message = 'Empty lists are disallowed because their type is unclear.'
raise TypeError(message)
if not isinstance(value[0], (str, float, int, bool, list)):
message = 'Lists can only contain strings, floats, ints, bools, lists'
if not isinstance(value[0], (str, float, int, bool, list, dict)):
message = 'Lists can only contain strings, floats, ints, bools, lists, dict'
message += f' but not {type(value[0])}'
raise TypeError(message)
if not all(isinstance(x, type(value[0])) for x in value[1:]):
Expand Down
45 changes: 37 additions & 8 deletions car_dreamer/toolkit/monitor/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
grid-gap: 20px;
width: 50%;
width: 60%;
}

.render-item {
Expand All @@ -59,26 +59,41 @@
.render-image {
width: 100%;
height: auto;
max-height: 512px;
object-fit: contain;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15);
}

.save-button {
margin-top: 10px;
padding: 5px 10px;
background-color: #4CAF50;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}

.save-button:hover {
background-color: #45a049;
}

.render-key {
margin-top: 10px;
font-weight: bold;
color: #555;
}

.reward-plot-container {
width: calc(50% - 30px);
width: 40%;
height: auto;
max-height: 400px;
margin-right: 20px;
}

.plot-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(400px, 1fr));
grid-gap: 20px;
width: 100%;
max-width: 100%;
Expand Down Expand Up @@ -169,7 +184,7 @@
</div>
</div>
<div class="section">
<div class="section-title">Scalars</div>
<div class="section-title">Info</div>
<div class="info-grid"></div>
</div>
<div class="section">
Expand Down Expand Up @@ -203,17 +218,31 @@
const renderItem = document.createElement('div');
renderItem.className = 'render-item';
const img = document.createElement('img');
img.src = `data:image/jpeg;base64,${item.image}`;
img.src = `data:image/webp;base64,${item.image}`;
img.className = 'render-image';
const key = document.createElement('div');
key.className = 'render-key';
key.textContent = item.key;
const saveButton = document.createElement('button');
saveButton.className = 'save-button';
saveButton.textContent = 'Save';
saveButton.addEventListener('click', () => {
saveImage(item.key, item.image);
});
renderItem.appendChild(img);
renderItem.appendChild(key);
renderItem.appendChild(saveButton);
renderGrid.appendChild(renderItem);
});
}

function saveImage(key, base64Image) {
const link = document.createElement('a');
link.href = `data:image/webp;base64,${base64Image}`;
link.download = `${key}_${Date.now()}.webp`;
link.click();
}

function updateRewardPlot(info) {
const rewardData = {};
Object.entries(info).forEach(([key, value]) => {
Expand Down Expand Up @@ -457,11 +486,11 @@

function adjustGridColumns() {
const renderItems = renderGrid.querySelectorAll('.render-item');
const numRenderColumns = Math.min(Math.max(Math.floor(renderGrid.clientWidth / 200), 1), renderItems.length);
const numRenderColumns = Math.min(Math.max(Math.floor(renderGrid.clientWidth / 200), 1), 2);
renderGrid.style.gridTemplateColumns = `repeat(${numRenderColumns}, 1fr)`;

const plotItems = plotGrid.querySelectorAll('.plot-item');
const numPlotColumns = Math.min(Math.max(Math.floor(plotGrid.clientWidth / 200), 1), 6);
const numPlotColumns = Math.min(Math.max(Math.floor(plotGrid.clientWidth / 400), 1), plotItems.length);
plotGrid.style.gridTemplateColumns = `repeat(${numPlotColumns}, 1fr)`;

const infoItems = infoGrid.querySelectorAll('.info-item');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ....carla_manager import WorldManager, Command, ActorPolygon
from .constants import BirdeyeEntity, Color
from .map_renderer import MapRenderer
from ..utils import should_filter

class BirdeyeRenderer:
def __init__(
Expand Down Expand Up @@ -130,7 +131,7 @@ def _render_background_vehicles(self, **env_state):
ego_id = self._ego.id

for vehicle_id, polygon in vehicle_polygons.items():
if vehicle_id == ego_id:
if vehicle_id == ego_id or should_filter(self._ego.get_transform(), self._world_manager.actor_transforms[vehicle_id]):
continue
vehicle_color = color.get(vehicle_id, None)
if vehicle_color is not None:
Expand All @@ -148,7 +149,7 @@ def _render_background_waypoints(self, **env_state):
vehicle_polygons = self._world_manager.actor_polygons

for vehicle_id, path in background_waypoints.items():
if vehicle_id == self._ego.id:
if vehicle_id == self._ego.id or should_filter(self._ego.get_transform(), self._world_manager.actor_transforms[vehicle_id]):
continue
vehicle_polygon = vehicle_polygons.get(vehicle_id, None)
if vehicle_polygon is None:
Expand Down Expand Up @@ -193,7 +194,7 @@ def render_character(location, message, message_color):
}

for vehicle_id, message in background_messages.items():
if vehicle_id == self._ego.id:
if vehicle_id == self._ego.id or should_filter(self._ego.get_transform(), self._world_manager.actor_transforms[vehicle_id]):
continue
polygon = self._world_manager.actor_polygons.get(vehicle_id, None)
if polygon is None:
Expand Down
4 changes: 3 additions & 1 deletion car_dreamer/toolkit/observer/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def get_visibility(

return fov_visible, recursive_visible

def should_filter(ego_transform, actor_transform):
return abs(actor_transform.location.z - ego_transform.location.z) > 3

def get_neighbors(
ego: carla.Actor,
Expand All @@ -129,7 +131,7 @@ def get_neighbors(
ego_location = ego_transform.location

for id, transform in actor_transforms.items():
if id == ego_id or not fov_visible[id]:
if id == ego_id or not fov_visible[id] or should_filter(ego_transform, transform):
continue
actor_location = transform.location
if actor_location.x > 5.0 and actor_location.x < 16.2 and abs(actor_location.x - ego_location.x) < 4.0:
Expand Down
16 changes: 6 additions & 10 deletions car_dreamer/toolkit/planner/base_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, vehicle: carla.Actor, max_waypoints=60, reach_threshold=0.5):
# The distance below which the waypoint is considered as reached
self._reach_threshold = reach_threshold
self._vehicle_location = get_vehicle_pos(self._vehicle)
self._prev_location = self._vehicle_location

@abstractmethod
def init_route(self):
Expand Down Expand Up @@ -93,16 +94,17 @@ def run_step(self):
'''
Run one step of the route planner, extending the route and removing expired waypoints.
:return: tuple(list of waypoints ``(x, y, yaw)``, number of completed waypoints, number of obsoleted waypoints)
:return: tuple(list of waypoints ``(x, y, yaw)``, additional stats)
'''
if not self._initialized:
self.init_route()
self._initialized = True

self.extend_route()
self._vehicle_location = get_vehicle_pos(self._vehicle)
planner_stats = self._update_waypoints_queue()
waypoints = self._get_waypoints()
self._vehicle_location = get_vehicle_pos(self._vehicle)
self._prev_location = self._vehicle_location
return waypoints, planner_stats

def _get_waypoints(self):
Expand All @@ -117,11 +119,9 @@ def _update_waypoints_queue(self):
num_completed = 0
num_obsolete = 0
num_to_delete = 0
travel_distance = 0.0
min_distance = 100
vehicle_location = get_vehicle_pos(self._vehicle)
for i, waypoint in enumerate(self._waypoints_queue):
dist = get_location_distance(vehicle_location, waypoint)
dist = get_location_distance(self._vehicle_location, waypoint)
if dist < self._reach_threshold:
num_completed += 1
num_to_delete = i + 1
Expand All @@ -130,16 +130,12 @@ def _update_waypoints_queue(self):
min_distance = dist
num_to_delete = i
num_obsolete = i
prev_location = self._vehicle_location
for i in range(num_completed):
travel_distance += get_location_distance(prev_location, self._waypoints_queue[i])
prev_location = self._waypoints_queue[i]
for _ in range(num_to_delete):
self.pop_waypoint()
num_obsolete -= num_completed
planner_stats = dict(
num_completed=num_completed,
num_obsolete=num_obsolete,
travel_distance=travel_distance
travel_distance=get_location_distance(self._prev_location, self._vehicle_location)
)
return planner_stats
2 changes: 1 addition & 1 deletion dreamerv2/dreamerv2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defaults:
filter: '.*'
tbtt: 0
train:
steps: 3e5
steps: 4e5
expl_until: 0
log_every: 2e3
eval_every: 1e3
Expand Down
3 changes: 1 addition & 2 deletions dreamerv2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
def save_configs(config, logdir):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
cardreamer_id = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=str(directory.parent)).decode('utf-8').strip()
gym_id = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=str(directory.parent / 'gym-carla')).decode('utf-8').strip()
config = config.update(cardreamer_id=cardreamer_id, gym_id=gym_id)
config = config.update(cardreamer_id=cardreamer_id)
config_filename = f'config_{timestamp}.yaml'
config_path = pathlib.Path(logdir) / config_filename
config.save(str(config_path))
Expand Down
10 changes: 6 additions & 4 deletions dreamerv3/embodied/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def save(self, filename):
elif filename.suffix in ('.yml', '.yaml'):
import ruamel.yaml as yaml
with io.StringIO() as stream:
yaml.safe_dump(dict(self), stream)
yaml = yaml.YAML(typ='safe', pure=True)
yaml.dump(dict(self), stream)
filename.write(stream.getvalue())
else:
raise NotImplementedError(filename.suffix)
Expand All @@ -44,7 +45,8 @@ def load(cls, filename):
return cls(json.loads(filename.read_text()))
elif filename.suffix in ('.yml', '.yaml'):
import ruamel.yaml as yaml
return cls(yaml.safe_load(filename.read_text()))
yaml = yaml.YAML(typ='safe', pure=True)
return cls(yaml.load(filename.read_text()))
else:
raise NotImplementedError(filename.suffix)

Expand Down Expand Up @@ -174,8 +176,8 @@ def _ensure_values(self, mapping):
if len(value) == 0:
message = 'Empty lists are disallowed because their type is unclear.'
raise TypeError(message)
if not isinstance(value[0], (str, float, int, bool, list)):
message = 'Lists can only contain strings, floats, ints, bools, lists'
if not isinstance(value[0], (str, float, int, bool, list, dict)):
message = 'Lists can only contain strings, floats, ints, bools, lists, dict'
message += f' but not {type(value[0])}'
raise TypeError(message)
if not all(isinstance(x, type(value[0])) for x in value[1:]):
Expand Down
6 changes: 6 additions & 0 deletions dreamerv3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ruamel.yaml as yaml
import dreamerv3
import embodied
import datetime

import car_dreamer

Expand Down Expand Up @@ -60,6 +61,11 @@ def main(argv=None):
env = wrap_env(env, dreamerv3_config)
env = embodied.BatchEnv([env], parallel=False)

timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
config_filename = f'config_{timestamp}.yaml'
config.save(str(logdir / config_filename))
print(f'[Train] Config saved to {logdir / config_filename}')

agent = dreamerv3.Agent(env.obs_space, env.act_space, step, dreamerv3_config)
replay = embodied.replay.Uniform(
dreamerv3_config.batch_length, dreamerv3_config.replay_size, logdir / 'replay')
Expand Down

0 comments on commit a3b0902

Please sign in to comment.