Skip to content

Commit

Permalink
Updated Brain V torch ATLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiopaniego committed Apr 15, 2024
1 parent 47dbbd6 commit 4fe663b
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(self, sensors, actuators, model=None, handler=None, config=None):
self.net = torch.jit.load(PRETRAINED_MODELS + model).to(self.device)
# self.clean_model()
else:
self.net = PilotNet((200,66,4), 3).to(self.device)
#self.net = PilotNet((200,66,4), 3).to(self.device)
#self.net = PilotNet((200,66,4), 2).to(self.device)
self.net = PilotNet((66,200,4), 2).to(self.device)
self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device))
else:
print("Brain not loaded")
Expand Down Expand Up @@ -119,27 +121,89 @@ def execute(self):
self.update_frame('frame_2', image)
self.update_frame('frame_3', image_3)

self.update_frame('frame_0', bird_eye_view_1)
#self.update_frame('frame_0', bird_eye_view_1)
self.update_frame('frame_0', np.array(image))

try:
image = Image.fromarray(image)

#print(image.size)
#image = image.resize((300, 225))
#image = image.resize((225, 300))
#print(image.size)
#image.save('imagen_guardada.png')
# (2048, 1536)
# (2048, 120)


# ATLAS original
#camera_bp.set_attribute('image_size_x', str(2048))
#camera_bp.set_attribute('image_size_y', str(1536))
# We adjust the size to retrieve images at a faster pace
#camera_bp.set_attribute('image_size_x', str(300))
#camera_bp.set_attribute('image_size_y', str(224.85))

altura_recorte = 120
ancho, altura = image.size
y_superior = max(0, altura - altura_recorte)
image = image.crop((0, y_superior, ancho, altura))

#print('image.size', image.size)




image = self.transformations(image)


image = image / 255.0
speed = self.vehicle.get_velocity()
vehicle_speed = 3.6 * math.sqrt(speed.x**2 + speed.y**2 + speed.z**2)

valor_cuartadimension = torch.full((1, image.shape[1], image.shape[2]), float(vehicle_speed))
#print('vehicle_speed', vehicle_speed)

vehicle_speed_norm = torch.clamp(torch.tensor(vehicle_speed, dtype=torch.float32) / 40.0, 0, 1.0)


valor_cuartadimension = torch.full((1, image.shape[1], image.shape[2]), float(vehicle_speed_norm))
image = torch.cat((image, valor_cuartadimension), dim=0).to(self.device)
image = image.unsqueeze(0)

#print('image.shape', image.shape)
#print('self.gpu_inference', self.gpu_inference)
#print(image)

start_time = time.time()
with torch.no_grad():
prediction = self.net(image).cpu().numpy() if self.gpu_inference else self.net(image).numpy()
self.inference_times.append(time.time() - start_time)
throttle = prediction[0][0]
steer = prediction[0][1] * (1 - (-1)) + (-1)
break_command = prediction[0][2]

print('prediction', prediction)
if (prediction[0][1] < 0.4):
print('LOG!!!!!!!!!!!!!!!!!!!!!!!!!!!')
#print(prediction)
prediction = prediction.flatten()
# prediction = prediction.detach().cpu().numpy().flatten()
#print('prediction', prediction)

#print('vehicle_speed', vehicle_speed)
#print('vehicle_speed_norm', vehicle_speed_norm)

combined, steer = prediction
combined = float(combined)
throttle, break_command = 0.0, 0.0
if combined >= 0.5:
throttle = (combined - 0.5) / 0.5
else:
break_command = (0.5 - combined) / 0.5
steer = (float(steer) * 2.0) - 1.0

#throttle = prediction[0][0]
#steer = prediction[0][1] * (1 - (-1)) + (-1)
#break_command = prediction[0][2]

print(throttle, steer, break_command)
#print('----')

if vehicle_speed > 30:
self.motors.sendThrottle(0)
Expand All @@ -151,7 +215,7 @@ def execute(self):
self.motors.sendSteer(0.0)
self.motors.sendBrake(0)
else:
self.motors.sendThrottle(throttle)
self.motors.sendThrottle(0.5)
self.motors.sendSteer(steer)
self.motors.sendBrake(break_command)

Expand Down
79 changes: 44 additions & 35 deletions behavior_metrics/brains/CARLA/pytorch/utils/pilotnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,65 @@


class PilotNet(nn.Module):
def __init__(self,
image_shape,
num_labels):
def __init__(self, image_shape, num_labels):
super(PilotNet, self).__init__()

