Skip to content

Commit

Permalink
Merge pull request #1 from AlanWang611/ultralytics
Browse files Browse the repository at this point in the history
Ultralytics
  • Loading branch information
AnthonyYao7 authored Apr 26, 2024
2 parents 194204c + c6a4c67 commit d959d38
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .idea/SET2023.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added computer_vision/best_trash.pt
Binary file not shown.
63 changes: 63 additions & 0 deletions computer_vision/trash_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from ultralytics import YOLO
from PIL import Image
from http.server import BaseHTTPRequestHandler, HTTPServer
import io
import cv2
import time
import requests


image_count = 0
model = YOLO('best_trash.pt')

def predict_image(image_path):

# model.predict(image_path, save=True, imgsz=1280, conf=0.25, show_labels=True, show_conf=True, iou=0.5,
# line_width=3)

#plot stuff onto image
im = cv2.imread(image_path)
result = model(im, imgsz=1280, conf=0.15, show_labels=True, show_conf=True, iou=0.5, line_width=3)
annotated_frame = result[0].plot()
# cv2.imshow('Result', annotated_frame)
# if cv2.waitKey(1) & 0xFF==ord("q"):
# return
# cv2.destroyAllWindows()
# cv2.waitKey(0)


# result boxes: all coords: x1, y1, x2, y2
# result_boxes = result[0].boxes.xyxy.cpu().detach().numpy()
return result[0].tojson()


class HTTPRequestHandler(BaseHTTPRequestHandler):

# POST method handler
def do_POST(self):
print("Got post")
global image_count
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
image_filename = 'images/' + str(image_count) + '.jpg'
image_count += 1
with open(image_filename, 'wb') as f:
f.write(post_data)

self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
prediction = predict_image(image_filename)
print(prediction)
self.wfile.write(prediction.encode('utf-8'))


def main():
server_address = ('', 8001)
httpd = HTTPServer(server_address, HTTPRequestHandler)
print('start server')
httpd.serve_forever()


if __name__ == "__main__":
main()
28 changes: 28 additions & 0 deletions computer_vision/trash_detection_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import cv2, shutil

def capture_image():
# Open the camera
cap = cv2.VideoCapture(0)

# Check if the camera is opened successfully
if not cap.isOpened():
print("Error: Unable to open camera.")
return

# Capture a frame
ret, frame = cap.read()

if ret:
# Save the captured frame as an image
cv2.imwrite("frame.jpg", frame)
# Send frame to trash_detection.py
shutil.copy("frame.jpg", "trash_detection.pg")
else:
print("Error: Unable to capture image.")

# Release the camera
cap.release()

robot_on = True
while robot_on:
capture_image()
40 changes: 40 additions & 0 deletions demo_computervision_display/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Fetch</title>
<style>
body, html {
margin: 0;
padding: 0;
width: 100%;
height: 100%;
overflow: hidden;
}
img {
max-width: 100%;
max-height: 100%;
display: block;
margin: auto;
}
</style>
</head>
<body>
<img id="dynamicImage" src="" alt="Dynamic Image">
<script>
function fetchImage() {
const imageElement = document.getElementById('dynamicImage');
// Update the source with a query string to avoid caching issues
imageElement.src = 'http://localhost:8000/image.jpg?' + new Date().getTime();
}

// Fetch an image every second
setInterval(fetchImage, 1000);

// Fetch the first image immediately on load
window.onload = fetchImage;
</script>
</body>
</html>
46 changes: 46 additions & 0 deletions demo_computervision_display/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from http.server import BaseHTTPRequestHandler, HTTPServer
import os
from pathlib import Path


class HTTPRequestHandler(BaseHTTPRequestHandler):

# POST method handler
def do_GET(self):
directory = '../computer_vision/images/' # Adjust the path to your image directory

