Skip to content

Commit

Permalink
Merge pull request meadx#28 from edwin-nz/master
Browse files Browse the repository at this point in the history
Updated with neural network trained to classify site traffic light
  • Loading branch information
edwin-nz authored Dec 8, 2018
2 parents 8f67594 + f4a4198 commit edb36f2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 65 deletions.
Binary file not shown.
166 changes: 104 additions & 62 deletions ros/src/tl_detector/light_classification/tl_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,73 @@
from keras.models import Sequential
from keras.layers import Conv2D, Flatten, Dense, Lambda, MaxPooling2D, Dropout
import tensorflow as tf
import yaml

class TLClassifier(object):
def __init__(self):
#TODO load classifier
#pass
global tl_model
global graph_tl

self.threshold = .7

#Load the VGG16 model
# Save the graph after loading the model
global vgg_model
vgg_model = vgg16.VGG16(weights='imagenet')
global graph_vgg
graph_vgg = tf.get_default_graph()
config_string = rospy.get_param("/traffic_light_config")
self.config = yaml.load(config_string)
self.is_site = self.config['is_site']


keep_prob = 0.2
global tl_model
tl_model = Sequential()
tl_model.add(Lambda(lambda x: x / 127.5 - 1.0, input_shape=(224, 224, 3)))
tl_model.add(Conv2D(32, (3, 3), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Dropout(keep_prob))
tl_model.add(Conv2D(36, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Dropout(keep_prob))
tl_model.add(Conv2D(48, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(Dropout(keep_prob))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Dropout(keep_prob))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Dropout(keep_prob))
tl_model.add(Flatten())
tl_model.add(Dense(100))
tl_model.add(Dense(50))
tl_model.add(Dense(10))
tl_model.add(Dense(3, activation='softmax'))
if (not self.is_site):

self.threshold = .7
#Load the VGG16 model
# Save the graph after loading the model
global vgg_model
vgg_model = vgg16.VGG16(weights='imagenet')
global graph_vgg
graph_vgg = tf.get_default_graph()

os.chdir('.')
tl_model.load_weights('light_classification/highway_modelv3-ep15-wts.h5')
global graph_tl
graph_tl = tf.get_default_graph()

tl_model = Sequential()
tl_model.add(Lambda(lambda x: x / 127.5 - 1.0, input_shape=(224, 224, 3)))
tl_model.add(Conv2D(32, (3, 3), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Conv2D(36, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Conv2D(48, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Flatten())
tl_model.add(Dense(100))
tl_model.add(Dense(50))
tl_model.add(Dense(10))
tl_model.add(Dense(3, activation='softmax'))

os.chdir('.')
tl_model.load_weights('light_classification/highway_modelv3-ep15-wts.h5')
graph_tl = tf.get_default_graph()

else:
self.threshold = .5

tl_model = Sequential()
tl_model.add(Lambda(lambda x: x / 127.5 - 1.0, input_shape=(224, 224, 3)))
tl_model.add(Conv2D(32, (3, 3), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Conv2D(36, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(MaxPooling2D(2,2))
tl_model.add(Conv2D(48, (5, 5), strides=(2, 2), activation='relu'))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Conv2D(64, (3, 3), activation='relu'))
tl_model.add(Flatten())
tl_model.add(Dense(100))
tl_model.add(Dense(50))
tl_model.add(Dense(10))
tl_model.add(Dense(4, activation='softmax'))

os.chdir('.')
tl_model.load_weights('light_classification/site_modelv5-ep30-wts.h5')
graph_tl = tf.get_default_graph()

print('Traffic light claasifier initialized')

import cv2



def get_classification(self, image):
"""Determines the color of the traffic light in the image
Expand All @@ -79,31 +97,55 @@ def get_classification(self, image):
image224 = img_to_array(image224)
# Convert the image into 4D Tensor (samples, height, width, channels) by adding an extra dimension to the axis 0.
input_image = np.expand_dims(image224, axis=0)
processed_image_vgg16 = vgg16.preprocess_input(input_image.copy())
with graph_vgg.as_default():
predictions_vgg16 = vgg_model.predict(processed_image_vgg16)
label_vgg16 = decode_predictions(predictions_vgg16)

if (not self.is_site):

processed_image_vgg16 = vgg16.preprocess_input(input_image.copy())
with graph_vgg.as_default():
predictions_vgg16 = vgg_model.predict(processed_image_vgg16)
label_vgg16 = decode_predictions(predictions_vgg16)

if (int(label_vgg16[0][0][1] == 'traffic_light') & int(label_vgg16[0][0][2] > 0.7)):
if (int(label_vgg16[0][0][1] == 'traffic_light') & int(label_vgg16[0][0][2] > 0.7)):
# make a prediction
with graph_tl.as_default():
predict = tl_model.predict(input_image)
#predict = tl_model.predict(processed_image_vgg16)

#print('Traffic Light Prediction ', predict[0])
if predict[0][0] > self.threshold:
#print('Classifier Prediction - RED', predict[0][0])
rospy.loginfo(" Classifier Prediction - RED %f", predict[0][0])
return TrafficLight.RED
elif predict[0][1] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - YELLOW %f", predict[0][1])
return TrafficLight.YELLOW
elif predict[0][2] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - GREEN %f", predict[0][2])
return TrafficLight.GREEN
else:
#print('Classifier Prediction - UNKNOWN', label_vgg16[0][0])
#rospy.loginfo(" Classifier Prediction - UNKNOWN [%s, %f]", label_vgg16[0][0][1], label_vgg16[0][0][2])
return TrafficLight.UNKNOWN

else:
# make a prediction
with graph_tl.as_default():
predict = tl_model.predict(np.expand_dims(image224, axis=0))
#predict = tl_model.predict(processed_image_vgg16)

#print('Traffic Light Prediction ', predict[0])
if predict[0][0] > self.threshold:
#print('Classifier Prediction - RED', predict[0][0])
rospy.loginfo(" Classifier Prediction - RED %f", predict[0][0])
return TrafficLight.RED
elif predict[0][1] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - YELLOW %f", predict[0][1])
return TrafficLight.YELLOW
elif predict[0][2] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - GREEN %f", predict[0][2])
return TrafficLight.GREEN
else:
#print('Classifier Prediction - UNKNOWN', label_vgg16[0][0])
rospy.loginfo(" Classifier Prediction - UNKNOWN [%s, %f]", label_vgg16[0][0][1], label_vgg16[0][0][2])
return TrafficLight.UNKNOWN
predict = tl_model.predict(input_image)

if predict[0][0] > self.threshold:
#print('Classifier Prediction - RED', predict[0][0])
rospy.loginfo(" Classifier Prediction - RED %f", predict[0][0])
return TrafficLight.RED
elif predict[0][1] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - YELLOW %f", predict[0][1])
return TrafficLight.YELLOW
elif predict[0][2] > self.threshold:
#print('Classifier Prediction - YELLOW', predict[0][1])
rospy.loginfo(" Classifier Prediction - GREEN %f", predict[0][2])
return TrafficLight.GREEN

rospy.loginfo(" Classifier Prediction - UNKNOWN %f", predict[0][3])
return TrafficLight.UNKNOWN
12 changes: 9 additions & 3 deletions ros/src/tl_detector/tl_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def __init__(self):
rely on the position of the light and the camera image to predict it.
'''
sub3 = rospy.Subscriber('/vehicle/traffic_lights', TrafficLightArray, self.traffic_cb)
sub6 = rospy.Subscriber('/image_color', Image, self.image_cb)

if (self.is_site):
sub6 = rospy.Subscriber('/image_raw', Image, self.image_cb)
else:
sub6 = rospy.Subscriber('/image_color', Image, self.image_cb)


self.upcoming_red_light_pub = rospy.Publisher('/traffic_waypoint', Int32, queue_size=1)
Expand Down Expand Up @@ -161,8 +165,9 @@ def get_light_state(self, light):
cv_image = self.bridge.imgmsg_to_cv2(self.camera_image, "bgr8")

# To collect camera images
#now_time_str = datetime.now().strftime("%y%m%d%H%M%S%f")
now_time_str = datetime.now().strftime("%y%m%d%H%M%S%f")
#cv2.imwrite('camera_images/'+str(light.state)+'/{0}.jpg'.format(now_time_str), cv_image)
cv2.imwrite('camera_images/'+'/{0}.jpg'.format(now_time_str), cv_image)

#Get classification
if self.light_classifier is None:
Expand All @@ -171,7 +176,8 @@ def get_light_state(self, light):

#self.light_classifier.get_classification(cv_image)
return self.light_classifier.get_classification(cv_image)
# For testing, just return the light state
# For testing, just rerutn the light state
#print('Light State' ,light.state)
#return light.state

def process_traffic_lights(self):
Expand Down

0 comments on commit edb36f2

Please sign in to comment.