Skip to content

Commit

Permalink
refactored gui graphs with data processor
Browse files Browse the repository at this point in the history
  • Loading branch information
SamedVossberg committed Dec 5, 2024
1 parent bd76651 commit 87be269
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 170 deletions.
156 changes: 106 additions & 50 deletions gui_dev/data_processor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,137 @@
use wasm_bindgen::prelude::*;
use serde_cbor::Value;
use serde_wasm_bindgen::to_value;
use serde_wasm_bindgen::{from_value, Serializer};
use serde::Serialize;
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap};
use web_sys::console;

#[wasm_bindgen]
pub fn process_cbor_data(data: &[u8]) -> JsValue {
match serde_cbor::from_slice::<BTreeMap<String, Value>>(data) {
Ok(decoded_data) => {
let mut data_by_channel: BTreeMap<String, ChannelData> = BTreeMap::new();
let mut all_features_set: BTreeSet<u32> = BTreeSet::new();
pub fn process_cbor_data(data: &[u8], channels_js: JsValue) -> JsValue {
// Deserialize channels_js into Vec<String>
let channels: Vec<String> = match from_value(channels_js) {
Ok(c) => c,
Err(err) => {
console::error_1(&format!("Failed to parse channels: {:?}", err).into());
return JsValue::NULL;
}
};

for (key, value) in decoded_data {
let (channel_name, feature_name) = get_channel_and_feature(&key);
match serde_cbor::from_slice::<Value>(data) {
Ok(decoded_value) => {
console::log_1(&format!("Decoded value: {:?}", decoded_value).into());
if let Value::Map(decoded_map) = decoded_value {
// create output data structures for each graph
let mut psd_data_by_channel: BTreeMap<String, ChannelData> = BTreeMap::new();
let mut raw_data_by_channel: BTreeMap<String, Value> = BTreeMap::new();
let mut bandwidth_data_by_channel: BTreeMap<String, BTreeMap<String, Value>> = BTreeMap::new();
let mut all_data: BTreeMap<String, Value> = BTreeMap::new();

if channel_name.is_empty() {
continue;
}
let bandwidth_features = vec![
"fft_theta_mean",
"fft_alpha_mean",
"fft_low_beta_mean",
"fft_high_beta_mean",
"fft_low_gamma_mean",
"fft_high_gamma_mean",
];

if !feature_name.starts_with("fft_psd_") {
continue;
}
for (key_value, value) in decoded_map {
let key_str = match key_value {
Value::Text(s) => s,
_ => continue,
};

let feature_number = &feature_name["fft_psd_".len()..];
let feature_index = match feature_number.parse::<u32>() {
Ok(n) => n,
Err(_) => continue,
};
// Insert into all_data
all_data.insert(key_str.clone(), value.clone());

all_features_set.insert(feature_index);
let (channel_name, feature_name) =
get_channel_and_feature(&key_str, &channels);

let channel_data = data_by_channel
.entry(channel_name.clone())
.or_insert_with(|| ChannelData {
channel_name: channel_name.clone(),
feature_map: BTreeMap::new(),
});
if channel_name.is_empty() {
continue;
}

channel_data.feature_map.insert(feature_index, value);
}
if feature_name == "raw" {
raw_data_by_channel.insert(channel_name.clone(), value.clone());
} else if feature_name.starts_with("fft_psd_") {
let feature_number = &feature_name["fft_psd_".len()..];
let feature_index = match feature_number.parse::<u32>() {
Ok(n) => n,
Err(_) => continue,
};

let feature_index_str = feature_index.to_string();

let all_features: Vec<u32> = all_features_set.into_iter().collect();
let channel_data = psd_data_by_channel
.entry(channel_name.clone())
.or_insert_with(|| ChannelData {
channel_name: channel_name.clone(),
feature_map: BTreeMap::new(),
});

let result = ProcessedData {
data_by_channel,
all_features,
};
channel_data
.feature_map
.insert(feature_index_str, value.clone());
} else if bandwidth_features.contains(&feature_name.as_str()) {

to_value(&result).unwrap_or(JsValue::NULL)
let channel_bandwidth_data = bandwidth_data_by_channel
.entry(channel_name.clone())
.or_insert_with(BTreeMap::new);

channel_bandwidth_data.insert(feature_name.clone(), value.clone());
}
}

let result = ProcessedData {
psd_data_by_channel,
raw_data_by_channel,
bandwidth_data_by_channel,
all_data,
};

// Serialize maps as plain JavaScript objects
let serializer = Serializer::new().serialize_maps_as_objects(true);
match result.serialize(&serializer) {
Ok(js_value) => js_value,
Err(err) => {
console::error_1(&format!("Serialization error: {:?}", err).into());
JsValue::NULL
}
}
} else {
console::error_1(&"Decoded CBOR data is not a map.".into());
JsValue::NULL
}
}
Err(e) => {
// Optionally log the error for debugging
Err(err) => {
console::error_1(&format!("Failed to decode CBOR data: {:?}", err).into());
JsValue::NULL
},
}
}
}

fn get_channel_and_feature(key: &str) -> (String, String) {
// Adjusted to split at the "_fft_psd_" pattern
let pattern = "_fft_psd_";
if let Some(pos) = key.find(pattern) {
let channel_name = &key[..pos];
let feature_name = &key[pos + 1..]; // Skip the underscore
(channel_name.to_string(), feature_name.to_string())
} else {
("".to_string(), key.to_string())
fn get_channel_and_feature(key: &str, channels: &[String]) -> (String, String) {
// Iterate over channels to find if the key starts with any channel name
for channel in channels {
if key.starts_with(channel) {
let feature_name = key[channel.len()..].trim_start_matches('_');
return (channel.clone(), feature_name.to_string());
}
}
// No matching channel found
("".to_string(), key.to_string())
}

#[derive(Serialize)]
struct ChannelData {
channel_name: String,
feature_map: BTreeMap<u32, Value>,
feature_map: BTreeMap<String, Value>,
}

#[derive(Serialize)]
struct ProcessedData {
data_by_channel: BTreeMap<String, ChannelData>,
all_features: Vec<u32>,
psd_data_by_channel: BTreeMap<String, ChannelData>,
raw_data_by_channel: BTreeMap<String, Value>,
bandwidth_data_by_channel: BTreeMap<String, BTreeMap<String, Value>>,
all_data: BTreeMap<String, Value>,
}
49 changes: 15 additions & 34 deletions gui_dev/src/components/BandPowerGraph.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
FormControlLabel,
} from "@mui/material";
import { CollapsibleBox } from "./CollapsibleBox";
import { getChannelAndFeature } from "./utils";
import { shallow } from "zustand/shallow";

const generateColors = (numColors) => {
Expand Down Expand Up @@ -47,45 +46,27 @@ export const BandPowerGraph = () => {
const [selectedChannel, setSelectedChannel] = useState("");
const hasInitialized = useRef(false);

const socketData = useSocketStore((state) => state.graphData);
const processedData = useSocketStore((state) => state.processedData);

const data = useMemo(() => {
if (!socketData || !selectedChannel) return null;
const dataByChannel = {};
if (!processedData || !selectedChannel) return null;

Object.entries(socketData).forEach(([key, value]) => {
const { channelName = "", featureName = "" } = getChannelAndFeature(
availableChannels,
key
);
if (!channelName) return;

if (!fftFeatures.includes(featureName)) return;

if (!dataByChannel[channelName]) {
dataByChannel[channelName] = {
channelName,
features: [],
values: [],
};
}
const bandwidthDataByChannel = processedData.bandwidth_data_by_channel;

dataByChannel[channelName].features.push(featureName);
dataByChannel[channelName].values.push(value);
});

const channelData = dataByChannel[selectedChannel];
const channelData = bandwidthDataByChannel[selectedChannel];
if (channelData) {
const sortedValues = fftFeatures.map((feature) => {
const index = channelData.features.indexOf(feature);
return index !== -1 ? channelData.values[index] : null;
const features = fftFeatures.map((f) =>
f.replace("_mean", "").replace("fft_", "")
);
const values = fftFeatures.map((feature) => {
const value = channelData[feature];
return value !== undefined ? value : null;
});

return {
channelName: selectedChannel,
features: fftFeatures.map((f) =>
f.replace("_mean", "").replace("fft_", "")
),
values: sortedValues,
features,
values,
};
} else {
return {
Expand All @@ -96,7 +77,7 @@ export const BandPowerGraph = () => {
values: fftFeatures.map(() => null),
};
}
}, [socketData, selectedChannel, availableChannels]);
}, [processedData, selectedChannel]);

const graphRef = useRef(null);
const plotlyRef = useRef(null);
Expand Down Expand Up @@ -169,7 +150,7 @@ export const BandPowerGraph = () => {
<Typography variant="h6" sx={{ flexGrow: 1 }}>
Band Power
</Typography>
<Box sx={{ ml: 2, mr:4, minWidth: 200 }}>
<Box sx={{ ml: 2, mr: 4, minWidth: 200 }}>
<CollapsibleBox title="Channel Selection" defaultExpanded={true}>
<Box display="flex" flexDirection="column">
<RadioGroup
Expand Down
27 changes: 14 additions & 13 deletions gui_dev/src/components/HeatmapGraph.jsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import React, { useEffect, useState, useRef, useMemo } from 'react';
import { useSocketStore, useSessionStore } from '@/stores';
import { useSocketStore } from '@/stores/socketStore';
import { useSessionStore } from '@/stores/sessionStore';
import Plot from 'react-plotly.js';
import {
Box,
Expand All @@ -12,7 +13,6 @@ import {
Slider,
} from '@mui/material';
import { CollapsibleBox } from './CollapsibleBox';
import { getChannelAndFeature } from './utils';
import { shallow } from 'zustand/shallow';

export const HeatmapGraph = () => {
Expand All @@ -37,7 +37,7 @@ export const HeatmapGraph = () => {
const [isDataStale, setIsDataStale] = useState(false);
const [lastDataTime, setLastDataTime] = useState(null);

const graphData = useSocketStore((state) => state.graphData);
const processedData = useSocketStore((state) => state.processedData);

const [maxDataPoints, setMaxDataPoints] = useState(100);

Expand Down Expand Up @@ -83,9 +83,10 @@ export const HeatmapGraph = () => {
}, [usedChannels, selectedChannel]);

useEffect(() => {
if (!graphData || !selectedChannel) return;
if (!processedData || !selectedChannel) return;

const dataKeys = Object.keys(graphData);
const allData = processedData.all_data || {};
const dataKeys = Object.keys(allData);
const channelPrefix = `${selectedChannel}_`;
const featureKeys = dataKeys.filter(
(key) => key.startsWith(channelPrefix) && key !== 'time'
Expand Down Expand Up @@ -113,17 +114,18 @@ export const HeatmapGraph = () => {
setIsDataStale(false);
setLastDataTime(null);
}
}, [graphData, selectedChannel, features]);
}, [processedData, selectedChannel]);

useEffect(() => {
if (
!graphData ||
!processedData ||
!selectedChannel ||
features.length === 0 ||
selectedFeatures.length === 0
)
return;

const allData = processedData.all_data || {};
setLastDataTime(Date.now());
setIsDataStale(false);

Expand All @@ -137,7 +139,7 @@ export const HeatmapGraph = () => {

selectedFeatures.forEach((featureName, idx) => {
const key = `${selectedChannel}_${featureName}`;
const value = graphData[key];
const value = allData[key];
const numericValue = typeof value === 'number' && !isNaN(value) ? value : 0;

// Shift existing data to the left if necessary
Expand All @@ -155,7 +157,7 @@ export const HeatmapGraph = () => {

setHeatmapData({ x, z });
}, [
graphData,
processedData,
selectedChannel,
features,
selectedFeatures,
Expand Down Expand Up @@ -183,13 +185,12 @@ export const HeatmapGraph = () => {
plot_bgcolor: '#333',
autosize: true,
xaxis: {
title: { text: 'Nr. of Samples', font: { color: '#f4f4f4' } },
title: { text: 'Number of Samples', font: { color: '#f4f4f4' } },
color: '#cccccc',
tickfont: {
color: '#cccccc',
},
automargin: false,
// autorange: 'reversed'
},
yaxis: {
title: { text: 'Features', font: { color: '#f4f4f4' } },
Expand Down Expand Up @@ -300,8 +301,8 @@ export const HeatmapGraph = () => {
]}
layout={layout}
useResizeHandler={true}
style={{ width: '100%', height: '100%'}}
config={{ responsive: true, displayModeBar: false}}
style={{ width: '100%', height: '100%' }}
config={{ responsive: true, displayModeBar: false }}
/>
)}
</Box>
Expand Down
Loading

0 comments on commit 87be269

Please sign in to comment.