self.img_height = image_shape[0]
self.img_width = image_shape[1]

self.num_channels = image_shape[2]
# Batch normalization?
self.batchnorm_input = nn.BatchNorm2d(self.num_channels) # Para imágenes en formato RGB (3 canales)
self.cn_1 = nn.Conv2d(in_channels=self.num_channels, out_channels=24, kernel_size=5, stride=2)
self.relu_1 = nn.ReLU()
self.cn_2 = nn.Conv2d(in_channels=24, out_channels=36, kernel_size=5, stride=2)
self.relu_2 = nn.ReLU()
self.cn_3 = nn.Conv2d(in_channels=36, out_channels=48, kernel_size=5, stride=2)
self.relu_3 = nn.ReLU()
self.cn_4 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, stride=1)
self.relu_4 = nn.ReLU()
self.cn_5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
self.relu_5 = nn.ReLU()
self.dropout_1 = nn.Dropout(0.2)
self.flatten = nn.Flatten()

self.output_size = num_labels

self.ln_1 = nn.BatchNorm2d(self.num_channels, eps=1e-03)

self.cn_1 = nn.Conv2d(self.num_channels, 24, kernel_size=5, stride=2)
self.cn_2 = nn.Conv2d(24, 36, kernel_size=5, stride=2)
self.cn_3 = nn.Conv2d(36, 48, kernel_size=5, stride=2)
self.cn_4 = nn.Conv2d(48, 64, kernel_size=3, stride=1)
self.cn_5 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

self.fc_1 = nn.Linear(1 * 18 * 64, 1164)
# Flatten layer?
self.fc_1 = nn.Linear(1152, 1164) # add embedding layer output size
self.relu_fc_1 = nn.ReLU()
self.fc_2 = nn.Linear(1164, 100)
self.relu_fc_2 = nn.ReLU()
self.fc_3 = nn.Linear(100, 50)
self.relu_fc_3 = nn.ReLU()
self.fc_4 = nn.Linear(50, 10)
self.fc_5 = nn.Linear(10, self.output_size)
self.relu_fc_4 = nn.ReLU()
self.fc_5 = nn.Linear(10, num_labels)

def forward(self, img):

out = self.ln_1(img)

out = self.cn_1(out)
out = torch.relu(out)
out = self.batchnorm_input(img)
out = self.cn_1(img)
out = self.relu_1(out)
out = self.cn_2(out)
out = torch.relu(out)
out = self.relu_2(out)
out = self.cn_3(out)
out = torch.relu(out)
out = self.relu_3(out)

out = self.cn_4(out)
out = torch.relu(out)
out = self.relu_4(out)
out = self.cn_5(out)
out = torch.relu(out)

out = out.reshape(out.size(0), -1)

out = self.relu_5(out)

out = self.dropout_1(out)

#out = out.view(-1, 1152)
out = self.flatten(out)

out = self.fc_1(out)
out = torch.relu(out)
out = self.relu_fc_1(out)
out = self.fc_2(out)
out = torch.relu(out)
out = self.relu_fc_2(out)
out = self.fc_3(out)
out = torch.relu(out)
out = self.relu_fc_3(out)
out = self.fc_4(out)
out = torch.relu(out)
out = self.relu_fc_4(out)
out = self.fc_5(out)

return out
#out = torch.sigmoid(out)

return out
6 changes: 3 additions & 3 deletions behavior_metrics/configs/CARLA/default_carla_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ Behaviors:
Topic: '/carla/ego_vehicle/vehicle_control_cmd'
MaxV: 3
MaxW: 0.3
BrainPath: 'brains/CARLA/brain_carla_segmentation_based_imitation_learning.py'
BrainPath: 'brains/CARLA/pytorch/brain_carla_bird_eye_deep_learning_torch_V.py'
PilotTimeCycle: 50
AsyncMode: True
Parameters:
Model: 'pilotnet_v8.0.pth'
Model: '15_04_best_model_checkpoint_189.pth'
ImageCropped: True
ImageSize: [ 100,50 ]
ImageNormalized: True
Expand All @@ -47,7 +47,7 @@ Behaviors:
ImageTranform: ''
Type: 'CARLA'
Simulation:
World: configs/CARLA/CARLA_launch_files/town_02_anticlockwise_imitation_learning.launch
World: configs/CARLA/CARLA_launch_files/carla_new_ada_bridge_updated.launch
RandomSpawnPoint: False
Dataset:
In: '/tmp/my_bag.bag'
Expand Down

0 comments on commit 4fe663b

Please sign in to comment.