-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from Vykp00/1-create-data-prepossessing-pipelin…
…e-for-raw-data 1 create data prepossessing pipeline for raw data
- Loading branch information
Showing
4 changed files
with
224 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import csv | ||
import json | ||
from collections import defaultdict | ||
|
||
# Path to the CSV file | ||
csv_file = 'walking_sample.csv' | ||
|
||
# JSON structure initialization | ||
result = defaultdict(list) | ||
|
||
# Read and process the CSV file | ||
with open(csv_file, 'r') as file: | ||
csv_reader = csv.DictReader(file) | ||
for row in csv_reader: | ||
result["timestamp"].append(int(row["timestamp"])) | ||
result["gazepoint_x"].append(float(row["gazepoint_x"])) | ||
result["gazepoint_y"].append(float(row["gazepoint_y"])) | ||
result["pupil_area_right_sq_mm"].append(float(row["pupil_area_right_sq_mm"])) | ||
result["pupil_area_left_sq_mm"].append(float(row["pupil_area_left_sq_mm"])) | ||
result["eye_event"].append(row["eye_event"].strip()) | ||
|
||
# Convert the defaultdict to a regular dictionary | ||
result = dict(result) | ||
|
||
# Save the JSON to a file (optional) or print it | ||
output_json = 'walking_sample.json' | ||
with open(output_json, 'w') as json_file: | ||
json.dump(result, json_file, indent=4) | ||
|
||
print(json.dumps(result, indent=4)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,92 @@ | ||
from typing import Any, Dict, List, Optional | ||
import traceback | ||
from typing import List | ||
|
||
import uvicorn | ||
import ydf | ||
from fastapi import FastAPI, status, Response | ||
from pydantic import BaseModel | ||
import ydf | ||
import math | ||
|
||
from starlette.responses import JSONResponse | ||
|
||
from utils import logger, preprocess_data | ||
|
||
app = FastAPI() | ||
|
||
model = ydf.load_model("model") | ||
label_classes = model.label_classes() | ||
logger.info(f"Current label_classes: {label_classes}") | ||
|
||
# Label mapping | ||
label_mapping = { | ||
"1": "walking", | ||
"3": "playing", | ||
"2": "reading" | ||
} | ||
|
||
|
||
class Example(BaseModel): | ||
timestamp: float = math.nan | ||
gazepoint_x: float = math.nan | ||
gazepoint_y: float = math.nan | ||
pupil_area_right_sq_mm: float = math.nan | ||
pupil_area_left_sq_mm: float = math.nan | ||
eye_event: str = "" | ||
euclidean_distance: Optional[float] = None | ||
prev_euclidean_distance: Optional[float] = None # Allow None as a valid value | ||
class DataBatches(BaseModel): | ||
timestamp: List[float] = [] | ||
gazepoint_x: List[float] = [] | ||
gazepoint_y: List[float] = [] | ||
pupil_area_right_sq_mm: List[float] = [] | ||
pupil_area_left_sq_mm: List[float] = [] | ||
eye_event: List[str] = [] | ||
|
||
|
||
class Output(BaseModel): | ||
predictions: List[List[float]] | ||
label_classes: List[str] | ||
prev_euclidean_distance: Optional[List[float]] = None # Return prev_euclidean_distance to run interference | ||
walking: float | ||
playing: float | ||
reading: float | ||
process_data: int # Amount of processed data | ||
|
||
|
||
@app.get('/hello', status_code=status.HTTP_200_OK) | ||
def hello_world(response: Response): | ||
return {'Welcome to SeeTrue AI!': "data"} | ||
|
||
|
||
@app.post("/predict") | ||
async def predict(examples: List[Example]): | ||
processed_batches = [] | ||
prev_euclidean_distances = None | ||
|
||
for example in examples: | ||
# Handle prev_euclidean_distance logic based on eye_event and existing value | ||
if example.eye_event == "FE": | ||
if example.prev_euclidean_distance is None: | ||
example.prev_euclidean_distance = example.euclidean_distance | ||
else: | ||
example.prev_euclidean_distance = example.euclidean_distance | ||
elif example.eye_event == "FB": | ||
if example.prev_euclidean_distance is None: | ||
example.euclidean_distance = 1.0 | ||
example.prev_euclidean_distance = example.euclidean_distance | ||
|
||
# For other eye_event values, prev_euclidean_distance is not modified | ||
|
||
# Wrap the example features into a batch, excluding prev_euclidean_distance | ||
processed_batches.append({ | ||
k: v for k, v in example.model_dump().items() if k != "prev_euclidean_distance" | ||
}) | ||
prev_euclidean_distances = example.prev_euclidean_distance | ||
# For other eye_event values, prev_euclidean_distance is not modified | ||
|
||
# Transpose the batch for model input | ||
example_batch: Dict[str, List[Any]] = { | ||
key: [batch[key] for batch in processed_batches] for key in processed_batches[0] | ||
} | ||
print(example_batch) | ||
print("Previous Euclidean distances:", prev_euclidean_distances) | ||
# Perform prediction | ||
prediction_batch = model.predict(example_batch).tolist() | ||
|
||
# Return the prediction along with the updated prev_euclidean_distance | ||
response = { | ||
"predictions": prediction_batch, | ||
"label_classes": label_classes, | ||
"prev_euclidean_distance": prev_euclidean_distances | ||
} | ||
return JSONResponse(content=response, status_code=status.HTTP_200_OK) | ||
|
||
# | ||
# @app.post("/predict_batch") | ||
# async def predict_batch(example_batch): | ||
# return model.predict(example_batch).tolist() | ||
async def predict(payload: DataBatches): | ||
try: | ||
# Preprocess the payload into individual records | ||
data = payload.model_dump() | ||
processed_data = preprocess_data(data) | ||
process_data = len(processed_data) | ||
# Aggregate processed data into a single batch input | ||
batch_input = { | ||
key: [record[key] for record in processed_data if key != "prev_euclidean_distance"] | ||
for key in processed_data[0] if key != "prev_euclidean_distance" | ||
} | ||
prediction_batch = model.predict(batch_input).tolist() | ||
|
||
# Calculate means for each label class | ||
means = {label_mapping[label]: 0.0 for label in label_mapping.keys()} | ||
counts = {label_mapping[label]: 0 for label in label_mapping.keys()} | ||
|
||
for prediction in prediction_batch: | ||
for i, value in enumerate(prediction): | ||
class_label = label_classes[i] | ||
mapped_label = label_mapping[class_label] | ||
means[mapped_label] += value | ||
counts[mapped_label] += 1 | ||
|
||
# Calculate averages of each prediction | ||
for label in means: | ||
if counts[label] > 0: | ||
means[label] /= counts[label] | ||
|
||
# Construct response | ||
response = Output( | ||
walking=means["walking"], | ||
playing=means["playing"], | ||
reading=means["reading"], | ||
process_data=process_data | ||
) | ||
return response | ||
except Exception as e: | ||
trace_back_msg = traceback.format_exc() | ||
logger.error(f"{str(e)} \n {trace_back_msg}") | ||
return JSONResponse(content={"Error": str(e)}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) | ||
|
||
|
||
if __name__ == '__main__': | ||
uvicorn.run("main:app", host="0.0.0.0", port=8080, workers=1, access_log=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import logging | ||
import traceback | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import re | ||
|
||
# Set up logging to both file and console | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') | ||
logger = logging.getLogger() | ||
""" | ||
Utils function for feature engineering and preprocessing the data | ||
""" | ||
# Regular expression to extract x, y, and d for distance calculation | ||
pattern = r'FEx(?P<x>[-+]?\d*\.\d+)y(?P<y>[-+]?\d*\.\d+)d(?P<d>[-+]?\d*\.\d+)' | ||
|
||
|
||
def calculate_euclidean_distance(eye_event: str) -> Optional[float]: | ||
"""Calculate the Euclidean distance from the eye_event field.""" | ||
try: | ||
match = re.match(pattern, eye_event) | ||
if match: | ||
x = float(match.group('x')) | ||
y = float(match.group('y')) | ||
d = float(match.group('d')) | ||
|
||
# Calculate the Euclidean Distance | ||
F = round(np.sqrt(x ** 2 + y ** 2) * d, 4) | ||
return F | ||
return None | ||
except Exception as e: | ||
trace_back_msg = traceback.format_exc() | ||
logger.error(f"Error while calculating euclidean distance : {str(e)} \n {trace_back_msg}") | ||
raise | ||
|
||
|
||
def remove_na_row(payload: dict) -> dict: | ||
""" | ||
Remove rows where 'eye_event' is 'NA' from the JSON-like payload. | ||
Args: | ||
payload (dict): The JSON-like batch payload. | ||
Returns: | ||
dict: The filtered payload with rows where 'eye_event' is 'NA' removed. | ||
""" | ||
try: | ||
# Determine the indices of rows to keep (where 'eye_event' is not 'NA') | ||
valid_indices = [ | ||
i for i, event in enumerate(payload["eye_event"]) if event.strip() != "NA" | ||
] | ||
|
||
# Filter the payload by keeping only the valid indices | ||
filtered_payload = {key: [values[i] for i in valid_indices] for key, values in payload.items()} | ||
|
||
removed_rows = len(payload["eye_event"]) - len(filtered_payload["eye_event"]) | ||
# print(f"Removed {removed_rows} rows with 'NA' in 'eye_event'.") | ||
|
||
return filtered_payload | ||
except Exception as e: | ||
trace_back_msg = traceback.format_exc() | ||
logger.error(f"Error during NA removal : {str(e)} \n {trace_back_msg}") | ||
# print(f"Error during NA removal: {e}") | ||
# return payload | ||
raise | ||
|
||
|
||
def preprocess_data(payload: dict) -> list: | ||
""" | ||
Process the incoming batch payload into individual records. | ||
Args: | ||
payload (dict): The JSON-like batch payload. | ||
Returns: | ||
list: A list of individual processed records. | ||
""" | ||
try: | ||
# Step 1: Remove rows with 'NA' in 'eye_event' | ||
payload = remove_na_row(payload) | ||
|
||
prev_euclidean_distance = None # Initialize previous distance | ||
processed_data = [] | ||
|
||
# Step 2: Iterate through the filtered records | ||
for i in range(len(payload['timestamp'])): | ||
record = { | ||
"timestamp": payload["timestamp"][i], | ||
"gazepoint_x": payload["gazepoint_x"][i], | ||
"gazepoint_y": payload["gazepoint_y"][i], | ||
"pupil_area_right_sq_mm": payload["pupil_area_right_sq_mm"][i], | ||
"pupil_area_left_sq_mm": payload["pupil_area_left_sq_mm"][i], | ||
"eye_event": payload["eye_event"][i], | ||
} | ||
|
||
# Step 3: Calculate Euclidean Distance | ||
record["euclidean_distance"] = calculate_euclidean_distance(record["eye_event"]) | ||
|
||
# Step 4: Handle prev_euclidean_distance | ||
if record["euclidean_distance"] is None: | ||
if record["eye_event"] in ["S", "BB", "BE"]: | ||
record["euclidean_distance"] = 0.0 | ||
elif record["eye_event"] == "FB": | ||
record["euclidean_distance"] = prev_euclidean_distance or 1.0 | ||
|
||
# Set prev_euclidean_distance for the next record | ||
record["prev_euclidean_distance"] = prev_euclidean_distance | ||
prev_euclidean_distance = record["euclidean_distance"] | ||
|
||
# Step 5: Append the fully processed record | ||
processed_data.append(record) | ||
|
||
return processed_data | ||
except Exception as e: | ||
trace_back_msg = traceback.format_exc() | ||
logger.error(f"Error processing data : {str(e)} \n {trace_back_msg}") | ||
# print(f"Error Error processing data: {e}") | ||
raise |