try:
# Get all image files in the directory
files = list(Path(directory).glob('*'))
# Filter files to only get those that are valid images
image_files = [file for file in files if file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']]
print(image_files)

# Find the oldest file based on creation time
oldest_file = max(image_files, key=os.path.getctime, default=None)

if oldest_file is None:
self.send_error(404, "No image files found.")
return

print(oldest_file)

# Open the oldest image file and send it
with open(oldest_file, 'rb') as file:
self.send_response(200)
self.send_header('Content-type', 'image/jpeg') # Change if using different image types
self.end_headers()
self.wfile.write(file.read())
except Exception as e:
self.send_error(500, f"Server Error: {e}")


def run(server_class=HTTPServer, handler_class=HTTPRequestHandler, port=8000):
server_address = ('', port)
httpd = server_class(server_address, handler_class)
print(f"Server starting on port {port}...")
httpd.serve_forever()


if __name__ == "__main__":
run()
2 changes: 1 addition & 1 deletion image_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# 1/23/2024

import time
import rospy
# import rospy
import cv2
from std_msgs.msg import String
from sensor_msgs.msg import Image
Expand Down
154 changes: 139 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,153 @@
import os
from enum import IntEnum
import time
import shutil
import math

from PIL import Image
import cv2
import serial
# import serial
import requests
from pp.distance_measurement import bounding_box_angle, distance_estimation
from pp.config import *


"""
camera
navigation movement
sensors (lidar)
navigation algorithm
e-meet camera
horizontal: 68.61668666 degrees
vertical: 65.68702522 z
at 640x480
"""


class RobotState(IntEnum):
SEEKING_TRASH = 0
MOVING_TO_TRASH = 1
AT_TRASH = 2


class TrashTarget:
def __init__(self, horizontal_angle: float, distance: float):
self.horizontal_angle = horizontal_angle
self.distance = distance


class RobotStateVariables:
def __init__(self, robot_state: RobotState, trash_target: TrashTarget):
self.robot = Robot()
self.robot_state = robot_state
self.trash_target = trash_target


class Robot:
def rotate(self, angle_rad: float):
pass

def move_forward(self, distance_meters: float):
pass

def collect_trash(self):
pass


image_count = 0
YOLO_INFERENCE_SERVER_ADDRESS = os.environ.get('YOLO_INFERENCE_SERVER_ADDRESS')
IMAGE_SAVE_DIR = os.environ.get('IMAGE_SAVE_DIR')


def setup():

if not os.path.exists(IMAGE_SAVE_DIR):
os.makedirs(IMAGE_SAVE_DIR)
else:
for filename in os.listdir(IMAGE_SAVE_DIR):
file_path = os.path.join(IMAGE_SAVE_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# make image saving directory if not already
# if it exists already clear it
# check if yolo inference server is running

pass


def read_image(cap: cv2.VideoCapture, save_directory):
global image_count
_, frame = cap.read()
im = Image.fromarray(frame)
save_path = save_directory + '/' + str(image_count) + '.jpg'
image_count += 1
im.save(save_path)
return save_path


def do_yolo_inference(image_path: str) -> dict:
with open(image_path, 'rb') as f:
data = f.read()

print("Sending image")
r = requests.post(YOLO_INFERENCE_SERVER_ADDRESS, data=data)
resp = r.json()
return resp


def do_seeking_trash(robot_state: RobotStateVariables, cap: cv2.VideoCapture):
image_path = read_image(cap, IMAGE_SAVE_DIR)
yolo_results = do_yolo_inference(image_path)
if len(yolo_results) == 0:
return
first_piece_of_trash = yolo_results[0]
box = first_piece_of_trash['box']
x1, y1, x2, y2 = box['x1'], box['y1'], box['x2'], box['y2']
p = ((x1 + x2) / 2, y1)
res1 = bounding_box_angle([p], CAMERA_DIMENSIONS_PIXELS, CAMERA_FOV_RADIANS)
res2 = distance_estimation(res1[0], CAMERA_HEIGHT_ABOVE_GROUND_METERS)

robot_state.trash_target = TrashTarget(res1[0][0], math.sqrt(pow(res2[0], 2) + pow(res2[1], 2)))


def do_moving_to_trash(robot_state: RobotStateVariables, cap: cv2.VideoCapture):
if robot_state.trash_target is None:
robot_state.robot_state = RobotState(0)
return

robot_state.robot.rotate(robot_state.trash_target.horizontal_angle)
robot_state.robot.move_forward(robot_state.trash_target.distance)


def do_at_trash(robot_state: RobotStateVariables, cap: cv2.VideoCapture):
robot_state.robot.collect_trash()



handler_functions = {
0: do_seeking_trash,
1: do_moving_to_trash,
2: do_at_trash,
}


def main():
ser = serial.Serial('/dev/cu.usbmodem101', 9600)

try:
distance = True
while True:
data = ser.readline().decode().strip()
if distance:
print("Distance:", data)
distance = not distance
else:
print("Strength:", data)
distance = not distance
except KeyboardInterrupt:
ser.close()
cap = cv2.VideoCapture('/dev/video0')
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
state = RobotStateVariables(RobotState(0), None)
setup()

# main loop
while True:
(handler_functions[state.robot_state])(state, cap)
time.sleep(1)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion periodic_image_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cv2
import os
import time
import rospy
# import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge

Expand Down
4 changes: 2 additions & 2 deletions pp/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math


CAMERA_FOV_RADIANS = (52.31 * math.pi / 180, 51.83 * math.pi / 180)
CAMERA_FOV_RADIANS = (68.61668666 * math.pi / 180, 65.68702522 * math.pi / 180) # horizontal, vertical
CAMERA_DIMENSIONS_PIXELS = (640, 480)
CAMERA_HEIGHT_ABOVE_GROUND_METERS = 0.263
CAMERA_HEIGHT_ABOVE_GROUND_METERS = .084
Loading

0 comments on commit d959d38

Please sign in to comment.