"
]
diff --git a/nbs/15_timeseries_plots.ipynb b/nbs/15_timeseries_plots.ipynb
new file mode 100644
index 0000000..0351316
--- /dev/null
+++ b/nbs/15_timeseries_plots.ipynb
@@ -0,0 +1,921 @@
+{
+ "cells": [
+ {
+ "cell_type": "raw",
+ "metadata": {},
+ "source": [
+ "---\n",
+ "output-file: timeseries_plots.html\n",
+ "title: Time series plots\n",
+ "\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| default_exp timeseries_plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "from nbdev.showdoc import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "from typing import Callable, Iterable, Optional, Union, Tuple\n",
+ "import warnings\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "import seaborn as sns\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.dates as mdates\n",
+ "\n",
+ "from pheno_utils.config import DEFAULT_PALETTE, TIME_FORMAT, LEGEND_SHIFT"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "class TimeSeriesFigure:\n",
+ " def __init__(self, figsize: tuple = (10, 6), padding: float = 0.05):\n",
+ " \"\"\"\n",
+ " Initialize a TimeSeriesFigure instance. This class is used to create and manage\n",
+ " a figure with multiple axes for time series data.\n",
+ " \n",
+ " Args:\n",
+ " figsize (tuple): Size of the figure (width, height) in inches.\n",
+ " \"\"\"\n",
+ " self.fig = plt.figure(figsize=figsize)\n",
+ " self.axes: Iterable[tuple] = []\n",
+ " self.axis_names: dict = {}\n",
+ " self.padding = padding\n",
+ " self.custom_paddings = {} # To store custom padding for specific axes\n",
+ " self.shared_x_groups = [] # To keep track of shared x-axis groups\n",
+ "\n",
+ " def plot(\n",
+ " self, \n",
+ " plot_function: Callable, \n",
+ " *args, \n",
+ " n_axes: int = 1, \n",
+ " height: float = 1, \n",
+ " sharex: Union[str, int, plt.Axes] = None, \n",
+ " second_y: bool = False,\n",
+ " name: str = None, \n",
+ " ax: Union[str, int, plt.Axes] = None, \n",
+ " adjust_time: Optional[str] = 'union',\n",
+ " adjust_by_axis: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None,\n",
+ " **kwargs\n",
+ " ) -> Union[plt.Axes, Iterable[plt.Axes]]:\n",
+ " \"\"\"\n",
+ " Plot using a dataset-specific function, creating a new axis if needed.\n",
+ " The plot function should accept the axis object as the argument `ax`, or\n",
+ " a list of axes if multiple axes are used.\n",
+ " \n",
+ " Args:\n",
+ " plot_function (Callable): The dataset-specific function to plot the data.\n",
+ " *args: Arguments to pass to the plot function.\n",
+ " n_axes (int): The number of axes required. Default is 1.\n",
+ " height (float): The proportional height of the axes relative to a single unit axis.\n",
+ " sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.\n",
+ " second_y (bool): If True, plot will be done on a secondary y-axis in the plot. Default is False.s\n",
+ " name (str): Name or ID to assign to the axis.\n",
+ " ax (plt.Axes, str, int): Pre-existing axis (object, name, or index) or list of axes to plot on.\n",
+ " adjust_time (str, None): Method to adjust the time limits of all axes to match the data.\n",
+ " adjust_by_axis (str, int, plt.Axes): Axes (single or multiple) to use as a reference for adjusting the time limits.\n",
+ " **kwargs: Keyword arguments to pass to the plot function.\n",
+ " \n",
+ " Returns:\n",
+ " Union[plt.Axes, Iterable[plt.Axes]]: A single axis object or a list of axis objects if multiple axes are used.\n",
+ " \"\"\"\n",
+ " if ax is None:\n",
+ " ax = self.add_axes(height=height, n_axes=n_axes, sharex=sharex, name=name)\n",
+ " else:\n",
+ " ax = self.get_axes(ax, squeeze=True)\n",
+ "\n",
+ " if second_y:\n",
+ " ax.yaxis.grid(False)\n",
+ " ax = ax.twinx()\n",
+ "\n",
+ " plot_function(*args, ax=ax, **kwargs)\n",
+ " if adjust_time:\n",
+ " self.set_time_limits(None, None, method=adjust_time, reference_axis=adjust_by_axis)\n",
+ " if second_y:\n",
+ " ax.yaxis.grid(False)\n",
+ " ax.yaxis.label.set_rotation(90)\n",
+ " ax.yaxis.label.set_ha('center')\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ " def add_axes(\n",
+ " self, \n",
+ " height: float = 1, \n",
+ " n_axes: int = 1, \n",
+ " sharex: Optional[Union[str, int, plt.Axes]] = None, \n",
+ " name: Optional[str] = None,\n",
+ " ) -> Union[plt.Axes, Iterable[plt.Axes]]:\n",
+ " \"\"\"\n",
+ " Add one or more axes with a specific proportional height to the figure.\n",
+ " \n",
+ " Args:\n",
+ " height (float): The proportional height of each new axis relative to a single unit axis.\n",
+ " n_axes (int): The number of axes to create.\n",
+ " sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.\n",
+ " name (Optional[str]): Name or ID to assign to the axis (only valid if num_axes=1).\n",
+ " \n",
+ " Returns:\n",
+ " Union[plt.Axes, Iterable[plt.Axes]]: A single axis object or a list of axis objects if multiple axes are created.\n",
+ " \"\"\"\n",
+ " new_axes = []\n",
+ " shared_group = []\n",
+ " \n",
+ " if sharex is not None:\n",
+ " sharex = self.get_axes(sharex)[0]\n",
+ " shared_group.append(sharex)\n",
+ "\n",
+ " for _ in range(n_axes):\n",
+ " ax = self.fig.add_subplot(len(self.axes) + 1, 1, len(self.axes) + 1, sharex=sharex)\n",
+ " new_axes.append(ax)\n",
+ " self.axes.append((ax, height))\n",
+ " shared_group.append(ax)\n",
+ " # When creating mulitple axes, always share their x-axis with the first one\n",
+ " if sharex is None:\n",
+ " sharex = ax\n",
+ " \n",
+ " if shared_group:\n",
+ " self.shared_x_groups.append(shared_group)\n",
+ "\n",
+ " if name is not None:\n",
+ " self.axis_names[name] = new_axes\n",
+ " \n",
+ " self._adjust_axes()\n",
+ "\n",
+ " return new_axes if n_axes > 1 else new_axes[0]\n",
+ "\n",
+ " def _adjust_axes(self) -> None:\n",
+ " \"\"\"\n",
+ " Adjust the positions and sizes of all axes based on their proportional height and apply padding.\n",
+ " \"\"\"\n",
+ " total_height = sum(height for _, height in self.axes)\n",
+ " total_padding = self.padding * (len(self.axes) - 1)\n",
+ " bottom = 1 - total_padding # Start from the top of the figure\n",
+ "\n",
+ " for i, (ax, height) in enumerate(self.axes):\n",
+ " ax_height = height / total_height * (1 - total_padding)\n",
+ " # Adjust for any custom padding before this axis\n",
+ " custom_pad = self.custom_paddings.get(i, 0)\n",
+ " ax.set_position([0.1, bottom - ax_height, 0.8, ax_height])\n",
+ " bottom -= ax_height + self.padding + custom_pad # Move down, considering padding\n",
+ "\n",
+ " def _get_axis_by_name(self, name: str) -> Optional[plt.Axes]:\n",
+ " \"\"\"\n",
+ " Retrieve an axis by its name or ID.\n",
+ " \n",
+ " Args:\n",
+ " name (str): The name or ID of the axis to retrieve.\n",
+ " \n",
+ " Returns:\n",
+ " Optional[plt.Axes]: The corresponding axis object if found, otherwise None.\n",
+ " \"\"\"\n",
+ " return self.axis_names.get(name, [])\n",
+ "\n",
+ " def get_axes(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, squeeze=False) -> Iterable[plt.Axes]:\n",
+ " \"\"\"\n",
+ " Retrieve the axis object(s) based on the input type.\n",
+ "\n",
+ " Args:\n",
+ " ax: The axis object, index, name, or list of those to retrieve.\n",
+ " squeeze (bool): Whether to return a single axis object if only one is found.\n",
+ " \n",
+ " Returns:\n",
+ " Iterable[plt.Axes]: A list of axis objects.\n",
+ " \"\"\"\n",
+ " if ax is None:\n",
+ " return [a for a, _ in self.axes]\n",
+ " elif not isinstance(ax, list):\n",
+ " ax = [ax]\n",
+ " \n",
+ " ax_list = []\n",
+ " for a in ax:\n",
+ " if isinstance(a, str):\n",
+ " by_name = self._get_axis_by_name(a)\n",
+ " if len(by_name) == 0:\n",
+ " warnings.warn(f\"No axis found with name '{a}'\")\n",
+ " ax_list.extend(by_name)\n",
+ " elif isinstance(a, int):\n",
+ " ax_list.append(self.axes[a][0])\n",
+ "\n",
+ " if squeeze and len(ax_list) == 1:\n",
+ " return ax_list[0]\n",
+ " else:\n",
+ " return ax_list\n",
+ "\n",
+ " def print_shared_axes(self):\n",
+ " \"\"\"\n",
+ " Print which axes in the figure share their x-axis.\n",
+ "\n",
+ " Returns:\n",
+ " None\n",
+ " \"\"\"\n",
+ " shared_groups = {}\n",
+ " for i, (ax, _) in enumerate(self.axes):\n",
+ " for j, (other_ax, _) in enumerate(self.axes):\n",
+ " if i != j and ax.get_shared_x_axes().joined(ax, other_ax):\n",
+ " if i not in shared_groups:\n",
+ " shared_groups[i] = []\n",
+ " shared_groups[i].append(j)\n",
+ "\n",
+ " for ax_idx, shared_with in shared_groups.items():\n",
+ " print(f\"Axis {ax_idx} shares its x-axis with: {shared_with}\")\n",
+ "\n",
+ " def get_axis_properties(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None) -> dict:\n",
+ " \"\"\"\n",
+ " Get the properties of a specific axis or axes.\n",
+ " \n",
+ " Args:\n",
+ " ax (str, int, plt.Axes, or a list of those): The axis or axes to get the properties for.\n",
+ " \n",
+ " Returns:\n",
+ " dict: A dictionary of properties for the axis or axes.\n",
+ " \"\"\"\n",
+ " ax_list = self.get_axes(ax)\n",
+ " properties = {}\n",
+ " for a in ax_list:\n",
+ " properties = {key: properties.get(key, []) + [value] for key, value in a.properties().items()}\n",
+ "\n",
+ " for k, v in properties.items():\n",
+ " if len(v) == 1:\n",
+ " properties[k] = v[0]\n",
+ "\n",
+ " return properties\n",
+ "\n",
+ " def set_axis_properties(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, **kwargs) -> None:\n",
+ " \"\"\"\n",
+ " Set properties for a specific axis or axes.\n",
+ " \n",
+ " Args:\n",
+ " ax (str, int, plt.Axes, or a list of those): The axis or axes to set the properties for.\n",
+ " **kwargs: Additional keyword arguments to pass to the axis object.\n",
+ " \"\"\"\n",
+ " ax_list = self.get_axes(ax)\n",
+ " for a in ax_list:\n",
+ " a.set(**kwargs)\n",
+ "\n",
+ " def set_axis_padding(self, padding: float, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, above: bool = True) -> None:\n",
+ " \"\"\"\n",
+ " Set custom padding for a specific axis.\n",
+ " \n",
+ " Args:\n",
+ " padding (float): The amount of padding to add as a fraction of the figure height.\n",
+ " \n",
+ " above (bool): Whether to add padding above the axis (default) or below.\n",
+ " \"\"\"\n",
+ " ax_list = self.get_axes(ax)\n",
+ " all_axes = [a for a, _ in self.axes]\n",
+ "\n",
+ " for ax in ax_list:\n",
+ " axis_index = all_axes.index(ax)\n",
+ " if axis_index < 0:\n",
+ " warnings.warn(\"Axis not found in the figure.\")\n",
+ " continue\n",
+ " if above:\n",
+ " self.custom_paddings[axis_index] = padding\n",
+ " elif axis_index == len(self.axes) - 1:\n",
+ " continue\n",
+ " else:\n",
+ " self.custom_paddings[axis_index + 1] = padding\n",
+ " self._adjust_axes()\n",
+ "\n",
+ " def set_time_limits(\n",
+ " self, start_time: Union[float, str, pd.Timestamp, None],\n",
+ " end_time: Union[float, str, pd.Timestamp, None],\n",
+ " method: str='union',\n",
+ " reference_axis: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None\n",
+ " ) -> None:\n",
+ " \"\"\"\n",
+ " Set the time limits for all axes in the figure. Calling with None will adjust the limits to the data.\n",
+ "\n",
+ " Args:\n",
+ " start_time (Union[float, str, pd.Timestamp, None]): The start time for the x-axis.\n",
+ " end_time (Union[float, str, pd.Timestamp, None]): The end time for the x-axis.\n",
+ " \"\"\"\n",
+ " # Default values\n",
+ " xlim = np.array(self.get_axis_properties(reference_axis)['xlim']).reshape((-1, 2))\n",
+ " if method == 'union':\n",
+ " xlim = xlim[:, 0].min(), xlim[:, 1].max()\n",
+ " elif method == 'intersect':\n",
+ " xlim = xlim[:, 0].max(), xlim[:, 1].min()\n",
+ " else:\n",
+ " raise ValueError(f\"Invalid method: {method} not in ['union', 'intersect']\")\n",
+ "\n",
+ " # Convert string inputs to pandas Timestamp objects\n",
+ " if start_time is not None:\n",
+ " start_time = pd.to_datetime(start_time)\n",
+ " else:\n",
+ " start_time = xlim[0]\n",
+ " if end_time is not None:\n",
+ " end_time = pd.to_datetime(end_time)\n",
+ " else:\n",
+ " end_time = xlim[1]\n",
+ "\n",
+ " self.set_axis_properties(xlim=(start_time, end_time))\n",
+ "\n",
+ " def set_periodic_ticks(\n",
+ " self, \n",
+ " interval: Union[str, pd.Timedelta], \n",
+ " start_time: str = '2018-01-01 00:00',\n",
+ " end_time: str = None,\n",
+ " fmt=TIME_FORMAT,\n",
+ " ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None\n",
+ " ) -> None:\n",
+ " \"\"\"\n",
+ " Set periodic x-ticks at a regular interval throughout the day.\n",
+ "\n",
+ " Args:\n",
+ " interval (Union[str, pd.Timedelta]): The interval between ticks (e.g., '1H' for hourly ticks, '30T' for 30 minutes).\n",
+ " start_time (str): The time of day to start the ticks from (default is '00:00').\n",
+ " end_time (str): The time of day to end the ticks at (default is None).\n",
+ " fmt (str): The date format string to be used for the tick labels.\n",
+ " ax (str, int, plt.Axes, or a list of those): The axis (or axes) to apply the ticks to. \n",
+ " Can be an axis object, a list of axes, an index, or a name. If None, applies to all axes.\n",
+ " \"\"\"\n",
+ " # Convert interval to pandas Timedelta if it's a string\n",
+ " if isinstance(interval, str):\n",
+ " interval = pd.to_timedelta(interval)\n",
+ "\n",
+ " # Convert start_time to a datetime object with today's date\n",
+ " if start_time is not None:\n",
+ " start_time = pd.to_datetime(start_time).tz_localize(None)\n",
+ " if end_time is not None:\n",
+ " end_time = pd.to_datetime(end_time).tz_localize(None)\n",
+ "\n",
+ " # Determine which axes to apply this to\n",
+ " axes = self.get_axes(ax)\n",
+ "\n",
+ " for a in axes:\n",
+ " if a is not None:\n",
+ " # Get the x-axis limits\n",
+ " min_x, max_x = a.get_xlim()\n",
+ "\n",
+ " # Convert limits to datetime if they are in float format\n",
+ " if isinstance(min_x, (float, int)):\n",
+ " min_x = mdates.num2date(min_x).replace(tzinfo=None)\n",
+ " if isinstance(max_x, (float, int)):\n",
+ " max_x = mdates.num2date(max_x).replace(tzinfo=None)\n",
+ "\n",
+ " # Set the ticks to align with the start_datetime\n",
+ " ticks = pd.date_range(start=start_time if start_time else min_x,\n",
+ " end=end_time if end_time else max_x,\n",
+ " freq=interval)\n",
+ "\n",
+ " # Make sure ticks are within the limits\n",
+ " ticks = [tick for tick in ticks if min_x <= tick and tick <= max_x]\n",
+ "\n",
+ " # Set the locator and formatter\n",
+ " format_xticks(a, ticks, fmt)\n",
+ "\n",
+ " plt.setp(a.get_xticklabels(), rotation=0, ha='center')\n",
+ "\n",
+ " def add_legend(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]], **kwargs) -> None:\n",
+ " \"\"\"\n",
+ " Add a legend to a specific axis.\n",
+ " \n",
+ " Args:\n",
+ " axis (str, int, plt.Axes, or a list of those): The axis to add the legend to.\n",
+ " \"\"\"\n",
+ " ax_list = self.get_axes(ax)\n",
+ " for a in ax_list:\n",
+ " a.legend(**kwargs)\n",
+ "\n",
+ " def set_legend(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, bbox_to_anchor: tuple=None, **kwargs):\n",
+ " \"\"\"\n",
+ " Update the legend properties for all axes in the figure, or a subset of them, if the legend exists.\n",
+ "\n",
+ " Args:\n",
+ " axis (str, int, plt.Axes, or a list of those): The name or list of names of axes to update the legend for.\n",
+ " bbox_to_anchor (tuple, optional): The bounding box coordinates for the legend.\n",
+ " **kwargs: Additional keyword arguments passed to the legend object.\n",
+ " \"\"\"\n",
+ " ax_list = self.get_axes(ax)\n",
+ "\n",
+ " for a in ax_list:\n",
+ " legend = a.get_legend()\n",
+ " if legend is None:\n",
+ " continue\n",
+ " if bbox_to_anchor is not None:\n",
+ " legend.set_bbox_to_anchor(bbox_to_anchor)\n",
+ " legend.set(**kwargs)\n",
+ "\n",
+ " def show(self) -> None:\n",
+ " \"\"\"\n",
+ " Display the figure.\n",
+ " \"\"\"\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def format_xticks(ax: plt.Axes, xticks: Iterable=None, format: str=TIME_FORMAT, **kwargs):\n",
+ " \"\"\" format datestrings on x axis \"\"\"\n",
+ " if xticks is None:\n",
+ " xticks = ax.get_xticks()\n",
+ " ax.set_xticks(xticks)\n",
+ " ax.set_xticklabels(xticks, **kwargs)\n",
+ " xfmt = mdates.DateFormatter(format)\n",
+ " ax.xaxis.set_major_formatter(xfmt)\n",
+ "\n",
+ "\n",
+ "def format_timeseries(\n",
+ " df: pd.DataFrame,\n",
+ " participant_id: int=None,\n",
+ " array_index: int=None,\n",
+ " time_range: Tuple[str, str]=None,\n",
+ " x_start: str='collection_timestamp',\n",
+ " x_end: str='collection_timestamp',\n",
+ " unique: bool=False,\n",
+ ") -> pd.DataFrame:\n",
+ " \"\"\"\n",
+ " Reformat and filter a time series DataFrame based on participant ID, array index, and date range.\n",
+ "\n",
+ " Args:\n",
+ " df (pd.DataFrame): The DataFrame to filter.\n",
+ " participant_id (int): The participant ID to filter by.\n",
+ " array_index (int): The array index to filter by.\n",
+ " time_range: The date range to filter by. Can be a tuple of two dates / times or two strings.\n",
+ " x_start (str): The name of the column containing the start time.\n",
+ " x_end (str): The name of the column containing the end time.\n",
+ "\n",
+ " Returns:\n",
+ " pd.DataFrame: The filtered DataFrame\n",
+ " \"\"\"\n",
+ " if participant_id is not None:\n",
+ " df = df.query('participant_id == @participant_id')\n",
+ " if array_index is not None:\n",
+ " df = df.query('array_index == @array_index')\n",
+ "\n",
+ " # Reset index to avoid issues with slicing and indexing\n",
+ " x_ind = np.unique([c for c in [x_start, x_end] if c in df.index.names])\n",
+ " if len(x_ind):\n",
+ " if np.isin(x_ind, df.index.names).any():\n",
+ " df = df.reset_index(x_ind)\n",
+ " df[x_start] = df[x_start].dt.tz_localize(None)\n",
+ " if x_start != x_end:\n",
+ " df[x_end] = df[x_end].dt.tz_localize(None)\n",
+ " if time_range is not None:\n",
+ " time_range = pd.to_datetime(time_range)\n",
+ " df = df.loc[(time_range[0] <= df[x_start]) & (df[x_end] <= time_range[1])]\n",
+ " if unique:\n",
+ " df = df.drop_duplicates()\n",
+ "\n",
+ " return df.sort_values(x_start)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "def plot_events_bars(\n",
+ " events: pd.DataFrame,\n",
+ " x_start: str = 'collection_timestamp',\n",
+ " x_end: str = 'event_end',\n",
+ " y: str = 'event',\n",
+ " hue: str = 'channel',\n",
+ " participant_id: Optional[int] = None,\n",
+ " array_index: Optional[int] = None,\n",
+ " time_range: Optional[Tuple[str, str]] = None,\n",
+ " y_include: Optional[Iterable[str]] = None,\n",
+ " y_exclude: Optional[Iterable[str]] = None,\n",
+ " legend: bool = True,\n",
+ " palette: str = DEFAULT_PALETTE,\n",
+ " alpha: Optional[float] = 0.7,\n",
+ " ax: Optional[plt.Axes] = None,\n",
+ " figsize: Tuple[float, float] = (12, 6),\n",
+ ") -> plt.Axes:\n",
+ " \"\"\"\n",
+ " Plot events as bars on a time series plot.\n",
+ "\n",
+ " Args:\n",
+ " events (pd.DataFrame): The events dataframe.\n",
+ " x_start (str): The column name for the start time of the event.\n",
+ " x_end (str): The column name for the end time of the event.\n",
+ " y (str): The column name for the y-axis values.\n",
+ " hue (str): The column name for the color of the event.\n",
+ " participant_id (int): The participant ID to filter events by.\n",
+ " array_index (int): The array index to filter events by.\n",
+ " time_range (Tuple[str, str]): The time range to filter events by.\n",
+ " y_include (Iterable[str]): The list of values to include in the plot.\n",
+ " y_exclude (Iterable[str]): The list of values to exclude from the plot.\n",
+ " legend (bool): Whether to show the legend.\n",
+ " palette (str): The name of the colormap to use for coloring events.\n",
+ " alpha (float): The transparency of the bars. Default is 0.7.\n",
+ " ax (plt.Axes): The axis to plot on. If None, a new figure is created.\n",
+ " figsize (Tuple[float, float]): The size of the figure (width, height) in inches.\n",
+ " \"\"\"\n",
+ " events, color_map = prep_to_plot_timeseries(\n",
+ " events, x_start, x_end,\n",
+ " hue, y,\n",
+ " participant_id, array_index, time_range,\n",
+ " y_include, y_exclude,\n",
+ " palette=palette)\n",
+ " if hue is None:\n",
+ " hue = 'hue'\n",
+ "\n",
+ " if ax is None:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " # Plot events\n",
+ " events = events.assign(diff=lambda x: x[x_end] - x[x_start]).sort_values([hue, y])\n",
+ " y_labels = []\n",
+ " legend_dicts = []\n",
+ " for i, (y_label, events) in enumerate(events.groupby(y, observed=True, sort=False)):\n",
+ " if len(y) == 0:\n",
+ " continue\n",
+ " y_labels.append(y_label)\n",
+ " for c, r in events.groupby(hue, observed=True):\n",
+ " data = r[[x_start, 'diff']]\n",
+ " if not len(data):\n",
+ " continue\n",
+ " h = ax.broken_barh(data.values, (i-0.4,0.8), color=color_map[c], alpha=alpha)\n",
+ " legend_dicts.append({'label': c, 'handle': h})\n",
+ "\n",
+ " # format plot\n",
+ " if legend:\n",
+ " legend_df = pd.DataFrame.from_dict(legend_dicts).drop_duplicates(subset='label')\n",
+ " ax.legend(\n",
+ " legend_df['handle'],\n",
+ " legend_df['label'],\n",
+ " loc='upper left', \n",
+ " bbox_to_anchor=LEGEND_SHIFT)\n",
+ "\n",
+ " ax.set_yticks(np.arange(len(y_labels)), y_labels)\n",
+ " format_xticks(ax)\n",
+ " ax.invert_yaxis() # Invert y-axis to match the order of the legend\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ "\n",
+ "def plot_events_fill(\n",
+ " events: pd.DataFrame,\n",
+ " x_start: str = 'collection_timestamp',\n",
+ " x_end: str = 'event_end',\n",
+ " hue: str = 'channel',\n",
+ " label: str = None,\n",
+ " participant_id: Optional[int] = None,\n",
+ " array_index: Optional[int] = None,\n",
+ " time_range: Optional[Tuple[str, str]] = None,\n",
+ " y_include: Optional[Iterable[str]] = None,\n",
+ " y_exclude: Optional[Iterable[str]] = None,\n",
+ " legend: bool = True,\n",
+ " palette: str = DEFAULT_PALETTE,\n",
+ " alpha: Optional[float] = 0.5,\n",
+ " ax: Optional[plt.Axes] = None,\n",
+ " figsize: Iterable[float] = [12, 6],\n",
+ ") -> plt.Axes:\n",
+ " \"\"\"\n",
+ " Plot events as filled regions on a time series plot.\n",
+ "\n",
+ " Args:\n",
+ " events (pd.DataFrame): The events dataframe.\n",
+ " x_start (str): The column name for the start time of the event.\n",
+ " x_end (str): The column name for the end time of the event.\n",
+ " hue (str): The column name for the color of the event.\n",
+ " label (str): The column name for the label of the event.\n",
+ " participant_id (int): The participant ID to filter events by.\n",
+ " array_index (int): The array index to filter events by.\n",
+ " time_range (Iterable[str]): The time range to filter events by.\n",
+ " y_include (Iterable[str]): The list of values to include in the plot.\n",
+ " y_exclude (Iterable[str]): The list of values to exclude from the plot.\n",
+ " legend (bool): Whether to show the legend.\n",
+ " palette (str): The name of the palette to use for coloring events.\n",
+ " alpha (float): The transparency of the filled regions.\n",
+ " ax (plt.Axes): The axis to plot on. If None, a new figure is created.\n",
+ " figsize (Tuple[float, float]): The size of the figure (width, height) in inches.\n",
+ " \"\"\"\n",
+ " events, color_map = prep_to_plot_timeseries(\n",
+ " events, x_start, x_end,\n",
+ " hue, label,\n",
+ " participant_id, array_index, time_range,\n",
+ " y_include, y_exclude,\n",
+ " palette=palette)\n",
+ " if hue is None:\n",
+ " hue = 'hue'\n",
+ "\n",
+ " if ax is None:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ " if type(ax) is not list:\n",
+ " ax = [ax]\n",
+ "\n",
+ " for a in ax:\n",
+ " # Plotting events\n",
+ " this_color = hue if hue is not None else '#4c72b0'\n",
+ " for _, row in events.iterrows():\n",
+ " if color_map is not None:\n",
+ " this_color = color_map[row[hue]]\n",
+ " # Plot the event as a filled region, with zorder to ensure it's behind other elements\n",
+ " a.axvspan(\n",
+ " row[x_start], row[x_end], 0, 1,\n",
+ " color=this_color, alpha=alpha, zorder=0,\n",
+ " transform=a.get_xaxis_transform())\n",
+ "\n",
+ " # Add labels as xticks on the top secondary x-axis\n",
+ " if label:\n",
+ " secax = a.secondary_xaxis('top')\n",
+ " secax.set_xticks(events[x_start])\n",
+ " secax.set_xticklabels(events[label], rotation=0, ha='center')\n",
+ "\n",
+ " # Add legend\n",
+ " if legend:\n",
+ " # Get existing handles from existing legends in the axes\n",
+ " handles, labels = a.get_legend_handles_labels()\n",
+ " if color_map is not None:\n",
+ " handles += [plt.Rectangle((0, 0), 1, 1, color=c, alpha=alpha) for c in color_map]\n",
+ " labels += color_map.index.tolist()\n",
+ " else:\n",
+ " handles += [plt.Rectangle((0, 0), 1, 1, color=this_color, alpha=alpha)]\n",
+ " labels += ['events']\n",
+ " a.legend(handles, labels, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)\n",
+ "\n",
+ " format_xticks(a)\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ "\n",
+ "def prep_to_plot_timeseries(\n",
+ " data: pd.DataFrame,\n",
+ " x_start: str,\n",
+ " x_end: str,\n",
+ " hue: str,\n",
+ " label: str,\n",
+ " participant_id: int,\n",
+ " array_index: int,\n",
+ " time_range: Tuple[str, str],\n",
+ " y_include: Iterable[str],\n",
+ " y_exclude: Iterable[str],\n",
+ " add_columns: Iterable[str]=None,\n",
+ " palette=DEFAULT_PALETTE,\n",
+ ") -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
+ " \"\"\"\n",
+ " Prepare timeseries / events data for plotting.\n",
+ "\n",
+ " Args:\n",
+ " events (pd.DataFrame): The timeseries / events dataframe.\n",
+ " x_start (str): The column name for the start time of the event.\n",
+ " x_end (str): The column name for the end time of the event.\n",
+ " hue (str): The column name for the color of the event.\n",
+ " label (str): The column name for the label of the event.\n",
+ " participant_id (int): The participant ID to filter events by.\n",
+ " array_index (int): The array index to filter events by.\n",
+ " time_range (Iterable[str]): The time range to filter events by.\n",
+ " y_include (Iterable[str]): The list of values to include in the plot.\n",
+ " y_exclude (Iterable[str]): The list of values to exclude from the plot.\n",
+ " add_columns (Iterable[str]): Additional columns to include in the plot.\n",
+ " palette (str): The name of the colormap to use for coloring events.\n",
+ "\n",
+ " Returns:\n",
+ " Tuple[pd.DataFrame, pd.DataFrame]: The filtered events dataframe and the color map.\n",
+ " \"\"\"\n",
+ " if type(add_columns) is str:\n",
+ " add_columns = [add_columns]\n",
+ "\n",
+ " data = format_timeseries(data, participant_id, array_index, time_range, x_start, x_end)\n",
+ "\n",
+ " # Filter events based on y_include and y_exclude\n",
+ " data = data.dropna(subset=[x_start, x_end])\n",
+ " if hue is not None and hue in data.index.names:\n",
+ " data = data.reset_index(hue)\n",
+ " if label is not None and label in data.index.names:\n",
+ " data = data.reset_index(label)\n",
+ " if y_include is not None:\n",
+ " ind = pd.Series(False, index=data.index)\n",
+ " if hue is not None:\n",
+ " ind |= data[hue].isin(y_include)\n",
+ " if label is not None:\n",
+ " ind |= data[label].isin(y_include)\n",
+ " data = data.loc[ind]\n",
+ " if y_exclude is not None:\n",
+ " ind = pd.Series(False, index=data.index)\n",
+ " if hue is not None:\n",
+ " ind |= data[hue].isin(y_exclude)\n",
+ " if label is not None:\n",
+ " ind |= data[label].isin(y_exclude)\n",
+ " data = data.loc[~ind]\n",
+ " if hue is None:\n",
+ " hue = 'hue'\n",
+ " data[hue] = 'events'\n",
+ "\n",
+ " col_list = [x_start, x_end, hue, label]\n",
+ " if add_columns is not None:\n",
+ " col_list += list(add_columns)\n",
+ " col_list = pd.Series(col_list).dropna().drop_duplicates()\n",
+ "\n",
+ " # Set colors\n",
+ " if hue in data.columns:\n",
+ " colors = get_color_map(data, hue, palette)\n",
+ " else:\n",
+ " colors = None\n",
+ "\n",
+ " return data[col_list], colors\n",
+ "\n",
+ "\n",
+ "def get_events_period(\n",
+ " events_filtered: pd.DataFrame,\n",
+ " period_start: str,\n",
+ " period_end: str,\n",
+ " period_name: str,\n",
+ " col: str = 'event',\n",
+ " first_start: bool = True,\n",
+ " first_end: bool = True,\n",
+ " include_start: bool = True,\n",
+ " include_end: bool = True,\n",
+ " x_start: str = 'collection_timestamp',\n",
+ " x_end: str = 'event_end',\n",
+ ") -> pd.DataFrame:\n",
+ " \"\"\"\n",
+ " Get the period of time between the start and end events.\n",
+ "\n",
+ " Args:\n",
+ " events_filtered (pd.DataFrame): The events DataFrame.\n",
+ " period_start (str): The label of the start event.\n",
+ " period_end (str): The label of the end event.\n",
+ " period_name (str): The label to assign to the period.\n",
+ " col (str): The column name for the event labels. Default is 'event'.\n",
+ " first_start (bool): If True, get the first start event. Default is True.\n",
+ " first_end (bool): If True, get the first end event. Default is True.\n",
+ " include_start (bool): If True, include the start event in the period. Default is True.\n",
+ " include_end (bool): If True, include the end event in the period. Default is True.\n",
+ " x_start (str): The column name for the start time of the event. Default is 'collection_timestamp'.\n",
+ " x_end (str): The column name for the end time of the event. Default is 'event_end'.\n",
+ "\n",
+ " Returns:\n",
+ " pd.DataFrame: The period of events in the same format as the input DataFrame.\n",
+ " \"\"\"\n",
+ " events_filtered = format_timeseries(events_filtered, None, None, None, x_start, x_end)\n",
+ "\n",
+ " start_time = events_filtered.loc[\n",
+ " events_filtered[col] == period_start,\n",
+ " x_start if include_start else x_end]\\\n",
+ " .iloc[0 if first_start else -1]\n",
+ " end_time = events_filtered.loc[\n",
+ " events_filtered[col] == period_end,\n",
+ " x_end if include_end else x_start]\\\n",
+ " .iloc[0 if first_end else -1]\n",
+ "\n",
+ " return pd.DataFrame({\n",
+ " x_start: [start_time],\n",
+ " x_end: [end_time],\n",
+ " col: [period_name]\n",
+ " })\n",
+ "\n",
+ "\n",
+ "def get_color_map(data: pd.DataFrame, hue: str, palette: str) -> pd.DataFrame:\n",
+ " \"\"\"\n",
+ " Get a color map for a specific column in the data.\n",
+ "\n",
+ " Args:\n",
+ " data (pd.DataFrame): The data to get the color map from.\n",
+ " hue (str): The column name to use for the color map.\n",
+ " palette (str): The name of the colormap to use.\n",
+ "\n",
+ " Returns:\n",
+ " pd.DataFrame: A DataFrame with the color map.\n",
+ " \"\"\"\n",
+ " colors = sorted(data[hue].unique())\n",
+ " colors = pd.DataFrame({\n",
+ " hue: colors,\n",
+ " 'color': sns.color_palette(palette, len(colors))\n",
+ " }).set_index(hue)['color']\n",
+ "\n",
+ " return colors"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The class `TimeSeriesFigure` provides a user-friendly interface for plotting multiple channels of time series data.\n",
+ "\n",
+ "First, we will load time series DFs from the sleep monitoring dataset. The data includes sleep events, and sensor channels for heart rate, respiratory movement, and oxygen saturation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/ec2-user/projects/pheno-utils/pheno_utils/pheno_loader.py:610: UserWarning: No date field found\n",
+ " warnings.warn(f'No date field found')\n"
+ ]
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "from pheno_utils import PhenoLoader\n",
+ "\n",
+ "pl = PhenoLoader('sleep')\n",
+ "channels_df = pl.load_bulk_data('channels_time_series', pivot='source')\n",
+ "events_df = pl.load_bulk_data('events_time_series')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Any plotting function that accepts an `ax` argument can be used with `TimeSeriesFigure`. The pheno-utils package includes a number of functions that are useful for plotting time series data, such as `plot_events_bars` and `plot_events_fill`, however standard seaborn plotting functions (and others) can also be used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "sns.set_style('whitegrid')\n",
+ "\n",
+ "g = TimeSeriesFigure()\n",
+ "\n",
+ "channels_df = format_timeseries(channels_df).set_index('collection_timestamp')\n",
+ "g.plot(sns.lineplot, channels_df, x='collection_timestamp', y='heart_rate',\n",
+ " name='heart_rate') # Named axis 'heart_rate'\n",
+ "\n",
+ "# You can also use the `sharex` argument to share the x-axis between plots\n",
+ "# Named axes, such as 'heart_rate', can be referred to by name\n",
+ "g.plot(sns.lineplot, channels_df, x='collection_timestamp', y='spo2',\n",
+ " sharex='heart_rate')\n",
+ "\n",
+ "# You can increase the relative height of the plot by passing a `height` argument\n",
+ "g.plot(sns.lineplot, channels_df, x='collection_timestamp', y='respiratory_movement',\n",
+ " sharex='heart_rate', height=1.5)\n",
+ "\n",
+ "# You may add a plot to an existing axes by passing an `ax` argument to the plotting function\n",
+ "# Named axes, such as 'heart_rate', can be referred to by name\n",
+ "stage_events = ['Wake', 'Light Sleep', 'Deep Sleep', 'REM'] # Include only sleep stage events\n",
+ "g.plot(plot_events_fill, events_df, hue='event', y_include=stage_events,\n",
+ " ax='heart_rate')\n",
+ "\n",
+ "apnea_events = ['Resp. Event', 'Desaturation', 'A/H obstructive', 'A/H central', 'A/H unclassified']\n",
+ "g.plot(plot_events_bars, events_df, hue='event', y_include=apnea_events, height=1.5)\n",
+ "\n",
+ "# Control functions to conveniently modify all axes\n",
+ "g.set_periodic_ticks('1h')\n",
+ "g.set_axis_padding(0.05)\n",
+ "g.set_axis_properties(xlabel='')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "import nbdev; nbdev.nbdev_export()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/nbs/16_diet_plots.ipynb b/nbs/16_diet_plots.ipynb
new file mode 100644
index 0000000..a3d108f
--- /dev/null
+++ b/nbs/16_diet_plots.ipynb
@@ -0,0 +1,1105 @@
+{
+ "cells": [
+ {
+ "cell_type": "raw",
+ "metadata": {},
+ "source": [
+ "---\n",
+ "output-file: diet_plots.html\n",
+ "title: Diet logging plots\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| default_exp diet_plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "from nbdev.showdoc import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "from typing import List, Tuple\n",
+ "\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "import seaborn as sns\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.dates as mdates\n",
+ "from matplotlib.ticker import FuncFormatter\n",
+ "import matplotlib.patches as mpatches\n",
+ "import matplotlib.lines as mlines\n",
+ "import matplotlib.patches as Patch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "from pheno_utils.timeseries_plots import format_timeseries, format_xticks, plot_events_bars\n",
+ "from pheno_utils.config import DEFAULT_PALETTE, LEGEND_SHIFT\n",
+ "\n",
+ "\n",
+ "def plot_nutrient_bars(\n",
+ " diet_log: pd.DataFrame, \n",
+ " x: str='collection_timestamp',\n",
+ " label: str='short_food_name',\n",
+ " participant_id: int=None, \n",
+ " array_index: int = None,\n",
+ " time_range: Tuple[str, str]=None, \n",
+ " meals: bool=True,\n",
+ " summary: bool=False,\n",
+ " nut_include: List[str]=None,\n",
+ " nut_exclude: List[str]=None,\n",
+ " agg_units: dict={'kcal': 'sum', 'g': 'sum', 'mg': 'sum'},\n",
+ " legend: bool=True,\n",
+ " bar_width=np.timedelta64(15, 'm'),\n",
+ " palette: str=DEFAULT_PALETTE,\n",
+ " alpha: float=0.7,\n",
+ " ax: plt.Axes=None,\n",
+ " figsize: Tuple[float, float]=(14, 3),\n",
+ "):\n",
+ " \"\"\"\n",
+ " Plot a stacked bar chart representing nutrient intake for each meal over time.\n",
+ "\n",
+ " Args:\n",
+ " diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.\n",
+ " x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.\n",
+ " label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.\n",
+ " participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " array_index (Optional[int]): The array index to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.\n",
+ " meals (bool): If True, includes individual meals in the plot. Default is True.\n",
+ " summary (bool): If True, includes a daily summary in the plot. Default is False.\n",
+ " nut_include (List[str]): A list of nutrients to include in the plot. Default is None.\n",
+ " nut_exclude (List[str]): A list of nutrients to exclude from the plot. Default is None.\n",
+ " agg_units (dict): A dictionary mapping nutrient units to aggregation functions. Only nutrients with units in this dictionary are plotted.\n",
+ " legend (bool): If True, includes a legend in the plot. Default is True.\n",
+ " bar_width (np.timedelta64): The width of the bars representing each meal on the time axis. Default is 15 minutes.\n",
+ " palette (str): The color palette to use for the stacked bars.\n",
+ " alpha (float): The transparency of the stacked bars. Default is 0.7.\n",
+ " ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the bar chart. If None, a new axis is created. Default is None.\n",
+ " figsize (Tuple[float, float]): The size of the figure to create. Default is (14, 3).\n",
+ "\n",
+ " Returns:\n",
+ " None: The function creates a stacked bar chart on the specified or newly created axis.\n",
+ " \"\"\"\n",
+ " # Prepare the data for plotting\n",
+ " df, grouped_nutrients = prepare_meals(\n",
+ " diet_log,\n",
+ " participant_id=participant_id,\n",
+ " array_index=array_index,\n",
+ " time_range=time_range,\n",
+ " label=label,\n",
+ " return_meals=meals,\n",
+ " return_summary=summary,\n",
+ " y_include=nut_include,\n",
+ " y_exclude=nut_exclude,\n",
+ " agg_units=agg_units,\n",
+ " x_col=x,\n",
+ " )\n",
+ "\n",
+ " if ax is None:\n",
+ " fig, ax = plt.subplots(\n",
+ " len(grouped_nutrients), 1,\n",
+ " figsize=(figsize[0], figsize[1] * len(grouped_nutrients)),\n",
+ " sharex=True)\n",
+ " if len(grouped_nutrients) == 1:\n",
+ " ax = [ax]\n",
+ "\n",
+ " colors = sns.color_palette(\n",
+ " palette, sum([len(g) for g in grouped_nutrients.values()]))\n",
+ "\n",
+ " # Calculate the width in time units\n",
+ " bar_width_in_days = bar_width / np.timedelta64(1, 'D')\n",
+ "\n",
+ " unit_list = [g for g in grouped_nutrients if g != 'kcal']\n",
+ " if 'kcal' in grouped_nutrients:\n",
+ " # kcal is last to keep colours synced with the lollipop plot\n",
+ " unit_list.append('kcal')\n",
+ "\n",
+ " # Stacked bar plots for grouped nutrients\n",
+ " c = 0\n",
+ " for idx, unit in enumerate(unit_list):\n",
+ " bottom = pd.Series([0] * len(df))\n",
+ " for nut in grouped_nutrients[unit]:\n",
+ " if nut in ['weight_g']:\n",
+ " continue\n",
+ " ax[idx].bar(\n",
+ " df[x], df[nut], bottom=bottom, width=bar_width_in_days,\n",
+ " color=colors[c], alpha=alpha, label=nut)\n",
+ " bottom += df[nut]\n",
+ " c += 1\n",
+ " ax[idx].set_ylabel(f'Nutrients ({unit})', rotation=0, horizontalalignment='right')\n",
+ " if legend:\n",
+ " ax[idx].legend(loc='upper left', bbox_to_anchor=LEGEND_SHIFT)\n",
+ " ax[idx].grid(True)\n",
+ "\n",
+ " # Set x-tick labels for the bottom and top axes\n",
+ " format_xticks(ax[-1], df[x])\n",
+ " if label is not None:\n",
+ " secax = ax[0].secondary_xaxis('top')\n",
+ " secax.set_xticks(df[x])\n",
+ " secax.set_xticklabels(df[label], ha='center', fontsize=9)\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ "\n",
+ "def plot_nutrient_lollipop(\n",
+ " diet_log: pd.DataFrame, \n",
+ " x: str='collection_timestamp',\n",
+ " y: str='calories_kcal',\n",
+ " size: str='total_g', \n",
+ " label: str='short_food_name',\n",
+ " participant_id: int=None, \n",
+ " array_index: int=None,\n",
+ " time_range: Tuple[str, str]=None, \n",
+ " meals: bool=True,\n",
+ " summary: bool=False,\n",
+ " nut_include: List[str]=None,\n",
+ " nut_exclude: List[str]=None,\n",
+ " legend: bool=True,\n",
+ " size_scale: float=5,\n",
+ " palette: str=DEFAULT_PALETTE,\n",
+ " alpha: float=0.7,\n",
+ " ax: plt.Axes=None,\n",
+ " figsize: Tuple[float, float] = (12, 3),\n",
+ "):\n",
+ " \"\"\"\n",
+ " Plot a lollipop chart with pie charts representing nutrient composition for each meal.\n",
+ "\n",
+ " NOTE: The y-axis is scaled to match the units of the x-axis, to avoid distortion of the pie charts.\n",
+ " Due to scaling, if you intend to change `xlim` after plotting, you must also provide `date_range`.\n",
+ " Use the `second_y` of g.plot() option to plot it with other y-axis data.\n",
+ "\n",
+ " Args:\n",
+ " diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.\n",
+ " x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.\n",
+ " y (str): The name of the column in `diet_log` representing the y-axis variable, such as calories. Default is 'calories_kcal'.\n",
+ " size (str): The name of the column in `diet_log` representing the size of the pie charts. Default is 'total_g'.\n",
+ " label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.\n",
+ " participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.\n",
+ " meals (bool): If True, includes individual meals in the plot. Default is True.\n",
+ " summary (bool): If True, includes a daily summary in the plot. Default is False.\n",
+ " nut_include (List[str]): A list of nutrients to include in the plot. Default is None.\n",
+ " nut_exclude (List[str]): A list of nutrients to exclude from the plot. Default is None.\n",
+ " legend (bool): If True, includes a legend in the plot. Default is True.\n",
+ " size_scale (float): The scaling factor for the size of the pie charts. Default is 5.\n",
+ " palette (str): The color palette to use for the pie slices. Default is DEFAULT_PALETTTE.\n",
+ " alpha (float): The transparency of the pie slices. Default is 0.7.\n",
+ " ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the lollipop chart. If None, a new axis is created. Default is None.\n",
+ " figsize (Tuple[float, float]): The size of the figure to create. Default is (12, 6).\n",
+ "\n",
+ " Returns:\n",
+ " None: The function creates a lollipop plot with pie charts on the specified or newly created axis.\n",
+ " \"\"\"\n",
+ " # Prepare the data for plotting\n",
+ " df, grouped_nutrients = prepare_meals(\n",
+ " diet_log,\n",
+ " participant_id=participant_id,\n",
+ " array_index=array_index,\n",
+ " time_range=time_range,\n",
+ " return_meals=meals,\n",
+ " return_summary=summary,\n",
+ " y_include=nut_include,\n",
+ " y_exclude=nut_exclude,\n",
+ " x_col=x,\n",
+ " )\n",
+ "\n",
+ " if ax is None:\n",
+ " fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)\n",
+ "\n",
+ " # Convert nutrients in mg to grams\n",
+ " for nut in grouped_nutrients['mg']:\n",
+ " df[nut.replace('_mg', '_g')] = df[nut] / 1000\n",
+ " grouped_nutrients['g'] += [nut.replace('_mg', '_g')]\n",
+ "\n",
+ " pie_nuts = [nut for nut in grouped_nutrients['g']\n",
+ " if nut not in ['weight_g']]\n",
+ " df['total_g'] = df[pie_nuts].sum(axis=1)\n",
+ "\n",
+ " # Calculate unknown component and ensure all values are non-negative\n",
+ " df['other_g'] = (df['weight_g'] - df[pie_nuts].sum(axis=1)).clip(lower=0)\n",
+ " # pie_nuts += ['other_g']\n",
+ "\n",
+ " # Pre-set the x-axis limits based on the range of timestamps\n",
+ " if time_range is None:\n",
+ " min_x = mdates.date2num(df[x].min())\n",
+ " max_x = mdates.date2num(df[x].max())\n",
+ " else:\n",
+ " min_x = mdates.date2num(pd.to_datetime(time_range[0]))\n",
+ " max_x = mdates.date2num(pd.to_datetime(time_range[1]))\n",
+ "\n",
+ " # Pre-set the y-axis limits based on the range of the y-axis column\n",
+ " min_y = 0 # df[y_col].min()\n",
+ " max_y = df[y].max()\n",
+ "\n",
+ " # Calculate the aspect ratio between the x and y axes\n",
+ " # This is necessary to avoid distortion of the (circular) pie charts\n",
+ " x_range = max_x - min_x\n",
+ " y_range = max_y - min_y\n",
+ " aspect_ratio = x_range / y_range\n",
+ " y_delta = 0.1 * y_range\n",
+ "\n",
+ " # Scale the y-axis to match the aspect ratio of the x-axis\n",
+ " ax.set_xlim(min_x, max_x)\n",
+ " ax.set_ylim(min_y * aspect_ratio, (max_y + y_delta) * aspect_ratio)\n",
+ "\n",
+ " # Custom formatter to adjust the y-ticks back to the original scale\n",
+ " def ytick_formatter(y, pos):\n",
+ " return f'{y / aspect_ratio:.0f}'\n",
+ "\n",
+ " # Plotting the lollipop plot with pies using absolute figure coordinates\n",
+ " for idx, row in df.iterrows():\n",
+ " # Pie chart parameters\n",
+ " size_value = np.sqrt(row[size]) * aspect_ratio * size_scale\n",
+ " position = mdates.date2num(row[x])\n",
+ " y_value = row[y] * aspect_ratio # Scale y-value\n",
+ "\n",
+ " # Plot the stem (lollipop stick)\n",
+ " ax.plot([position, position], [0, y_value], color='gray', lw=1, zorder=1)\n",
+ "\n",
+ " # Plot the pie chart in figure coordinates (no distortion)\n",
+ " wedges = draw_pie_chart(ax, position, y_value, row[pie_nuts].fillna(0.).values, size_value, palette, alpha)\n",
+ "\n",
+ " if legend:\n",
+ " # Create a custom legend\n",
+ " ax.legend(handles=wedges, labels= pie_nuts, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)\n",
+ "\n",
+ " # Format x-axis to display dates properly\n",
+ " ax.set_ylabel(y.replace('_', ' ').title(), rotation=0, horizontalalignment='right')\n",
+ " ax.grid(True)\n",
+ "\n",
+ " # Set y-ticks and x-ticks\n",
+ " ax.yaxis.set_major_formatter(FuncFormatter(ytick_formatter))\n",
+ " ylim = ax.get_ylim()\n",
+ " yticks = np.arange(0, ylim[1] / aspect_ratio, 100, dtype=int)\n",
+ " ax.set_yticks(yticks * aspect_ratio)\n",
+ " ax.set_yticklabels(yticks)\n",
+ "\n",
+ " format_xticks(ax, df[x])\n",
+ " if label is not None:\n",
+ " secax = ax.secondary_xaxis('top')\n",
+ " secax.set_xticks(df[x])\n",
+ " secax.set_xticklabels(df[label], ha='center', fontsize=9)\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ "\n",
+ "def prepare_meals(\n",
+ " diet_log: pd.DataFrame,\n",
+ " participant_id: int=None,\n",
+ " array_index: int=None,\n",
+ " time_range: Tuple[str, str]=None,\n",
+ " label: str='short_food_name',\n",
+ " return_meals: bool = True,\n",
+ " return_summary: bool = False,\n",
+ " y_include: List[str] = None,\n",
+ " y_exclude: List[str] = None,\n",
+ " agg_units: dict={'kcal': 'sum', 'g': 'sum', 'mg': 'sum', 'unknown': 'first'},\n",
+ " x_col: str='collection_timestamp'\n",
+ ") -> pd.DataFrame:\n",
+ " \"\"\"\n",
+ " Prepare the diet log data for plotting meals and/or daily summaries.\n",
+ "\n",
+ " Args:\n",
+ " diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.\n",
+ " participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " array_index (Optional[int]): The array index to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.\n",
+ " label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.\n",
+ " return_meals (bool): If True, includes individual meals in the plot. Default is True.\n",
+ " return_summary (bool): If True, includes a daily summary in the plot. Default is False.\n",
+ " y_include (List[str]): A list of nutrients (regex) to include in the plot. Default is None.\n",
+ " y_exclude (List[str]): A list of nutrients (regex) to exclude from the plot. Default is None.\n",
+ " agg_units (dict): A dictionary mapping nutrient units to aggregation functions.\n",
+ " x_col (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.\n",
+ "\n",
+ " Returns:\n",
+ " pd.DataFrame: A dataframe containing the prepared data for plotting.\n",
+ " \"\"\"\n",
+ " diet_log = format_timeseries(\n",
+ " diet_log, participant_id, array_index, time_range,\n",
+ " x_start=x_col, x_end=x_col, unique=True)\n",
+ "\n",
+ " units = extract_units(diet_log.columns)\n",
+ " grouped_nutrients = {}\n",
+ " import re # Add this line to import the re module\n",
+ "\n",
+ " agg_dict = {}\n",
+ " for nut, unit in units.items():\n",
+ " if unit not in agg_units:\n",
+ " continue\n",
+ " if y_include is not None and not any([re.match(inc, nut) for inc in y_include]):\n",
+ " continue\n",
+ " if y_exclude is not None and any([re.match(exc, nut) for exc in y_exclude]):\n",
+ " continue\n",
+ " if unit not in grouped_nutrients:\n",
+ " grouped_nutrients[unit] = []\n",
+ " grouped_nutrients[unit].append(nut)\n",
+ " agg_dict[nut] = agg_units[unit]\n",
+ " nut_list = list(agg_dict.keys())\n",
+ " if label is not None:\n",
+ " agg_dict[label] = lambda x: '\\n'.join(x)\n",
+ "\n",
+ " df = diet_log\\\n",
+ " .dropna(subset=['short_food_name'])\\\n",
+ " .drop_duplicates()\\\n",
+ " .groupby([x_col])\\\n",
+ " .agg(agg_dict)\\\n",
+ " .reset_index()\n",
+ "\n",
+ " if return_summary:\n",
+ " # Add daily summary by grouping by date and summing up the nutrients\n",
+ " daily_df = df.groupby(df[x_col].dt.date)[nut_list]\\\n",
+ " .sum().reset_index()\n",
+ " if label is not None:\n",
+ " daily_df[label] = daily_df[x_col].astype('string') + '\\nDaily Summary'\n",
+ " daily_df[x_col] = pd.to_datetime(daily_df[x_col] + pd.Timedelta(hours=24))\n",
+ " if time_range is not None:\n",
+ " daily_df = daily_df[(time_range[0] <= daily_df[x_col]) & (daily_df[x_col] <= time_range[1])]\n",
+ " if return_meals:\n",
+ " # God knows why, but the two refuse to concat without this\n",
+ " df = pd.DataFrame(np.vstack([df, daily_df]), columns=df.columns)\n",
+ " else:\n",
+ " df = daily_df\n",
+ "\n",
+ " return df, grouped_nutrients\n",
+ "\n",
+ "\n",
+ "def extract_units(column_names: List[str]) -> dict:\n",
+ " units = {}\n",
+ " for col in column_names:\n",
+ " if '_' in col:\n",
+ " unit = col.split('_')[-1]\n",
+ " units[col] = unit\n",
+ " else:\n",
+ " units[col] = 'unknown'\n",
+ " return units\n",
+ "\n",
+ "\n",
+ "def draw_pie_chart(\n",
+ " ax: plt.Axes, \n",
+ " x: float, \n",
+ " y: float, \n",
+ " data: List[float], \n",
+ " size: float, \n",
+ " palette: str = DEFAULT_PALETTE,\n",
+ " alpha: float = 0.7,\n",
+ "):\n",
+ " \"\"\"\n",
+ " Draw a pie chart as an inset (in absolute figure coordinates) within the given axes\n",
+ " at the specified data coordinates.\n",
+ " What this solves is the issue of y-axis and x-axis scaling being different, which\n",
+ " distorts the pie chart when drawn directly on the axes.\n",
+ "\n",
+ " Args:\n",
+ " ax (plt.Axes): The axis on which to draw the pie chart.\n",
+ " x (float): The x-coordinate in data coordinates where the pie chart's center will be placed.\n",
+ " y (float): The y-coordinate in data coordinates where the pie chart's center will be placed.\n",
+ " data (List[float]): The data values to be represented in the pie chart.\n",
+ " size (float): The size (radius) of the pie chart in axes-relative coordinates.\n",
+ " palette (str): The color palette to use for the pie slices.\n",
+ "\n",
+ " Returns:\n",
+ " List[plt.Patch]: A list of wedge objects representing the pie chart slices.\n",
+ " \"\"\"\n",
+ " # Convert the position from data coordinates to axes coordinates\n",
+ " axes_coords = ax.transData.transform((x, y))\n",
+ " axes_coords = ax.transAxes.inverted().transform(axes_coords)\n",
+ "\n",
+ " # Create a new inset axis to draw the pie, using axes-relative coordinates\n",
+ " inset_ax = ax.inset_axes([axes_coords[0] - size, axes_coords[1] - size, 2 * size, 2 * size])\n",
+ "\n",
+ " # Plot the pie chart using the calculated position and scaled radius\n",
+ " colors = [(r, g, b, alpha) for r, g, b in sns.color_palette(palette, len(data))]\n",
+ " wedges, _ = inset_ax.pie(data, radius=1, startangle=90, wedgeprops=dict(edgecolor='none'), normalize=True,\n",
+ " colors=colors)\n",
+ "\n",
+ " # Hide the axes for the inset (pie chart)\n",
+ " inset_ax.set_axis_off()\n",
+ "\n",
+ " return wedges\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "SHORT_FOOD_CATEGORIES = {\n",
+ " 'beef, veal, lamb, and other meat products': 'meat products',\n",
+ " 'milk, cream cheese and yogurts': 'milk products',\n",
+ " 'nuts, seeds, and products': 'nuts and seeds',\n",
+ " 'eggs and their products': 'eggs',\n",
+ " 'pulses and products': 'pulses',\n",
+ " 'fruit juices and soft drinks': 'juices and soft drinks',\n",
+ " 'low calories and diet drinks': 'low cal. drinks',\n",
+ " 'poultry and its products': 'poultry',\n",
+ " 'pasta, grains and side dishes': 'grains',\n",
+ " 'industrialized vegetarian food ready to eat': 'industrialized veg.',\n",
+ "}\n",
+ "\n",
+ "def plot_meals_hbars(\n",
+ " diet_log: pd.DataFrame, \n",
+ " x: str='collection_timestamp',\n",
+ " y: str='short_food_category',\n",
+ " size: str='weight_g', \n",
+ " hue: str='short_food_category',\n",
+ " participant_id: int=None, \n",
+ " array_index: int=None,\n",
+ " time_range: Tuple[str, str]=None, \n",
+ " y_include: List[str] = None,\n",
+ " y_exclude: List[str] = None,\n",
+ " rename_categories: dict=SHORT_FOOD_CATEGORIES,\n",
+ " legend: bool=True,\n",
+ " size_legend: List[int]=[100, 200, 500],\n",
+ " size_scale: float=5,\n",
+ " palette: str=DEFAULT_PALETTE,\n",
+ " alpha: float=0.7,\n",
+ " ax: plt.Axes=None,\n",
+ " figsize: Tuple[float, float] = (12, 6),\n",
+ "):\n",
+ " \"\"\"\n",
+ " Plot a diet chart with bars representing meals and their size over time.\n",
+ "\n",
+ " Args:\n",
+ " diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.\n",
+ " x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.\n",
+ " y (str): The name of the column in `diet_log` representing the y-axis variable, such as food categories. Default is 'short_food_category'.\n",
+ " size (str): The name of the column in `diet_log` representing the size of the bars. Default is 'weight_g'.\n",
+ " hue (str): The name of the column in `diet_log` representing the color of the bars. Default is 'short_food_category'.\n",
+ " participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.\n",
+ " time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.\n",
+ " y_include (List[str]): A list of strings representing the categories to include in the plot. Default is None.\n",
+ " y_exclude (List[str]): A list of strings representing the categories to exclude from the plot. Default is None.\n",
+ " rename_categories (dict): A dictionary mapping original food categories to shorter names. Default is SHORT_FOOD_CATEGORIES.\n",
+ " legend (bool): If True, includes a legend in the plot. Default is True.\n",
+ " size_legend (List[int]): A list of integers representing the sizes to include in the size legend. Default is [100, 200, 500].\n",
+ " size_scale (float): The scaling factor for the size of the bars. Default is 5.\n",
+ " palette (str): The palette to use for the bars.\n",
+ " alpha (float): The transparency of the bars. Default is 0.7.\n",
+ " ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the lollipop chart. If None, a new axis is created. Default is None.\n",
+ " figsize (Tuple[float, float]): The size of the figure to create. Default is (12, 6).\n",
+ " \"\"\"\n",
+ " diet_log = format_timeseries(\n",
+ " diet_log, participant_id, array_index,\n",
+ " time_range, x_start=x, x_end=x, unique=True)\n",
+ "\n",
+ " diet_log['event_end'] = diet_log[x] \\\n",
+ " + size_scale * pd.to_timedelta(diet_log[size], unit='s')\n",
+ "\n",
+ " # Categories\n",
+ " diet_log['short_food_category'] = diet_log['food_category'].str.lower()\n",
+ " for s, t in rename_categories.items():\n",
+ " diet_log['short_food_category'] = diet_log['short_food_category'].str.replace(s, t, regex=False)\n",
+ " diet_log['short_food_category'] = diet_log['short_food_category']\\\n",
+ " .str.replace(' and ', ' & ', regex=False)\\\n",
+ " .str.replace('_wholewheat', ' (whole/w)', regex=False)\n",
+ "\n",
+ " # User events plot to plot meals\n",
+ " ax = plot_events_bars(\n",
+ " diet_log,\n",
+ " x_start=x, x_end='event_end',\n",
+ " y=y, hue=hue,\n",
+ " y_include=y_include, y_exclude=y_exclude, alpha=alpha,\n",
+ " ax=ax, figsize=figsize, palette=palette, legend=legend)\n",
+ "\n",
+ " format_xticks(ax, diet_log[x].drop_duplicates())\n",
+ "\n",
+ " add_size_legend(ax, size_legend, size_scale, alpha)\n",
+ "\n",
+ " return ax\n",
+ "\n",
+ "\n",
+ "def add_size_legend(ax: plt.Axes, sizes: List[int], size_scale: float, alpha: float, shift: int=0):\n",
+ " \"\"\"\n",
+ " Add a size legend to a plot_meals_hbars plot using broken_barh.\n",
+ " \"\"\"\n",
+ " if len(sizes) == 0:\n",
+ " return\n",
+ "\n",
+ " # Manually add size legend using broken_barh\n",
+ " sec2day = 1 / (60 * 60 * 24) # Convert seconds to days\n",
+ " size_durations = [\n",
+ " s * size_scale * sec2day\n",
+ " for s in sizes]\n",
+ " max_duration = max(size_durations)\n",
+ "\n",
+ " # Calculate the xlim to place the legend bars right at the end\n",
+ " xlim = ax.get_xlim() # These are in days\n",
+ " y_start_legend = ax.get_ylim()[0] - 1 - shift\n",
+ " x_bar_start = \\\n",
+ " xlim[1] - \\\n",
+ " 1.5 * max_duration\n",
+ "\n",
+ " # Add a bounding box around the text and bars\n",
+ " ax.add_patch(mpatches.Rectangle(\n",
+ " (x_bar_start - 1.5*(max_duration + 10 * sec2day), y_start_legend - len(sizes) + 0.25),\n",
+ " 3 * (max_duration + 10 * sec2day), len(sizes) + 0.5,\n",
+ " edgecolor='gray', facecolor='white', lw=1))\n",
+ "\n",
+ " for i, (s, duration) in enumerate(zip(sizes, size_durations)):\n",
+ " # Plot the bar\n",
+ " ax.broken_barh(\n",
+ " xranges=[(x_bar_start, duration)],\n",
+ " yrange=(y_start_legend - i - 0.4, 0.8), \n",
+ " facecolors='gray', alpha=alpha\n",
+ " )\n",
+ "\n",
+ " # Add text next to the bar\n",
+ " ax.annotate(f'{s}g', \n",
+ " (x_bar_start - 10 * sec2day, y_start_legend - i),\n",
+ " va='center', ha='right', fontsize=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "from pheno_utils.timeseries_plots import TimeSeriesFigure, plot_events_fill\n",
+ "from pheno_utils.sleep_plots import plot_sleep_channels, get_sleep_period\n",
+ "\n",
+ "def plot_diet_cgm_sleep(\n",
+ " diet: pd.DataFrame=None,\n",
+ " cgm: pd.DataFrame=None,\n",
+ " sleep_events: pd.DataFrame=None,\n",
+ " sleep_channels: pd.DataFrame=None,\n",
+ " cgm_grid: List[int] = [0, 54, 70, 100, 140, 180],\n",
+ " channel_filter: List[str]=['heart_rate', 'actigraph', 'spo2'],\n",
+ " participant_id=None,\n",
+ " array_index=None,\n",
+ " time_range: Tuple[str, str]=None,\n",
+ " figsize=(14, 10),\n",
+ " nutrient_kws: dict={},\n",
+ " meals_kws: dict={},\n",
+ " cgm_kws: dict={},\n",
+ " events_kws: dict={},\n",
+ " channels_kws: dict={},\n",
+ ") -> TimeSeriesFigure:\n",
+ " \"\"\"\n",
+ " Plot diet, CGM and sleep data together.\n",
+ "\n",
+ " Arg:\n",
+ " diet (pd.DataFrame): Diet logging data. Set to None to remove from figure.\n",
+ " cgm (pd.DataFrame): CGM data. Set to None to remove from figure.\n",
+ " sleep_events (pd.DataFrame): Sleep events data. Set to None to remove from figure.\n",
+ " sleep_channels (pd.DataFrame): Sleep channels data. Set to None to remove from figure.\n",
+ " cgm_grid (List[int]): CGM grid lines. Default: [0, 54, 70, 100, 140, 180].\n",
+ " channel_filter (List[str]): Which sleep channels to include in the plot. Default: ['heart_rate', 'actigraph', 'spo2'].\n",
+ " participant_id (int): Participant ID.\n",
+ " array_index (int): Array index.\n",
+ " time_range (Tuple[str, str]): Time range to plot.\n",
+ " figsize (Tuple[int, int]): Figure size.\n",
+ " nutrient_kws (dict): Keyword arguments for diet nutrients lollipop plot.\n",
+ " meals_kws (dict): Keyword arguments for diet meals plot.\n",
+ " cgm_kws (dict): Keyword arguments for CGM plot.\n",
+ " events_kws (dict): Keyword arguments for sleep events plot.\n",
+ " channels_kws (dict): Keyword arguments for sleep channels plot.\n",
+ "\n",
+ " Returns:\n",
+ " TimeSeriesFigure: Plot.\n",
+ " \"\"\"\n",
+ " g = TimeSeriesFigure(figsize=figsize)\n",
+ "\n",
+ " # Add diet\n",
+ " if diet is not None:\n",
+ " g.plot(plot_nutrient_lollipop, diet,\n",
+ " second_y=True if cgm is not None else False,\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " size_scale=10, name='diet_glucose', height=1.5, **nutrient_kws)\n",
+ " g.plot(plot_meals_hbars, diet,\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " name='diet_bars', sharex='diet_glucose', height=3, **meals_kws)\n",
+ "\n",
+ " # Add CGM\n",
+ " if cgm is not None:\n",
+ " if diet is None:\n",
+ " g.add_axes(name='diet_glucose')\n",
+ " cgm = format_timeseries(\n",
+ " cgm,\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " )\n",
+ " ax = g.get_axes('diet_glucose', squeeze=True)\n",
+ " ax.plot(cgm['collection_timestamp'], cgm['glucose'], label='glucose', color='#4c72b0', **cgm_kws)\n",
+ " ax.scatter(cgm['collection_timestamp'], cgm['glucose'], s=10, color='#4c72b0', **cgm_kws)\n",
+ " ax.set_ylabel('Glucose', rotation=0, horizontalalignment='right')\n",
+ " ax.set_yticks(cgm_grid)\n",
+ " ax.yaxis.grid(True)\n",
+ "\n",
+ " # Add sleep\n",
+ " if sleep_channels is not None:\n",
+ " plot_sleep_channels(\n",
+ " sleep_channels,\n",
+ " x='collection_timestamp', y='values', row='source', hue=None,\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " y_include=channel_filter,\n",
+ " fig=g, height=1, **channels_kws,\n",
+ " )\n",
+ " if sleep_events is not None:\n",
+ " g.plot(plot_events_fill, sleep_events,\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " y_include=[\"Wake\", \"REM\", \"Light Sleep\", \"Deep Sleep\", \"Sleep\"],\n",
+ " hue='event', ax=['sleep_channels'], sharex='sleep_channels', alpha=0.3, **events_kws)\n",
+ " if cgm is not None or diet is not None:\n",
+ " g.plot(plot_events_fill, get_sleep_period(sleep_events),\n",
+ " participant_id=participant_id, array_index=array_index, time_range=time_range,\n",
+ " y_include=[\"Wake\", \"REM\", \"Light Sleep\", \"Deep Sleep\", \"Sleep\"], legend=False,\n",
+ " hue=None, palette='gray', label='event',\n",
+ " ax=['diet_glucose'], sharex='sleep_channels', alpha=0.3, **events_kws)\n",
+ "\n",
+ " # Tidy up\n",
+ " g.set_axis_padding(0.03)\n",
+ " if time_range is not None:\n",
+ " g.set_time_limits(*time_range)\n",
+ " g.set_periodic_ticks('2H', ax='sleep_channels')\n",
+ "\n",
+ " return g"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This module provides functions for plotting diet data, as well as a function for plotting diet, CGM and sleep data together.\n",
+ "\n",
+ "First, we will load the time series data for diet, CGM and sleep. (See also the dedicated modules for sleep and CGM.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Warning: index is not unique for diet_logging\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/ec2-user/projects/pheno-utils/pheno_utils/pheno_loader.py:610: UserWarning: No date field found\n",
+ " warnings.warn(f'No date field found')\n",
+ "/home/ec2-user/projects/pheno-utils/pheno_utils/pheno_loader.py:610: UserWarning: No date field found\n",
+ " warnings.warn(f'No date field found')\n",
+ "/home/ec2-user/projects/pheno-utils/pheno_utils/pheno_loader.py:610: UserWarning: No date field found\n",
+ " warnings.warn(f'No date field found')\n"
+ ]
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "from pheno_utils import PhenoLoader\n",
+ "\n",
+ "pl = PhenoLoader('sleep')\n",
+ "channels_df = pl.load_bulk_data('channels_time_series') # contains: heart_rate, spo2, respiratory_movement\n",
+ "events_df = pl.load_bulk_data('events_time_series')\n",
+ "\n",
+ "diet_df = PhenoLoader('diet_logging').dfs['diet_logging']\n",
+ "cgm_df = PhenoLoader('cgm').dfs['cgm']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ " glucose\n",
+ "participant_id collection_timestamp \n",
+ "0 2020-06-22 00:14:00+03:00 106.2\n",
+ " 2020-06-22 00:29:00+03:00 100.8\n",
+ " 2020-06-22 00:44:00+03:00 97.2\n",
+ " 2020-06-22 00:59:00+03:00 95.4\n",
+ " 2020-06-22 01:14:00+03:00 93.6"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "cgm_df.head(5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will use `plot_diet_cgm_sleep` to plot the data together."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "import seaborn as sns\n",
+ "sns.set_style('whitegrid')\n",
+ "\n",
+ "plot_diet_cgm_sleep(diet_df, cgm_df, events_df, channels_df,\n",
+ " channel_filter=['heart_rate', 'respiratory_movement', 'spo2'],\n",
+ " time_range=('2020-06-22 08:00', '2020-06-23 10:00'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each of the types of diet plots can be plotted independently as well. We will use `TimeSeriesFigure`, `plot_nutrient_lollipop`, `plot_meals_hbars` and `plot_nutrient_bars` to plot them together."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#| eval: false\n",
+ "from pheno_utils.timeseries_plots import TimeSeriesFigure\n",
+ "\n",
+ "g = TimeSeriesFigure(figsize=(14, 7))\n",
+ "\n",
+ "time_range = ('2020-06-22 06:00', '2020-06-23 15:00')\n",
+ "# Each call to the plot() methods adds a new time-synced subplot to the figure\n",
+ "g.plot(plot_nutrient_lollipop, diet_df, size_scale=15,\n",
+ " time_range=time_range,\n",
+ " name='diet_pie')\n",
+ "g.plot(plot_meals_hbars, diet_df,\n",
+ " time_range=time_range,\n",
+ " name='diet_meals', height=2)\n",
+ "g.plot(plot_nutrient_bars, diet_df,\n",
+ " time_range=time_range,\n",
+ " label=None, n_axes=2, nut_exclude=['sodium'],\n",
+ " name='diet_bars')\n",
+ "g.set_axis_padding(0.03)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "import nbdev; nbdev.nbdev_export()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/nbs/examples/cgm/cgm.parquet b/nbs/examples/cgm/cgm.parquet
index e65f5c5..1a59e82 100644
Binary files a/nbs/examples/cgm/cgm.parquet and b/nbs/examples/cgm/cgm.parquet differ
diff --git a/nbs/examples/diet_logging/diet_logging.parquet b/nbs/examples/diet_logging/diet_logging.parquet
index ae2d2b4..370e4ab 100644
Binary files a/nbs/examples/diet_logging/diet_logging.parquet and b/nbs/examples/diet_logging/diet_logging.parquet differ
diff --git a/nbs/examples/diet_logging/metadata/diet_logging_data_dictionary.csv b/nbs/examples/diet_logging/metadata/diet_logging_data_dictionary.csv
index fd2ea07..2611d14 100644
--- a/nbs/examples/diet_logging/metadata/diet_logging_data_dictionary.csv
+++ b/nbs/examples/diet_logging/metadata/diet_logging_data_dictionary.csv
@@ -1,19 +1,13 @@
tabular_field_name,field_string,description_string,parent_dataframe,relative_location,value_type,units,field_type,array,cohorts,data_type,debut,pandas_dtype,sampling_rate
collection_timestamp,Collection timestamp,Collection timestamp,,diet_logging/diet_logging.parquet,Time,Time,Data,Single,10K,Time Series,2019-01-29,"datetime64[ns, Asia/Jerusalem]",
-collection_date,Date,Datetime column relecting the time food item was logged,,diet_logging/diet_logging.parquet,Time,Time,Data,Single,10K,Time Series,2019-09-01,datetime64[ns],
food_id,Food ID,IDs in the diet logging app representing specific food ,,diet_logging/diet_logging.parquet,Categorical (single) ,None,Data,Single,10K,Time Series,2019-09-01,integer,
-logging_day,Logging day per participant,Integer indicating which day of logging period ,,diet_logging/diet_logging.parquet,Integer,None ,Data,Single,10K,Time Series,2019-09-01,float,
-weight,Weight,Weight of food item logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
short_food_name,Short food name,Classifcation of food item logged into a short food name category,,diet_logging/diet_logging.parquet,Categorical (single) ,None,Data,Single,10K,Time Series,2019-09-01,object,
food_category,Food category,Classifcation of food item logged into a food category,,diet_logging/diet_logging.parquet,Categorical (single) ,None,Data,Single,10K,Time Series,2019-09-01,object,
-product_name,Product name ,Product name of food logged,,diet_logging/diet_logging.parquet,Categorical (single) ,None,Data,Single,10K,Time Series,2019-09-01,object,
-calories,Calories,Calories of food item logged,,diet_logging/diet_logging.parquet,Continuous,kcal,Data,Single,10K,Time Series,2019-09-01,float,
+weight_g,Weight,Weight of food item logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
+calories_kcal,Calories,Calories of food item logged,,diet_logging/diet_logging.parquet,Continuous,kcal,Data,Single,10K,Time Series,2019-09-01,float,
carbohydrate_g,Carbohydrate intake per food logged,Carbohydrate intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
-llipid_g,Fat intake per food logged,Fat intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
+lipid_g,Fat intake per food logged,Fat intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
protein_g,Protein intake per food logged,Protein intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
-sodium_mg ,Sodium intake per food logged,Sodium intake per food logged,,diet_logging/diet_logging.parquet,Continuous,mg,Data,Single,10K,Time Series,2019-09-01,float,
+sodium_mg,Sodium intake per food logged,Sodium intake per food logged,,diet_logging/diet_logging.parquet,Continuous,mg,Data,Single,10K,Time Series,2019-09-01,float,
alcohol_g ,Alcohol intake per food logged,Alcohol intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
dietary_fiber_g,Dietary fiber intake per food logged,Dietary fiber intake per food logged,,diet_logging/diet_logging.parquet,Continuous,g,Data,Single,10K,Time Series,2019-09-01,float,
-local_timestamp,Local timestamp,Local timestamp of food logging,,diet_logging/diet_logging.parquet,Time,Time,Data,Single,10K,Time Series,2019-09-01,datetime64[ns],
-eaten_in_restaurant,Eaten at restaurant indication,Indication if food was eatn at home or at a restaurant,,diet_logging/diet_logging.parquet,Boolean,None,Data,Single,10K,Time Series,2019-09-01,bool,
-total_logging_days,Total number of days logged,Total number of days diet was logged per research stage,,diet_logging/diet_logging.parquet,Integer,None,Data,Single,10K,Time Series,2019-09-01,integer,
diff --git a/nbs/examples/sleep/metadata/sleep_data_dictionary.csv b/nbs/examples/sleep/metadata/sleep_data_dictionary.csv
index f5d6693..55c9018 100644
--- a/nbs/examples/sleep/metadata/sleep_data_dictionary.csv
+++ b/nbs/examples/sleep/metadata/sleep_data_dictionary.csv
@@ -1,3 +1,5 @@
-tabular_field_name,field_string,description_string,parent_dataframe,relative_location,value_type,units,sampling_rate,item_type,array,cohorts,field_type,debut,pandas_dtype
-ahi,AHI,AHI (Apnea-Hypopnea Index),,sleep/sleep.parquet,Continuous,Events / Hour,,Data,Multiple,10K,Continuous,2020-01-15,float64
-total_sleep_time,Total sleep time,Total sleep time,,sleep/sleep.parquet,Integer,Seconds,,Data,Multiple,10K,Continuous,2020-01-15,float64
+tabular_field_name,field_string,description_string,parent_dataframe,relative_location,units,sampling_rate,array,cohorts,field_type,debut,pandas_dtype
+ahi,AHI,AHI (Apnea-Hypopnea Index),,sleep/sleep.parquet,Events / Hour,,Multiple,10K,Continuous,2020-01-15,float
+total_sleep_time,Total sleep time,Total sleep time,,sleep/sleep.parquet,Seconds,,Multiple,10K,Continuous,2020-01-15,int
+channels_time_series,Channels time series,Sensor and derived channels time series parquet files,,sleep/sleep.parquet,,,Multiple,10K,Time series file (individual),2020-01-15,string
+events_time_series,Events time series,"Events during sleep derived from the raw channels, such as sleep stages, respiratory events, pulse rate events, and others",,sleep/sleep.parquet,,Data,Multiple,10K,Time series file (group),2020-01-15,string
\ No newline at end of file
diff --git a/nbs/examples/sleep/sleep.parquet b/nbs/examples/sleep/sleep.parquet
index 02a77e3..0b1d5f3 100644
Binary files a/nbs/examples/sleep/sleep.parquet and b/nbs/examples/sleep/sleep.parquet differ
diff --git a/nbs/examples/sleep/time_series/channels.parquet b/nbs/examples/sleep/time_series/channels.parquet
new file mode 100644
index 0000000..74c77e8
Binary files /dev/null and b/nbs/examples/sleep/time_series/channels.parquet differ
diff --git a/nbs/examples/sleep/time_series/events.parquet b/nbs/examples/sleep/time_series/events.parquet
new file mode 100644
index 0000000..2e91199
Binary files /dev/null and b/nbs/examples/sleep/time_series/events.parquet differ
diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml
index 1213d18..15712ba 100644
--- a/nbs/sidebar.yml
+++ b/nbs/sidebar.yml
@@ -15,6 +15,7 @@ website:
- 01_basic_plots.ipynb
- 02_blandaltman_plots.ipynb
- 03_age_reference_plots.ipynb
+ - 15_timeseries_plots.ipynb
- 04_date_plots.ipynb
- 06_sleep_plots.ipynb
- 08_cgm_plots.ipynb
diff --git a/pheno_utils/cgm_plots.py b/pheno_utils/cgm_plots.py
index 87f6131..8b6e8f9 100644
--- a/pheno_utils/cgm_plots.py
+++ b/pheno_utils/cgm_plots.py
@@ -28,7 +28,7 @@ def __init__(
cgm_date_col: str = "collection_timestamp",
gluc_col: str = "glucose",
diet_date_col: str = "collection_timestamp",
- diet_text_col: str = "shortname_eng",
+ diet_text_col: str = "short_food_name",
ax: Optional[plt.Axes] = None,
smooth: bool = False,
sleep_tuples: Optional[List[Tuple[pd.Timestamp, pd.Timestamp]]] = None,
@@ -39,15 +39,15 @@ def __init__(
Args:
cgm_df (pd.DataFrame): DataFrame containing the glucose measurements.
diet_df (Optional[pd.DataFrame], optional): DataFrame containing the diet data. Defaults to None.
- cgm_date_col (str, optional): Name of the date column in cgm_df. Defaults to "Date".
+ cgm_date_col (str, optional): Name of the date column in cgm_df. Defaults to "collection_timestamp".
gluc_col (str, optional): Name of the glucose column in cgm_df. Defaults to "glucose".
- diet_date_col (str, optional): Name of the date column in diet_df. Defaults to "Date".
- diet_text_col (str, optional): Name of the text column in diet_df. Defaults to "shortname_eng".
+ diet_date_col (str, optional): Name of the date column in diet_df. Defaults to "collection_timestamp".
+ diet_text_col (str, optional): Name of the text column in diet_df. Defaults to "short_food_name".
ax (Optional[plt.Axes], optional): Matplotlib Axes object to plot on. Defaults to None.
smooth (bool, optional): Apply smoothing to the glucose curve. Defaults to False.
sleep_tuples (Optional[List[Tuple[pd.Timestamp, pd.Timestamp]]], optional): List of sleep start and end times. Defaults to None.
"""
- self.cgm_df = cgm_df
+ self.cgm_df = cgm_df.reset_index()
self.diet_df = diet_df
self.cgm_date_col = cgm_date_col
self.gluc_col = gluc_col
@@ -110,7 +110,7 @@ def plot_diet(self) -> None:
for i, (food_datetime, group) in enumerate(
self.diet_df.groupby(self.diet_date_col)
):
- food_str = "\n".join(group[self.diet_text_col])
+ food_str = "\n".join(group[self.diet_text_col].dropna())
txt_x = food_datetime - pd.to_timedelta(7.5, "m")
if i % 2 == 0:
diff --git a/pheno_utils/config.py b/pheno_utils/config.py
index 4e87367..469731c 100644
--- a/pheno_utils/config.py
+++ b/pheno_utils/config.py
@@ -1,11 +1,11 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_config.ipynb.
# %% auto 0
-__all__ = ['REF_COLOR', 'FEMALE_COLOR', 'MALE_COLOR', 'ALL_COLOR', 'GLUC_COLOR', 'FOOD_COLOR', 'DATASETS_PATH', 'COHORT',
- 'EVENTS_DATASET', 'ERROR_ACTION', 'CONFIG_FILES', 'BULK_DATA_PATH', 'PREFERRED_LANGUAGE', 'config_found',
- 'DICT_PROPERTY_PATH', 'DATA_CODING_PATH', 'copy_tre_config', 'get_dictionary_properties_file_path',
- 'get_data_coding_file_path', 'generate_synthetic_data', 'generate_synthetic_data_like',
- 'generate_categorical_synthetic_data']
+__all__ = ['DEFAULT_PALETTE', 'REF_COLOR', 'FEMALE_COLOR', 'MALE_COLOR', 'ALL_COLOR', 'GLUC_COLOR', 'FOOD_COLOR', 'LEGEND_SHIFT',
+ 'TIME_FORMAT', 'DATASETS_PATH', 'COHORT', 'EVENTS_DATASET', 'ERROR_ACTION', 'CONFIG_FILES', 'BULK_DATA_PATH',
+ 'PREFERRED_LANGUAGE', 'config_found', 'DICT_PROPERTY_PATH', 'DATA_CODING_PATH', 'copy_tre_config',
+ 'get_dictionary_properties_file_path', 'get_data_coding_file_path', 'generate_synthetic_data',
+ 'generate_synthetic_data_like', 'generate_categorical_synthetic_data']
# %% ../nbs/00_config.ipynb 3
import os
@@ -16,6 +16,7 @@
from glob import glob
# %% ../nbs/00_config.ipynb 4
+DEFAULT_PALETTE = 'muted'
REF_COLOR = "k"
FEMALE_COLOR = "C1"
MALE_COLOR = "C0"
@@ -24,6 +25,9 @@
GLUC_COLOR = "C0"
FOOD_COLOR = "C1"
+LEGEND_SHIFT = (1.05, 1.05)
+TIME_FORMAT = '%d/%m\n%H:%M'
+
DATASETS_PATH = '/home/ec2-user/studies/hpp_datasets/'
COHORT = None
EVENTS_DATASET = 'events'
@@ -34,8 +38,6 @@
config_found = False
-
-
# %% ../nbs/00_config.ipynb 5
def copy_tre_config():
default_config_found = False
diff --git a/pheno_utils/diet_plots.py b/pheno_utils/diet_plots.py
new file mode 100644
index 0000000..1ca0022
--- /dev/null
+++ b/pheno_utils/diet_plots.py
@@ -0,0 +1,633 @@
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/16_diet_plots.ipynb.
+
+# %% auto 0
+__all__ = ['SHORT_FOOD_CATEGORIES', 'plot_nutrient_bars', 'plot_nutrient_lollipop', 'prepare_meals', 'extract_units',
+ 'draw_pie_chart', 'plot_meals_hbars', 'add_size_legend', 'plot_diet_cgm_sleep']
+
+# %% ../nbs/16_diet_plots.ipynb 3
+from typing import List, Tuple
+
+import pandas as pd
+import numpy as np
+
+import seaborn as sns
+import matplotlib.pyplot as plt
+import matplotlib.dates as mdates
+from matplotlib.ticker import FuncFormatter
+import matplotlib.patches as mpatches
+import matplotlib.lines as mlines
+import matplotlib.patches as Patch
+
+# %% ../nbs/16_diet_plots.ipynb 4
+from .timeseries_plots import format_timeseries, format_xticks, plot_events_bars
+from .config import DEFAULT_PALETTE, LEGEND_SHIFT
+
+
+def plot_nutrient_bars(
+ diet_log: pd.DataFrame,
+ x: str='collection_timestamp',
+ label: str='short_food_name',
+ participant_id: int=None,
+ array_index: int = None,
+ time_range: Tuple[str, str]=None,
+ meals: bool=True,
+ summary: bool=False,
+ nut_include: List[str]=None,
+ nut_exclude: List[str]=None,
+ agg_units: dict={'kcal': 'sum', 'g': 'sum', 'mg': 'sum'},
+ legend: bool=True,
+ bar_width=np.timedelta64(15, 'm'),
+ palette: str=DEFAULT_PALETTE,
+ alpha: float=0.7,
+ ax: plt.Axes=None,
+ figsize: Tuple[float, float]=(14, 3),
+):
+ """
+ Plot a stacked bar chart representing nutrient intake for each meal over time.
+
+ Args:
+ diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.
+ x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.
+ label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.
+ participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.
+ array_index (Optional[int]): The array index to filter the diet log. If None, no filtering is done. Default is None.
+ time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.
+ meals (bool): If True, includes individual meals in the plot. Default is True.
+ summary (bool): If True, includes a daily summary in the plot. Default is False.
+ nut_include (List[str]): A list of nutrients to include in the plot. Default is None.
+ nut_exclude (List[str]): A list of nutrients to exclude from the plot. Default is None.
+ agg_units (dict): A dictionary mapping nutrient units to aggregation functions. Only nutrients with units in this dictionary are plotted.
+ legend (bool): If True, includes a legend in the plot. Default is True.
+ bar_width (np.timedelta64): The width of the bars representing each meal on the time axis. Default is 15 minutes.
+ palette (str): The color palette to use for the stacked bars.
+ alpha (float): The transparency of the stacked bars. Default is 0.7.
+ ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the bar chart. If None, a new axis is created. Default is None.
+ figsize (Tuple[float, float]): The size of the figure to create. Default is (14, 3).
+
+ Returns:
+ None: The function creates a stacked bar chart on the specified or newly created axis.
+ """
+ # Prepare the data for plotting
+ df, grouped_nutrients = prepare_meals(
+ diet_log,
+ participant_id=participant_id,
+ array_index=array_index,
+ time_range=time_range,
+ label=label,
+ return_meals=meals,
+ return_summary=summary,
+ y_include=nut_include,
+ y_exclude=nut_exclude,
+ agg_units=agg_units,
+ x_col=x,
+ )
+
+ if ax is None:
+ fig, ax = plt.subplots(
+ len(grouped_nutrients), 1,
+ figsize=(figsize[0], figsize[1] * len(grouped_nutrients)),
+ sharex=True)
+ if len(grouped_nutrients) == 1:
+ ax = [ax]
+
+ colors = sns.color_palette(
+ palette, sum([len(g) for g in grouped_nutrients.values()]))
+
+ # Calculate the width in time units
+ bar_width_in_days = bar_width / np.timedelta64(1, 'D')
+
+ unit_list = [g for g in grouped_nutrients if g != 'kcal']
+ if 'kcal' in grouped_nutrients:
+ # kcal is last to keep colours synced with the lollipop plot
+ unit_list.append('kcal')
+
+ # Stacked bar plots for grouped nutrients
+ c = 0
+ for idx, unit in enumerate(unit_list):
+ bottom = pd.Series([0] * len(df))
+ for nut in grouped_nutrients[unit]:
+ if nut in ['weight_g']:
+ continue
+ ax[idx].bar(
+ df[x], df[nut], bottom=bottom, width=bar_width_in_days,
+ color=colors[c], alpha=alpha, label=nut)
+ bottom += df[nut]
+ c += 1
+ ax[idx].set_ylabel(f'Nutrients ({unit})', rotation=0, horizontalalignment='right')
+ if legend:
+ ax[idx].legend(loc='upper left', bbox_to_anchor=LEGEND_SHIFT)
+ ax[idx].grid(True)
+
+ # Set x-tick labels for the bottom and top axes
+ format_xticks(ax[-1], df[x])
+ if label is not None:
+ secax = ax[0].secondary_xaxis('top')
+ secax.set_xticks(df[x])
+ secax.set_xticklabels(df[label], ha='center', fontsize=9)
+
+ return ax
+
+
+def plot_nutrient_lollipop(
+ diet_log: pd.DataFrame,
+ x: str='collection_timestamp',
+ y: str='calories_kcal',
+ size: str='total_g',
+ label: str='short_food_name',
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ meals: bool=True,
+ summary: bool=False,
+ nut_include: List[str]=None,
+ nut_exclude: List[str]=None,
+ legend: bool=True,
+ size_scale: float=5,
+ palette: str=DEFAULT_PALETTE,
+ alpha: float=0.7,
+ ax: plt.Axes=None,
+ figsize: Tuple[float, float] = (12, 3),
+):
+ """
+ Plot a lollipop chart with pie charts representing nutrient composition for each meal.
+
+ NOTE: The y-axis is scaled to match the units of the x-axis, to avoid distortion of the pie charts.
+ Due to scaling, if you intend to change `xlim` after plotting, you must also provide `date_range`.
+ Use the `second_y` of g.plot() option to plot it with other y-axis data.
+
+ Args:
+ diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.
+ x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.
+ y (str): The name of the column in `diet_log` representing the y-axis variable, such as calories. Default is 'calories_kcal'.
+ size (str): The name of the column in `diet_log` representing the size of the pie charts. Default is 'total_g'.
+ label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.
+ participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.
+ time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.
+ meals (bool): If True, includes individual meals in the plot. Default is True.
+ summary (bool): If True, includes a daily summary in the plot. Default is False.
+ nut_include (List[str]): A list of nutrients to include in the plot. Default is None.
+ nut_exclude (List[str]): A list of nutrients to exclude from the plot. Default is None.
+ legend (bool): If True, includes a legend in the plot. Default is True.
+ size_scale (float): The scaling factor for the size of the pie charts. Default is 5.
+ palette (str): The color palette to use for the pie slices. Default is DEFAULT_PALETTTE.
+ alpha (float): The transparency of the pie slices. Default is 0.7.
+ ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the lollipop chart. If None, a new axis is created. Default is None.
+ figsize (Tuple[float, float]): The size of the figure to create. Default is (12, 6).
+
+ Returns:
+ None: The function creates a lollipop plot with pie charts on the specified or newly created axis.
+ """
+ # Prepare the data for plotting
+ df, grouped_nutrients = prepare_meals(
+ diet_log,
+ participant_id=participant_id,
+ array_index=array_index,
+ time_range=time_range,
+ return_meals=meals,
+ return_summary=summary,
+ y_include=nut_include,
+ y_exclude=nut_exclude,
+ x_col=x,
+ )
+
+ if ax is None:
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
+
+ # Convert nutrients in mg to grams
+ for nut in grouped_nutrients['mg']:
+ df[nut.replace('_mg', '_g')] = df[nut] / 1000
+ grouped_nutrients['g'] += [nut.replace('_mg', '_g')]
+
+ pie_nuts = [nut for nut in grouped_nutrients['g']
+ if nut not in ['weight_g']]
+ df['total_g'] = df[pie_nuts].sum(axis=1)
+
+ # Calculate unknown component and ensure all values are non-negative
+ df['other_g'] = (df['weight_g'] - df[pie_nuts].sum(axis=1)).clip(lower=0)
+ # pie_nuts += ['other_g']
+
+ # Pre-set the x-axis limits based on the range of timestamps
+ if time_range is None:
+ min_x = mdates.date2num(df[x].min())
+ max_x = mdates.date2num(df[x].max())
+ else:
+ min_x = mdates.date2num(pd.to_datetime(time_range[0]))
+ max_x = mdates.date2num(pd.to_datetime(time_range[1]))
+
+ # Pre-set the y-axis limits based on the range of the y-axis column
+ min_y = 0 # df[y_col].min()
+ max_y = df[y].max()
+
+ # Calculate the aspect ratio between the x and y axes
+ # This is necessary to avoid distortion of the (circular) pie charts
+ x_range = max_x - min_x
+ y_range = max_y - min_y
+ aspect_ratio = x_range / y_range
+ y_delta = 0.1 * y_range
+
+ # Scale the y-axis to match the aspect ratio of the x-axis
+ ax.set_xlim(min_x, max_x)
+ ax.set_ylim(min_y * aspect_ratio, (max_y + y_delta) * aspect_ratio)
+
+ # Custom formatter to adjust the y-ticks back to the original scale
+ def ytick_formatter(y, pos):
+ return f'{y / aspect_ratio:.0f}'
+
+ # Plotting the lollipop plot with pies using absolute figure coordinates
+ for idx, row in df.iterrows():
+ # Pie chart parameters
+ size_value = np.sqrt(row[size]) * aspect_ratio * size_scale
+ position = mdates.date2num(row[x])
+ y_value = row[y] * aspect_ratio # Scale y-value
+
+ # Plot the stem (lollipop stick)
+ ax.plot([position, position], [0, y_value], color='gray', lw=1, zorder=1)
+
+ # Plot the pie chart in figure coordinates (no distortion)
+ wedges = draw_pie_chart(ax, position, y_value, row[pie_nuts].fillna(0.).values, size_value, palette, alpha)
+
+ if legend:
+ # Create a custom legend
+ ax.legend(handles=wedges, labels= pie_nuts, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)
+
+ # Format x-axis to display dates properly
+ ax.set_ylabel(y.replace('_', ' ').title(), rotation=0, horizontalalignment='right')
+ ax.grid(True)
+
+ # Set y-ticks and x-ticks
+ ax.yaxis.set_major_formatter(FuncFormatter(ytick_formatter))
+ ylim = ax.get_ylim()
+ yticks = np.arange(0, ylim[1] / aspect_ratio, 100, dtype=int)
+ ax.set_yticks(yticks * aspect_ratio)
+ ax.set_yticklabels(yticks)
+
+ format_xticks(ax, df[x])
+ if label is not None:
+ secax = ax.secondary_xaxis('top')
+ secax.set_xticks(df[x])
+ secax.set_xticklabels(df[label], ha='center', fontsize=9)
+
+ return ax
+
+
+def prepare_meals(
+ diet_log: pd.DataFrame,
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ label: str='short_food_name',
+ return_meals: bool = True,
+ return_summary: bool = False,
+ y_include: List[str] = None,
+ y_exclude: List[str] = None,
+ agg_units: dict={'kcal': 'sum', 'g': 'sum', 'mg': 'sum', 'unknown': 'first'},
+ x_col: str='collection_timestamp'
+) -> pd.DataFrame:
+ """
+ Prepare the diet log data for plotting meals and/or daily summaries.
+
+ Args:
+ diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.
+ participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.
+ array_index (Optional[int]): The array index to filter the diet log. If None, no filtering is done. Default is None.
+ time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.
+ label (str): The name of the column in `diet_log` representing the labels for each meal. Default is 'short_food_name'.
+ return_meals (bool): If True, includes individual meals in the plot. Default is True.
+ return_summary (bool): If True, includes a daily summary in the plot. Default is False.
+ y_include (List[str]): A list of nutrients (regex) to include in the plot. Default is None.
+ y_exclude (List[str]): A list of nutrients (regex) to exclude from the plot. Default is None.
+ agg_units (dict): A dictionary mapping nutrient units to aggregation functions.
+ x_col (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.
+
+ Returns:
+ pd.DataFrame: A dataframe containing the prepared data for plotting.
+ """
+ diet_log = format_timeseries(
+ diet_log, participant_id, array_index, time_range,
+ x_start=x_col, x_end=x_col, unique=True)
+
+ units = extract_units(diet_log.columns)
+ grouped_nutrients = {}
+ import re # Add this line to import the re module
+
+ agg_dict = {}
+ for nut, unit in units.items():
+ if unit not in agg_units:
+ continue
+ if y_include is not None and not any([re.match(inc, nut) for inc in y_include]):
+ continue
+ if y_exclude is not None and any([re.match(exc, nut) for exc in y_exclude]):
+ continue
+ if unit not in grouped_nutrients:
+ grouped_nutrients[unit] = []
+ grouped_nutrients[unit].append(nut)
+ agg_dict[nut] = agg_units[unit]
+ nut_list = list(agg_dict.keys())
+ if label is not None:
+ agg_dict[label] = lambda x: '\n'.join(x)
+
+ df = diet_log\
+ .dropna(subset=['short_food_name'])\
+ .drop_duplicates()\
+ .groupby([x_col])\
+ .agg(agg_dict)\
+ .reset_index()
+
+ if return_summary:
+ # Add daily summary by grouping by date and summing up the nutrients
+ daily_df = df.groupby(df[x_col].dt.date)[nut_list]\
+ .sum().reset_index()
+ if label is not None:
+ daily_df[label] = daily_df[x_col].astype('string') + '\nDaily Summary'
+ daily_df[x_col] = pd.to_datetime(daily_df[x_col] + pd.Timedelta(hours=24))
+ if time_range is not None:
+ daily_df = daily_df[(time_range[0] <= daily_df[x_col]) & (daily_df[x_col] <= time_range[1])]
+ if return_meals:
+ # God knows why, but the two refuse to concat without this
+ df = pd.DataFrame(np.vstack([df, daily_df]), columns=df.columns)
+ else:
+ df = daily_df
+
+ return df, grouped_nutrients
+
+
+def extract_units(column_names: List[str]) -> dict:
+ units = {}
+ for col in column_names:
+ if '_' in col:
+ unit = col.split('_')[-1]
+ units[col] = unit
+ else:
+ units[col] = 'unknown'
+ return units
+
+
+def draw_pie_chart(
+ ax: plt.Axes,
+ x: float,
+ y: float,
+ data: List[float],
+ size: float,
+ palette: str = DEFAULT_PALETTE,
+ alpha: float = 0.7,
+):
+ """
+ Draw a pie chart as an inset (in absolute figure coordinates) within the given axes
+ at the specified data coordinates.
+ What this solves is the issue of y-axis and x-axis scaling being different, which
+ distorts the pie chart when drawn directly on the axes.
+
+ Args:
+ ax (plt.Axes): The axis on which to draw the pie chart.
+ x (float): The x-coordinate in data coordinates where the pie chart's center will be placed.
+ y (float): The y-coordinate in data coordinates where the pie chart's center will be placed.
+ data (List[float]): The data values to be represented in the pie chart.
+ size (float): The size (radius) of the pie chart in axes-relative coordinates.
+ palette (str): The color palette to use for the pie slices.
+
+ Returns:
+ List[plt.Patch]: A list of wedge objects representing the pie chart slices.
+ """
+ # Convert the position from data coordinates to axes coordinates
+ axes_coords = ax.transData.transform((x, y))
+ axes_coords = ax.transAxes.inverted().transform(axes_coords)
+
+ # Create a new inset axis to draw the pie, using axes-relative coordinates
+ inset_ax = ax.inset_axes([axes_coords[0] - size, axes_coords[1] - size, 2 * size, 2 * size])
+
+ # Plot the pie chart using the calculated position and scaled radius
+ colors = [(r, g, b, alpha) for r, g, b in sns.color_palette(palette, len(data))]
+ wedges, _ = inset_ax.pie(data, radius=1, startangle=90, wedgeprops=dict(edgecolor='none'), normalize=True,
+ colors=colors)
+
+ # Hide the axes for the inset (pie chart)
+ inset_ax.set_axis_off()
+
+ return wedges
+
+
+# %% ../nbs/16_diet_plots.ipynb 5
+SHORT_FOOD_CATEGORIES = {
+ 'beef, veal, lamb, and other meat products': 'meat products',
+ 'milk, cream cheese and yogurts': 'milk products',
+ 'nuts, seeds, and products': 'nuts and seeds',
+ 'eggs and their products': 'eggs',
+ 'pulses and products': 'pulses',
+ 'fruit juices and soft drinks': 'juices and soft drinks',
+ 'low calories and diet drinks': 'low cal. drinks',
+ 'poultry and its products': 'poultry',
+ 'pasta, grains and side dishes': 'grains',
+ 'industrialized vegetarian food ready to eat': 'industrialized veg.',
+}
+
+def plot_meals_hbars(
+ diet_log: pd.DataFrame,
+ x: str='collection_timestamp',
+ y: str='short_food_category',
+ size: str='weight_g',
+ hue: str='short_food_category',
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ y_include: List[str] = None,
+ y_exclude: List[str] = None,
+ rename_categories: dict=SHORT_FOOD_CATEGORIES,
+ legend: bool=True,
+ size_legend: List[int]=[100, 200, 500],
+ size_scale: float=5,
+ palette: str=DEFAULT_PALETTE,
+ alpha: float=0.7,
+ ax: plt.Axes=None,
+ figsize: Tuple[float, float] = (12, 6),
+):
+ """
+ Plot a diet chart with bars representing meals and their size over time.
+
+ Args:
+ diet_log (pd.DataFrame): The dataframe containing the diet log data, with columns for timestamps, nutrients, and other measurements.
+ x (str): The name of the column in `diet_log` representing the x-axis variable, such as timestamps. Default is 'collection_timestamp'.
+ y (str): The name of the column in `diet_log` representing the y-axis variable, such as food categories. Default is 'short_food_category'.
+ size (str): The name of the column in `diet_log` representing the size of the bars. Default is 'weight_g'.
+ hue (str): The name of the column in `diet_log` representing the color of the bars. Default is 'short_food_category'.
+ participant_id (Optional[int]): The participant's ID to filter the diet log. If None, no filtering is done. Default is None.
+ time_range (Optional[Tuple[str, str]]): A tuple of strings representing the start and end dates for filtering the data. Format should be 'YYYY-MM-DD'. Default is None.
+ y_include (List[str]): A list of strings representing the categories to include in the plot. Default is None.
+ y_exclude (List[str]): A list of strings representing the categories to exclude from the plot. Default is None.
+ rename_categories (dict): A dictionary mapping original food categories to shorter names. Default is SHORT_FOOD_CATEGORIES.
+ legend (bool): If True, includes a legend in the plot. Default is True.
+ size_legend (List[int]): A list of integers representing the sizes to include in the size legend. Default is [100, 200, 500].
+ size_scale (float): The scaling factor for the size of the bars. Default is 5.
+ palette (str): The palette to use for the bars.
+ alpha (float): The transparency of the bars. Default is 0.7.
+ ax (Optional[plt.Axes]): The Matplotlib axis on which to plot the lollipop chart. If None, a new axis is created. Default is None.
+ figsize (Tuple[float, float]): The size of the figure to create. Default is (12, 6).
+ """
+ diet_log = format_timeseries(
+ diet_log, participant_id, array_index,
+ time_range, x_start=x, x_end=x, unique=True)
+
+ diet_log['event_end'] = diet_log[x] \
+ + size_scale * pd.to_timedelta(diet_log[size], unit='s')
+
+ # Categories
+ diet_log['short_food_category'] = diet_log['food_category'].str.lower()
+ for s, t in rename_categories.items():
+ diet_log['short_food_category'] = diet_log['short_food_category'].str.replace(s, t, regex=False)
+ diet_log['short_food_category'] = diet_log['short_food_category']\
+ .str.replace(' and ', ' & ', regex=False)\
+ .str.replace('_wholewheat', ' (whole/w)', regex=False)
+
+ # User events plot to plot meals
+ ax = plot_events_bars(
+ diet_log,
+ x_start=x, x_end='event_end',
+ y=y, hue=hue,
+ y_include=y_include, y_exclude=y_exclude, alpha=alpha,
+ ax=ax, figsize=figsize, palette=palette, legend=legend)
+
+ format_xticks(ax, diet_log[x].drop_duplicates())
+
+ add_size_legend(ax, size_legend, size_scale, alpha)
+
+ return ax
+
+
+def add_size_legend(ax: plt.Axes, sizes: List[int], size_scale: float, alpha: float, shift: int=0):
+ """
+ Add a size legend to a plot_meals_hbars plot using broken_barh.
+ """
+ if len(sizes) == 0:
+ return
+
+ # Manually add size legend using broken_barh
+ sec2day = 1 / (60 * 60 * 24) # Convert seconds to days
+ size_durations = [
+ s * size_scale * sec2day
+ for s in sizes]
+ max_duration = max(size_durations)
+
+ # Calculate the xlim to place the legend bars right at the end
+ xlim = ax.get_xlim() # These are in days
+ y_start_legend = ax.get_ylim()[0] - 1 - shift
+ x_bar_start = \
+ xlim[1] - \
+ 1.5 * max_duration
+
+ # Add a bounding box around the text and bars
+ ax.add_patch(mpatches.Rectangle(
+ (x_bar_start - 1.5*(max_duration + 10 * sec2day), y_start_legend - len(sizes) + 0.25),
+ 3 * (max_duration + 10 * sec2day), len(sizes) + 0.5,
+ edgecolor='gray', facecolor='white', lw=1))
+
+ for i, (s, duration) in enumerate(zip(sizes, size_durations)):
+ # Plot the bar
+ ax.broken_barh(
+ xranges=[(x_bar_start, duration)],
+ yrange=(y_start_legend - i - 0.4, 0.8),
+ facecolors='gray', alpha=alpha
+ )
+
+ # Add text next to the bar
+ ax.annotate(f'{s}g',
+ (x_bar_start - 10 * sec2day, y_start_legend - i),
+ va='center', ha='right', fontsize=10)
+
+# %% ../nbs/16_diet_plots.ipynb 6
+from .timeseries_plots import TimeSeriesFigure, plot_events_fill
+from .sleep_plots import plot_sleep_channels, get_sleep_period
+
+def plot_diet_cgm_sleep(
+ diet: pd.DataFrame=None,
+ cgm: pd.DataFrame=None,
+ sleep_events: pd.DataFrame=None,
+ sleep_channels: pd.DataFrame=None,
+ cgm_grid: List[int] = [0, 54, 70, 100, 140, 180],
+ channel_filter: List[str]=['heart_rate', 'actigraph', 'spo2'],
+ participant_id=None,
+ array_index=None,
+ time_range: Tuple[str, str]=None,
+ figsize=(14, 10),
+ nutrient_kws: dict={},
+ meals_kws: dict={},
+ cgm_kws: dict={},
+ events_kws: dict={},
+ channels_kws: dict={},
+) -> TimeSeriesFigure:
+ """
+ Plot diet, CGM and sleep data together.
+
+ Arg:
+ diet (pd.DataFrame): Diet logging data. Set to None to remove from figure.
+ cgm (pd.DataFrame): CGM data. Set to None to remove from figure.
+ sleep_events (pd.DataFrame): Sleep events data. Set to None to remove from figure.
+ sleep_channels (pd.DataFrame): Sleep channels data. Set to None to remove from figure.
+ cgm_grid (List[int]): CGM grid lines. Default: [0, 54, 70, 100, 140, 180].
+ channel_filter (List[str]): Which sleep channels to include in the plot. Default: ['heart_rate', 'actigraph', 'spo2'].
+ participant_id (int): Participant ID.
+ array_index (int): Array index.
+ time_range (Tuple[str, str]): Time range to plot.
+ figsize (Tuple[int, int]): Figure size.
+ nutrient_kws (dict): Keyword arguments for diet nutrients lollipop plot.
+ meals_kws (dict): Keyword arguments for diet meals plot.
+ cgm_kws (dict): Keyword arguments for CGM plot.
+ events_kws (dict): Keyword arguments for sleep events plot.
+ channels_kws (dict): Keyword arguments for sleep channels plot.
+
+ Returns:
+ TimeSeriesFigure: Plot.
+ """
+ g = TimeSeriesFigure(figsize=figsize)
+
+ # Add diet
+ if diet is not None:
+ g.plot(plot_nutrient_lollipop, diet,
+ second_y=True if cgm is not None else False,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ size_scale=10, name='diet_glucose', height=1.5, **nutrient_kws)
+ g.plot(plot_meals_hbars, diet,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ name='diet_bars', sharex='diet_glucose', height=3, **meals_kws)
+
+ # Add CGM
+ if cgm is not None:
+ if diet is None:
+ g.add_axes(name='diet_glucose')
+ cgm = format_timeseries(
+ cgm,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ )
+ ax = g.get_axes('diet_glucose', squeeze=True)
+ ax.plot(cgm['collection_timestamp'], cgm['glucose'], label='glucose', color='#4c72b0', **cgm_kws)
+ ax.scatter(cgm['collection_timestamp'], cgm['glucose'], s=10, color='#4c72b0', **cgm_kws)
+ ax.set_ylabel('Glucose', rotation=0, horizontalalignment='right')
+ ax.set_yticks(cgm_grid)
+ ax.yaxis.grid(True)
+
+ # Add sleep
+ if sleep_channels is not None:
+ plot_sleep_channels(
+ sleep_channels,
+ x='collection_timestamp', y='values', row='source', hue=None,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=channel_filter,
+ fig=g, height=1, **channels_kws,
+ )
+ if sleep_events is not None:
+ g.plot(plot_events_fill, sleep_events,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=["Wake", "REM", "Light Sleep", "Deep Sleep", "Sleep"],
+ hue='event', ax=['sleep_channels'], sharex='sleep_channels', alpha=0.3, **events_kws)
+ if cgm is not None or diet is not None:
+ g.plot(plot_events_fill, get_sleep_period(sleep_events),
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=["Wake", "REM", "Light Sleep", "Deep Sleep", "Sleep"], legend=False,
+ hue=None, palette='gray', label='event',
+ ax=['diet_glucose'], sharex='sleep_channels', alpha=0.3, **events_kws)
+
+ # Tidy up
+ g.set_axis_padding(0.03)
+ if time_range is not None:
+ g.set_time_limits(*time_range)
+ g.set_periodic_ticks('2H', ax='sleep_channels')
+
+ return g
diff --git a/pheno_utils/sleep_plots.py b/pheno_utils/sleep_plots.py
index c6411f8..50029ae 100644
--- a/pheno_utils/sleep_plots.py
+++ b/pheno_utils/sleep_plots.py
@@ -1,18 +1,17 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_sleep_plots.ipynb.
# %% auto 0
-__all__ = ['CHANNELS', 'DEFAULT_CHANNELS', 'COLOR_GROUPS', 'ENUMS', 'CHANNEL_LIMS', 'plot_sleep', 'plot_events', 'plot_channels',
- 'format_xticks', 'get_legend_colors']
+__all__ = ['CHANNELS', 'DEFAULT_CHANNELS', 'COLOR_GROUPS', 'ENUMS', 'CHANNEL_LIMS', 'plot_sleep', 'plot_sleep_channels',
+ 'get_channels_colors', 'get_sleep_period']
# %% ../nbs/06_sleep_plots.ipynb 3
-from typing import Iterable, Optional
+from typing import Iterable, Tuple, List
-import numpy as np
import pandas as pd
-import seaborn as sns
import matplotlib.pyplot as plt
-import matplotlib.dates as mdates
+
+from .timeseries_plots import TimeSeriesFigure, plot_events_bars, get_events_period, format_xticks, prep_to_plot_timeseries, DEFAULT_PALETTE
# %% ../nbs/06_sleep_plots.ipynb 4
CHANNELS = {
@@ -32,6 +31,7 @@
DEFAULT_CHANNELS = ['actigraph', 'pat_infra', 'body_position', 'snore_db', 'heart_rate', 'spo2']
+# Color groups are designed to match events and raw channels
COLOR_GROUPS= {
'actigraph': ['actigraph', 'sleep_stage'],
'general': ['body_position'],
@@ -48,262 +48,197 @@
CHANNEL_LIMS = {'spo2': [0, 100]}
# %% ../nbs/06_sleep_plots.ipynb 5
-def plot_sleep(events: pd.DataFrame, channels: pd.DataFrame,
- array_index: Optional[int] = None,
- trim_to_events: Optional[bool] = True,
- add_events: Optional[pd.DataFrame] = None,
- event_filter: Optional[Iterable[str]] = None,
- channel_filter: Optional[Iterable[str]] = DEFAULT_CHANNELS,
- event_height: float=2, channel_height: float=0.45, width: float=10, aspect: float=0.2,
- style: str='whitegrid',
- xlim: Iterable[float]=None, **kwargs):
+def plot_sleep(
+ events: pd.DataFrame,
+ channels: pd.DataFrame,
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ event_filter: Iterable[str]=None,
+ channel_filter: Iterable[str]=DEFAULT_CHANNELS,
+ event_height: float=1,
+ channel_height: float=0.5,
+ padding: float=-0.02,
+ figsize: Tuple[float, float]=None,
+ palette: str=DEFAULT_PALETTE,
+) -> TimeSeriesFigure:
"""
Plot sleep events and channels data.
Args:
-
- events (pd.DataFrame): A pandas dataframe containing sleep events data.
- channels (pd.DataFrame): A pandas dataframe containing raw channels data.
- array_index (int, optional): The index of the array. Defaults to None.
- trim_to_events (bool, optional): Whether to trim the plot to the start and end of the events. Defaults to True.
- add_events (pd.DataFrame, optional): Additional events data to include in the plot. Defaults to None.
- event_filter (Iterable[str], optional): A list of events to include in the plot. Defaults to None.
- channel_filter (Iterable[str], optional): A list of channels to include in the plot. Defaults to DEFAULT_CHANNELS.
- event_height (float, optional): The height of the event plot in inches. Defaults to 2.
- channel_height (float, optional): The height of each channel plot in inches. Defaults to 0.45.
- width (float, optional): The width of the plot in inches. Defaults to 10.
- aspect (float, optional): The aspect ratio of the plot. Defaults to 0.2.
- style (str, optional): The seaborn style to use. Defaults to 'whitegrid'.
- xlim (List[float], optional): The x-axis limits of the plot. Defaults to None.
- **kwargs: Additional arguments to be passed to plot_channels().
+ events (pd.DataFrame): The sleep events dataframe.
+ channels (pd.DataFrame): The sleep channels dataframe.
+ participant_id (int): The participant id to filter the data.
+ array_index (int): The array index to filter the data.
+ time_range (Tuple[str, str]): The time range to filter the data.
+ event_filter (Iterable[str]): The events to include in the plot.
+ channel_filter (Iterable[str]): The channels to include in the plot.
+ event_height (float): The relative height of the events subplot.
+ channel_height (float): The relative height of each channel's subplot.
+ padding (float): The padding between subplots.
+ figsize (Tuple[float, float]): The size of the figure.
+ palette (str): The color palette to use.
Returns:
-
- None
+ TimeSeriesFigure: The figure with the sleep events and channels data.
"""
- nC = min([len(channel_filter), channels.index.get_level_values('source').nunique()])
- if xlim is not None:
- trim_to_events = True
-
- fig, ax = plt.subplots(nrows=nC+1, ncols=1, sharex=True, squeeze=False,
- height_ratios=nC*[channel_height] + [event_height],
- figsize=(width, width*aspect*(nC*channel_height + event_height)),
- facecolor='white')
-
- with sns.axes_style(style):
- try:
- plot_channels(channels, array_index=array_index, y_filter=channel_filter, ax=ax[:-1],
- **kwargs)
- plot_events(events, array_index=array_index, y_include=event_filter, y_exclude=['Gross Trunc'],
- ax=ax[-1,0],
- set_xlim=trim_to_events, add_events=add_events,
- xlim=xlim)
- except Exception as err:
- print(f'plot_channels failed due to:\n{err}')
- fig.clf()
- plot_events(events, array_index, y_include=event_filter, y_exclude=['Gross Trunc'],
- set_xlim=trim_to_events, add_events=add_events,
- xlim=xlim)
-
-
-def plot_events(events: pd.DataFrame, array_index: Optional[int]=None,
- x_start: str='collection_timestamp', x_end: str='event_end',
- y: str='event', color: str='channel', cmap: str='muted', set_xlim: bool=True,
- xlim: Iterable[float]=None, figsize: Iterable[float]=[10, 4],
- y_include: Optional[Iterable[str]] = None,
- y_exclude: Optional[Iterable[str]] = None,
- ax: plt.Axes=None,
- add_events: Optional[pd.DataFrame] = None,
- rename_channels: dict={'PAT Amplitude': 'PAT', 'PulseRate': 'Heart Rate'},
- rename_events: dict={}):
- """ plot an events timeline for a given participant and array_index """
- # slice to night
- if (array_index is not None) and (('array_index' in events.columns) or ('array_index' in events.index.names)):
- plot_df = events.query('array_index == @array_index').copy()
- else:
- plot_df = events.copy()
- # extract start and end times
- if x_start in plot_df.index.names:
- plot_df = plot_df.reset_index(x_start)
- if x_end in plot_df.index.names:
- plot_df = plot_df.reset_index(x_end)
- # remove timezone for correct matplotlib labeling
- plot_df[x_start] = plot_df[x_start].dt.tz_localize(None)
- plot_df[x_end] = plot_df[x_end].dt.tz_localize(None)
-
- # filter events
- if y_include is not None:
- plot_df = plot_df.query(f'{y} in {y_include}')
- if y_exclude is not None:
- plot_df = plot_df.query(f'{y} not in {y_exclude}')
- # additional user-provided events (application logging, etc.)
- if add_events is not None:
- tlim = plot_df[x_start].min(), plot_df[x_end].max()
- add_events = add_events.loc[
- (tlim[0] < add_events[x_end]) & (add_events[x_start] < tlim[1])]
- if len(add_events):
- add_events = add_events.set_index(plot_df.index[[0]])
- plot_df = pd.concat([plot_df, add_events[
- plot_df.columns.intersection(add_events.columns)]], axis=0)
-
- # rename channels and events
- plot_df = plot_df.copy()
- plot_df['channel'] = plot_df['channel'].replace(rename_channels)
- plot_df['event'] = plot_df['event'].replace(rename_events)
-
- # set x limits
- if xlim is not None:
- if type(xlim[0]) is str:
- xlim = (pd.to_datetime(xlim[0]), xlim[1])
- if type(xlim[0]) is not pd.Timestamp:
- xlim = plot_df.loc[plot_df['start'] < xlim[0], x_start].iloc[-1], xlim[1]
- if type(xlim[1]) is str:
- xlim = (xlim[0], pd.to_datetime(xlim[1]))
- if type(xlim[1]) is not pd.Timestamp:
- xlim = xlim[0], plot_df.loc[plot_df['end'] > xlim[1], x_end].iloc[0]
- else:
- xlim = plot_df[x_start].min(), plot_df[x_end].max()
-
- if ax is None:
- fig, ax = plt.subplots(figsize=figsize)
-
- # set colors
- colors = sorted(plot_df[color].unique())
- colors = pd.DataFrame({color: colors, 'color': sns.color_palette(cmap, len(colors))})\
- .set_index(color)['color']
-
- # plot events
- plot_df = plot_df.assign(diff=lambda x: x[x_end] - x[x_start]).sort_values([color, y])
- labels = []
- legend = []
- for i, (y_label, y) in enumerate(plot_df.groupby(y, observed=True, sort=False)):
- if len(y) == 0:
- continue
- labels.append(y_label)
- for c, r in y.groupby(color, observed=True):
- data = r[[x_start, 'diff']]
- if not len(data):
- continue
- h = ax.broken_barh(data.values, (i-0.4,0.8), color=colors[c], alpha=0.7)
- legend.append({'label': c, 'handle': h})
-
- # format plot
- legend = pd.DataFrame.from_dict(legend).drop_duplicates(subset='label')
- ax.legend(legend['handle'], legend['label'],
- bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
-
- str_title = ''
- if 'participant_id' in events.index.names:
- str_title += events.index.get_level_values('participant_id')[0].astype(str)
- if array_index is not None:
- str_title += f' / {array_index}'
- plt.suptitle(str_title, fontsize=14, weight='bold')
- ax.set_yticks(np.arange(len(labels)), labels)
- plt.tight_layout()
- ax.set_xlabel('Time')
- if set_xlim:
- ax.set_xlim(*xlim)
- format_xticks(ax)
-
- return ax
-
-
-def plot_channels(channels: pd.DataFrame, array_index: Optional[int]=None,
- y_filter: Optional[Iterable[str]]=None, ax: plt.Axes=None,
- discrete_events: Optional[Iterable[str]]=['sleep_stage', 'body_position'],
- time_col='collection_timestamp', height=1.5, resample='1s', cmap='muted',
- rename_channels=CHANNELS, **kwargs):
+ # Create figure
+ if figsize is None:
+ if 'source' in channels.index.names:
+ nC = channels.index.get_level_values('source').nunique()
+ else:
+ nC = channels['source'].nunique()
+ figsize = 2 * (nC * channel_height + event_height)
+ figsize = (8 * 2 * channel_height, figsize)
+ g = TimeSeriesFigure(figsize=figsize)
+
+ # Set colors
+ channels, color_map = get_channels_colors(
+ channels, events,
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ event_filter=event_filter, palette=palette,
+ )
+
+ # Plot
+ plot_sleep_channels(
+ channels,
+ x='collection_timestamp', y='values', row='source', hue='channel_group',
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=channel_filter,
+ fig=g, height=channel_height,
+ color_map=color_map, palette=palette,
+ )
+ g.plot(plot_events_bars,
+ events,
+ x_start='collection_timestamp', x_end='event_end', y='event', hue='channel',
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=event_filter, y_exclude=['Gross Trunc'],
+ palette=palette,
+ name='sleep_events', height=event_height, sharex='sleep_channels',
+ )
+ g.set_axis_padding(padding)
+
+ return g
+
+
+def plot_sleep_channels(
+ channels: pd.DataFrame,
+ x: str='collection_timestamp',
+ y: str='values',
+ row: str='source',
+ hue: str='channel_group',
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ y_include: Iterable[str]=None,
+ y_exclude: Iterable[str]=None,
+ rename_channels: dict=CHANNELS,
+ discrete_events: Iterable[str]=['sleep_stage', 'body_position'],
+ resample: str='1s',
+ color_map: pd.Series=None,
+ palette: str=DEFAULT_PALETTE,
+ fig: TimeSeriesFigure=None,
+ ax: List[plt.Axes]=None,
+ height=1,
+ **kwargs
+) -> List[plt.Axes]:
""" plot channels data for a given participant and array_index """
- # set colors
- colors = get_legend_colors(cmap).explode('source')
- colors['source'] = pd.Categorical(colors['source'])
- colors = colors.set_index('source')
-
- # filter data
- if (array_index is not None) and (('array_index' in channels.columns) or ('array_index' in channels.index.names)):
- data = channels.query('array_index == @array_index').copy()
- else:
- data = channels.copy()
- # extract time and channel name
- if time_col in channels.index.names:
- data = data.reset_index(time_col)
- if 'source' not in data.index.names:
- data = data.set_index('source')
- data[time_col] = data[time_col].dt.tz_localize(None)
-
- # grouping and coloring sources by event "channels"
- data = data.join(colors[['channel']])\
- .sort_values(['channel', time_col], ascending=[False, True])
-
- if ax is None:
- n = data.index.unique().size
- fig, ax = plt.subplots(nrows=n, figsize=(10, n*height), sharex=True, squeeze=False)
-
- # plot data
- ax_shift = 0
- for i, (source, d) in enumerate(data.groupby('source', observed=True, sort=False)):
- if (source not in CHANNELS) or (y_filter is not None and source not in y_filter):
- print(f'plot_channels: skipping {source}')
- ax_shift += 1
- continue
- iax = i - ax_shift
- if resample is not None:
- d = d.resample(resample, on=time_col).mean(numeric_only=True).reset_index()
- if source in colors.index:
- c = colors.loc[source, 'color']
+ # Filter data and prepare channels
+ channels, colors = prep_to_plot_timeseries(
+ channels, x, x, row, row,
+ participant_id, array_index, time_range,
+ y_include, y_exclude,
+ add_columns=[y, hue], palette=palette
+ )
+ if color_map is not None:
+ colors = color_map
+
+ # Create axes if necessary
+ n = channels[row].nunique()
+ if ax is None and fig is None:
+ fig, ax = plt.subplots(nrows=n, figsize=(12, n*height), sharex=True, squeeze=True)
+ elif ax is None:
+ ax = fig.add_axes(n_axes=n, height=height, name='sleep_channels')
+
+ # Plot data
+ for i, (source, d) in enumerate(channels.groupby(row, observed=True, sort=False)):
+ if colors is not None and hue is not None:
+ c = colors.get(d[hue].iloc[0], 'grey')
else:
- c = 'grey'
+ c = '#4c72b0'
+ if resample is not None:
+ d = d.resample(resample, on=x).mean(numeric_only=True).reset_index()
+
+ # Set channel value limits
if source in CHANNEL_LIMS:
- d = d.loc[(CHANNEL_LIMS[source][0] <= d['values']) & (d['values'] <= CHANNEL_LIMS[source][1])]
- ax[iax, 0].scatter(d[time_col].dt.tz_localize(None).values, d['values'].values, s=0.1, color=c)
+ d = d.loc[
+ (CHANNEL_LIMS[source][0] <= d[y]) &
+ (d[y] <= CHANNEL_LIMS[source][1])
+ ]
+ ax[i].scatter(d[x].values, d[y].values, s=0.1, color=c)
if source not in CHANNEL_LIMS:
- ylim = d['values'].quantile([0.001, 0.999]).tolist()
+ ylim = d[y].quantile([0.001, 0.999]).tolist()
ylim[0] = 0.95*ylim[0] if ylim[0] >= 0 else 1.1*ylim[0]
ylim[1] = 1.1*ylim[1] if ylim[1] >= 0 else 0.95*ylim[1]
- ax[iax,0].set_ylim(*ylim)
+ ax[i].set_ylim(*ylim)
+
if source in rename_channels:
- ax[iax, 0].set_ylabel(rename_channels[source], rotation=0, horizontalalignment='right')
+ ax[i].set_ylabel(rename_channels[source], rotation=0, horizontalalignment='right')
else:
- ax[iax,0].set_ylabel(source, rotation=0, horizontalalignment='right')
+ ax[i].set_ylabel(source, rotation=0, horizontalalignment='right')
if source in discrete_events:
if source in ENUMS:
- ax[iax, 0].set_yticks(ENUMS[source][0],labels=ENUMS[source][1])
+ ax[i].set_yticks(ENUMS[source][0],labels=ENUMS[source][1])
else:
- ax[iax, 0].set_yticks(d['values'].drop_duplicates().sort_values().values)
- ylabels = ax[iax, 0].get_yticklabels()
+ ax[i].set_yticks(d[y].drop_duplicates().sort_values().values)
+ ylabels = ax[i, 0].get_yticklabels()
for label in ylabels[1:-1]:
label.set_text('')
- ax[iax, 0].set_yticklabels(ylabels)
-
- # format plot
- for i in range(len(ax)):
- ax[i,0].set_xlabel('')
- ax[i,0].set_xticklabels([])
- if ax is None:
- print('entered')
- ax[-1,0].set_xlabel('Time')
- ax[-1,0].set_xlim(data[time_col].min(), data[time_col].max())
- format_xticks(ax[-1,0])
+ ax[i].set_yticklabels(ylabels)
+ format_xticks(ax[-1])
return ax
-def format_xticks(ax, format='%m/%d %H:%M'):
- """ format datestrings on x axis """
- xticks = ax.get_xticks()
- ax.set_xticks(xticks)
- ax.set_xticklabels(xticks, rotation=25, ha='right')
- xfmt = mdates.DateFormatter(format)
- ax.xaxis.set_major_formatter(xfmt)
-
-
-def get_legend_colors(cmap='muted'):
- # the following dict keys should correspond to sleep event channels
- colors = pd.Series(COLOR_GROUPS, name='source').to_frame().reset_index()\
- .rename(columns={'index': 'channel'}).sort_values('channel')
- colors['color'] = sns.color_palette(cmap, len(colors))
+def get_channels_colors(
+ channels: pd.DataFrame,
+ events: pd.DataFrame,
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ event_filter: Iterable[str]=None,
+ palette: str=DEFAULT_PALETTE,
+) -> Tuple[pd.DataFrame, pd.Series]:
+ # Group channels like events do
+ channel_groups = pd.Series(COLOR_GROUPS, name='source')\
+ .reset_index()\
+ .rename(columns={'index': 'channel_group'})\
+ .explode('source').set_index('source')
+ channels = channels.join(channel_groups)
+ # Simulate events colors
+ _, color_map = prep_to_plot_timeseries(
+ events,
+ x_start='collection_timestamp', x_end='event_end', label='event', hue='channel',
+ participant_id=participant_id, array_index=array_index, time_range=time_range,
+ y_include=event_filter, y_exclude=['Gross Trunc'],
+ palette=palette)
+
+ return channels, color_map
+
+# %% ../nbs/06_sleep_plots.ipynb 6
+def get_sleep_period(events: pd.DataFrame) -> pd.DataFrame:
+ """
+ Get the sleep period from the sleep events dataframe.
- return colors
+ Args:
+ events (pd.DataFrame): The sleep events dataframe.
+ Returns:
+ pd.DataFrame: The sleep period dataframe.
+ """
+ return events.groupby(['participant_id', 'research_stage', 'array_index'])\
+ .apply(get_events_period, 'Wake', 'Wake', 'Sleep',
+ first_start=True, first_end=False, include_start=False, include_end=False)
+
diff --git a/pheno_utils/timeseries_plots.py b/pheno_utils/timeseries_plots.py
new file mode 100644
index 0000000..22e1eb0
--- /dev/null
+++ b/pheno_utils/timeseries_plots.py
@@ -0,0 +1,763 @@
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/15_timeseries_plots.ipynb.
+
+# %% auto 0
+__all__ = ['TimeSeriesFigure', 'format_xticks', 'format_timeseries', 'plot_events_bars', 'plot_events_fill',
+ 'prep_to_plot_timeseries', 'get_events_period', 'get_color_map']
+
+# %% ../nbs/15_timeseries_plots.ipynb 3
+from typing import Callable, Iterable, Optional, Union, Tuple
+import warnings
+
+import numpy as np
+import pandas as pd
+
+import seaborn as sns
+import matplotlib.pyplot as plt
+import matplotlib.dates as mdates
+
+from .config import DEFAULT_PALETTE, TIME_FORMAT, LEGEND_SHIFT
+
+# %% ../nbs/15_timeseries_plots.ipynb 4
+class TimeSeriesFigure:
+ def __init__(self, figsize: tuple = (10, 6), padding: float = 0.05):
+ """
+ Initialize a TimeSeriesFigure instance. This class is used to create and manage
+ a figure with multiple axes for time series data.
+
+ Args:
+ figsize (tuple): Size of the figure (width, height) in inches.
+ """
+ self.fig = plt.figure(figsize=figsize)
+ self.axes: Iterable[tuple] = []
+ self.axis_names: dict = {}
+ self.padding = padding
+ self.custom_paddings = {} # To store custom padding for specific axes
+ self.shared_x_groups = [] # To keep track of shared x-axis groups
+
+ def plot(
+ self,
+ plot_function: Callable,
+ *args,
+ n_axes: int = 1,
+ height: float = 1,
+ sharex: Union[str, int, plt.Axes] = None,
+ second_y: bool = False,
+ name: str = None,
+ ax: Union[str, int, plt.Axes] = None,
+ adjust_time: Optional[str] = 'union',
+ adjust_by_axis: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None,
+ **kwargs
+ ) -> Union[plt.Axes, Iterable[plt.Axes]]:
+ """
+ Plot using a dataset-specific function, creating a new axis if needed.
+ The plot function should accept the axis object as the argument `ax`, or
+ a list of axes if multiple axes are used.
+
+ Args:
+ plot_function (Callable): The dataset-specific function to plot the data.
+ *args: Arguments to pass to the plot function.
+ n_axes (int): The number of axes required. Default is 1.
+ height (float): The proportional height of the axes relative to a single unit axis.
+ sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.
+ second_y (bool): If True, plot will be done on a secondary y-axis in the plot. Default is False.s
+ name (str): Name or ID to assign to the axis.
+ ax (plt.Axes, str, int): Pre-existing axis (object, name, or index) or list of axes to plot on.
+ adjust_time (str, None): Method to adjust the time limits of all axes to match the data.
+ adjust_by_axis (str, int, plt.Axes): Axes (single or multiple) to use as a reference for adjusting the time limits.
+ **kwargs: Keyword arguments to pass to the plot function.
+
+ Returns:
+ Union[plt.Axes, Iterable[plt.Axes]]: A single axis object or a list of axis objects if multiple axes are used.
+ """
+ if ax is None:
+ ax = self.add_axes(height=height, n_axes=n_axes, sharex=sharex, name=name)
+ else:
+ ax = self.get_axes(ax, squeeze=True)
+
+ if second_y:
+ ax.yaxis.grid(False)
+ ax = ax.twinx()
+
+ plot_function(*args, ax=ax, **kwargs)
+ if adjust_time:
+ self.set_time_limits(None, None, method=adjust_time, reference_axis=adjust_by_axis)
+ if second_y:
+ ax.yaxis.grid(False)
+ ax.yaxis.label.set_rotation(90)
+ ax.yaxis.label.set_ha('center')
+
+ return ax
+
+ def add_axes(
+ self,
+ height: float = 1,
+ n_axes: int = 1,
+ sharex: Optional[Union[str, int, plt.Axes]] = None,
+ name: Optional[str] = None,
+ ) -> Union[plt.Axes, Iterable[plt.Axes]]:
+ """
+ Add one or more axes with a specific proportional height to the figure.
+
+ Args:
+ height (float): The proportional height of each new axis relative to a single unit axis.
+ n_axes (int): The number of axes to create.
+ sharex (str, int, or plt.Axes): Index or name of the axis to share the x-axis with. If None, the x-axis is independent.
+ name (Optional[str]): Name or ID to assign to the axis (only valid if num_axes=1).
+
+ Returns:
+ Union[plt.Axes, Iterable[plt.Axes]]: A single axis object or a list of axis objects if multiple axes are created.
+ """
+ new_axes = []
+ shared_group = []
+
+ if sharex is not None:
+ sharex = self.get_axes(sharex)[0]
+ shared_group.append(sharex)
+
+ for _ in range(n_axes):
+ ax = self.fig.add_subplot(len(self.axes) + 1, 1, len(self.axes) + 1, sharex=sharex)
+ new_axes.append(ax)
+ self.axes.append((ax, height))
+ shared_group.append(ax)
+ # When creating mulitple axes, always share their x-axis with the first one
+ if sharex is None:
+ sharex = ax
+
+ if shared_group:
+ self.shared_x_groups.append(shared_group)
+
+ if name is not None:
+ self.axis_names[name] = new_axes
+
+ self._adjust_axes()
+
+ return new_axes if n_axes > 1 else new_axes[0]
+
+ def _adjust_axes(self) -> None:
+ """
+ Adjust the positions and sizes of all axes based on their proportional height and apply padding.
+ """
+ total_height = sum(height for _, height in self.axes)
+ total_padding = self.padding * (len(self.axes) - 1)
+ bottom = 1 - total_padding # Start from the top of the figure
+
+ for i, (ax, height) in enumerate(self.axes):
+ ax_height = height / total_height * (1 - total_padding)
+ # Adjust for any custom padding before this axis
+ custom_pad = self.custom_paddings.get(i, 0)
+ ax.set_position([0.1, bottom - ax_height, 0.8, ax_height])
+ bottom -= ax_height + self.padding + custom_pad # Move down, considering padding
+
+ def _get_axis_by_name(self, name: str) -> Optional[plt.Axes]:
+ """
+ Retrieve an axis by its name or ID.
+
+ Args:
+ name (str): The name or ID of the axis to retrieve.
+
+ Returns:
+ Optional[plt.Axes]: The corresponding axis object if found, otherwise None.
+ """
+ return self.axis_names.get(name, [])
+
+ def get_axes(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, squeeze=False) -> Iterable[plt.Axes]:
+ """
+ Retrieve the axis object(s) based on the input type.
+
+ Args:
+ ax: The axis object, index, name, or list of those to retrieve.
+ squeeze (bool): Whether to return a single axis object if only one is found.
+
+ Returns:
+ Iterable[plt.Axes]: A list of axis objects.
+ """
+ if ax is None:
+ return [a for a, _ in self.axes]
+ elif not isinstance(ax, list):
+ ax = [ax]
+
+ ax_list = []
+ for a in ax:
+ if isinstance(a, str):
+ by_name = self._get_axis_by_name(a)
+ if len(by_name) == 0:
+ warnings.warn(f"No axis found with name '{a}'")
+ ax_list.extend(by_name)
+ elif isinstance(a, int):
+ ax_list.append(self.axes[a][0])
+
+ if squeeze and len(ax_list) == 1:
+ return ax_list[0]
+ else:
+ return ax_list
+
+ def print_shared_axes(self):
+ """
+ Print which axes in the figure share their x-axis.
+
+ Returns:
+ None
+ """
+ shared_groups = {}
+ for i, (ax, _) in enumerate(self.axes):
+ for j, (other_ax, _) in enumerate(self.axes):
+ if i != j and ax.get_shared_x_axes().joined(ax, other_ax):
+ if i not in shared_groups:
+ shared_groups[i] = []
+ shared_groups[i].append(j)
+
+ for ax_idx, shared_with in shared_groups.items():
+ print(f"Axis {ax_idx} shares its x-axis with: {shared_with}")
+
+ def get_axis_properties(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None) -> dict:
+ """
+ Get the properties of a specific axis or axes.
+
+ Args:
+ ax (str, int, plt.Axes, or a list of those): The axis or axes to get the properties for.
+
+ Returns:
+ dict: A dictionary of properties for the axis or axes.
+ """
+ ax_list = self.get_axes(ax)
+ properties = {}
+ for a in ax_list:
+ properties = {key: properties.get(key, []) + [value] for key, value in a.properties().items()}
+
+ for k, v in properties.items():
+ if len(v) == 1:
+ properties[k] = v[0]
+
+ return properties
+
+ def set_axis_properties(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, **kwargs) -> None:
+ """
+ Set properties for a specific axis or axes.
+
+ Args:
+ ax (str, int, plt.Axes, or a list of those): The axis or axes to set the properties for.
+ **kwargs: Additional keyword arguments to pass to the axis object.
+ """
+ ax_list = self.get_axes(ax)
+ for a in ax_list:
+ a.set(**kwargs)
+
+ def set_axis_padding(self, padding: float, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, above: bool = True) -> None:
+ """
+ Set custom padding for a specific axis.
+
+ Args:
+ padding (float): The amount of padding to add as a fraction of the figure height.
+
+ above (bool): Whether to add padding above the axis (default) or below.
+ """
+ ax_list = self.get_axes(ax)
+ all_axes = [a for a, _ in self.axes]
+
+ for ax in ax_list:
+ axis_index = all_axes.index(ax)
+ if axis_index < 0:
+ warnings.warn("Axis not found in the figure.")
+ continue
+ if above:
+ self.custom_paddings[axis_index] = padding
+ elif axis_index == len(self.axes) - 1:
+ continue
+ else:
+ self.custom_paddings[axis_index + 1] = padding
+ self._adjust_axes()
+
+ def set_time_limits(
+ self, start_time: Union[float, str, pd.Timestamp, None],
+ end_time: Union[float, str, pd.Timestamp, None],
+ method: str='union',
+ reference_axis: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None
+ ) -> None:
+ """
+ Set the time limits for all axes in the figure. Calling with None will adjust the limits to the data.
+
+ Args:
+ start_time (Union[float, str, pd.Timestamp, None]): The start time for the x-axis.
+ end_time (Union[float, str, pd.Timestamp, None]): The end time for the x-axis.
+ """
+ # Default values
+ xlim = np.array(self.get_axis_properties(reference_axis)['xlim']).reshape((-1, 2))
+ if method == 'union':
+ xlim = xlim[:, 0].min(), xlim[:, 1].max()
+ elif method == 'intersect':
+ xlim = xlim[:, 0].max(), xlim[:, 1].min()
+ else:
+ raise ValueError(f"Invalid method: {method} not in ['union', 'intersect']")
+
+ # Convert string inputs to pandas Timestamp objects
+ if start_time is not None:
+ start_time = pd.to_datetime(start_time)
+ else:
+ start_time = xlim[0]
+ if end_time is not None:
+ end_time = pd.to_datetime(end_time)
+ else:
+ end_time = xlim[1]
+
+ self.set_axis_properties(xlim=(start_time, end_time))
+
+ def set_periodic_ticks(
+ self,
+ interval: Union[str, pd.Timedelta],
+ start_time: str = '2018-01-01 00:00',
+ end_time: str = None,
+ fmt=TIME_FORMAT,
+ ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]] = None
+ ) -> None:
+ """
+ Set periodic x-ticks at a regular interval throughout the day.
+
+ Args:
+ interval (Union[str, pd.Timedelta]): The interval between ticks (e.g., '1H' for hourly ticks, '30T' for 30 minutes).
+ start_time (str): The time of day to start the ticks from (default is '00:00').
+ end_time (str): The time of day to end the ticks at (default is None).
+ fmt (str): The date format string to be used for the tick labels.
+ ax (str, int, plt.Axes, or a list of those): The axis (or axes) to apply the ticks to.
+ Can be an axis object, a list of axes, an index, or a name. If None, applies to all axes.
+ """
+ # Convert interval to pandas Timedelta if it's a string
+ if isinstance(interval, str):
+ interval = pd.to_timedelta(interval)
+
+ # Convert start_time to a datetime object with today's date
+ if start_time is not None:
+ start_time = pd.to_datetime(start_time).tz_localize(None)
+ if end_time is not None:
+ end_time = pd.to_datetime(end_time).tz_localize(None)
+
+ # Determine which axes to apply this to
+ axes = self.get_axes(ax)
+
+ for a in axes:
+ if a is not None:
+ # Get the x-axis limits
+ min_x, max_x = a.get_xlim()
+
+ # Convert limits to datetime if they are in float format
+ if isinstance(min_x, (float, int)):
+ min_x = mdates.num2date(min_x).replace(tzinfo=None)
+ if isinstance(max_x, (float, int)):
+ max_x = mdates.num2date(max_x).replace(tzinfo=None)
+
+ # Set the ticks to align with the start_datetime
+ ticks = pd.date_range(start=start_time if start_time else min_x,
+ end=end_time if end_time else max_x,
+ freq=interval)
+
+ # Make sure ticks are within the limits
+ ticks = [tick for tick in ticks if min_x <= tick and tick <= max_x]
+
+ # Set the locator and formatter
+ format_xticks(a, ticks, fmt)
+
+ plt.setp(a.get_xticklabels(), rotation=0, ha='center')
+
+ def add_legend(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]], **kwargs) -> None:
+ """
+ Add a legend to a specific axis.
+
+ Args:
+ axis (str, int, plt.Axes, or a list of those): The axis to add the legend to.
+ """
+ ax_list = self.get_axes(ax)
+ for a in ax_list:
+ a.legend(**kwargs)
+
+ def set_legend(self, ax: Union[str, int, plt.Axes, Iterable[Union[str, int, plt.Axes]]]=None, bbox_to_anchor: tuple=None, **kwargs):
+ """
+ Update the legend properties for all axes in the figure, or a subset of them, if the legend exists.
+
+ Args:
+ axis (str, int, plt.Axes, or a list of those): The name or list of names of axes to update the legend for.
+ bbox_to_anchor (tuple, optional): The bounding box coordinates for the legend.
+ **kwargs: Additional keyword arguments passed to the legend object.
+ """
+ ax_list = self.get_axes(ax)
+
+ for a in ax_list:
+ legend = a.get_legend()
+ if legend is None:
+ continue
+ if bbox_to_anchor is not None:
+ legend.set_bbox_to_anchor(bbox_to_anchor)
+ legend.set(**kwargs)
+
+ def show(self) -> None:
+ """
+ Display the figure.
+ """
+ plt.show()
+
+
+def format_xticks(ax: plt.Axes, xticks: Iterable=None, format: str=TIME_FORMAT, **kwargs):
+ """ format datestrings on x axis """
+ if xticks is None:
+ xticks = ax.get_xticks()
+ ax.set_xticks(xticks)
+ ax.set_xticklabels(xticks, **kwargs)
+ xfmt = mdates.DateFormatter(format)
+ ax.xaxis.set_major_formatter(xfmt)
+
+
+def format_timeseries(
+ df: pd.DataFrame,
+ participant_id: int=None,
+ array_index: int=None,
+ time_range: Tuple[str, str]=None,
+ x_start: str='collection_timestamp',
+ x_end: str='collection_timestamp',
+ unique: bool=False,
+) -> pd.DataFrame:
+ """
+ Reformat and filter a time series DataFrame based on participant ID, array index, and date range.
+
+ Args:
+ df (pd.DataFrame): The DataFrame to filter.
+ participant_id (int): The participant ID to filter by.
+ array_index (int): The array index to filter by.
+ time_range: The date range to filter by. Can be a tuple of two dates / times or two strings.
+ x_start (str): The name of the column containing the start time.
+ x_end (str): The name of the column containing the end time.
+
+ Returns:
+ pd.DataFrame: The filtered DataFrame
+ """
+ if participant_id is not None:
+ df = df.query('participant_id == @participant_id')
+ if array_index is not None:
+ df = df.query('array_index == @array_index')
+
+ # Reset index to avoid issues with slicing and indexing
+ x_ind = np.unique([c for c in [x_start, x_end] if c in df.index.names])
+ if len(x_ind):
+ if np.isin(x_ind, df.index.names).any():
+ df = df.reset_index(x_ind)
+ df[x_start] = df[x_start].dt.tz_localize(None)
+ if x_start != x_end:
+ df[x_end] = df[x_end].dt.tz_localize(None)
+ if time_range is not None:
+ time_range = pd.to_datetime(time_range)
+ df = df.loc[(time_range[0] <= df[x_start]) & (df[x_end] <= time_range[1])]
+ if unique:
+ df = df.drop_duplicates()
+
+ return df.sort_values(x_start)
+
+# %% ../nbs/15_timeseries_plots.ipynb 5
+def plot_events_bars(
+ events: pd.DataFrame,
+ x_start: str = 'collection_timestamp',
+ x_end: str = 'event_end',
+ y: str = 'event',
+ hue: str = 'channel',
+ participant_id: Optional[int] = None,
+ array_index: Optional[int] = None,
+ time_range: Optional[Tuple[str, str]] = None,
+ y_include: Optional[Iterable[str]] = None,
+ y_exclude: Optional[Iterable[str]] = None,
+ legend: bool = True,
+ palette: str = DEFAULT_PALETTE,
+ alpha: Optional[float] = 0.7,
+ ax: Optional[plt.Axes] = None,
+ figsize: Tuple[float, float] = (12, 6),
+) -> plt.Axes:
+ """
+ Plot events as bars on a time series plot.
+
+ Args:
+ events (pd.DataFrame): The events dataframe.
+ x_start (str): The column name for the start time of the event.
+ x_end (str): The column name for the end time of the event.
+ y (str): The column name for the y-axis values.
+ hue (str): The column name for the color of the event.
+ participant_id (int): The participant ID to filter events by.
+ array_index (int): The array index to filter events by.
+ time_range (Tuple[str, str]): The time range to filter events by.
+ y_include (Iterable[str]): The list of values to include in the plot.
+ y_exclude (Iterable[str]): The list of values to exclude from the plot.
+ legend (bool): Whether to show the legend.
+ palette (str): The name of the colormap to use for coloring events.
+ alpha (float): The transparency of the bars. Default is 0.7.
+ ax (plt.Axes): The axis to plot on. If None, a new figure is created.
+ figsize (Tuple[float, float]): The size of the figure (width, height) in inches.
+ """
+ events, color_map = prep_to_plot_timeseries(
+ events, x_start, x_end,
+ hue, y,
+ participant_id, array_index, time_range,
+ y_include, y_exclude,
+ palette=palette)
+ if hue is None:
+ hue = 'hue'
+
+ if ax is None:
+ fig, ax = plt.subplots(figsize=figsize)
+
+ # Plot events
+ events = events.assign(diff=lambda x: x[x_end] - x[x_start]).sort_values([hue, y])
+ y_labels = []
+ legend_dicts = []
+ for i, (y_label, events) in enumerate(events.groupby(y, observed=True, sort=False)):
+ if len(y) == 0:
+ continue
+ y_labels.append(y_label)
+ for c, r in events.groupby(hue, observed=True):
+ data = r[[x_start, 'diff']]
+ if not len(data):
+ continue
+ h = ax.broken_barh(data.values, (i-0.4,0.8), color=color_map[c], alpha=alpha)
+ legend_dicts.append({'label': c, 'handle': h})
+
+ # format plot
+ if legend:
+ legend_df = pd.DataFrame.from_dict(legend_dicts).drop_duplicates(subset='label')
+ ax.legend(
+ legend_df['handle'],
+ legend_df['label'],
+ loc='upper left',
+ bbox_to_anchor=LEGEND_SHIFT)
+
+ ax.set_yticks(np.arange(len(y_labels)), y_labels)
+ format_xticks(ax)
+ ax.invert_yaxis() # Invert y-axis to match the order of the legend
+
+ return ax
+
+
+def plot_events_fill(
+ events: pd.DataFrame,
+ x_start: str = 'collection_timestamp',
+ x_end: str = 'event_end',
+ hue: str = 'channel',
+ label: str = None,
+ participant_id: Optional[int] = None,
+ array_index: Optional[int] = None,
+ time_range: Optional[Tuple[str, str]] = None,
+ y_include: Optional[Iterable[str]] = None,
+ y_exclude: Optional[Iterable[str]] = None,
+ legend: bool = True,
+ palette: str = DEFAULT_PALETTE,
+ alpha: Optional[float] = 0.5,
+ ax: Optional[plt.Axes] = None,
+ figsize: Iterable[float] = [12, 6],
+) -> plt.Axes:
+ """
+ Plot events as filled regions on a time series plot.
+
+ Args:
+ events (pd.DataFrame): The events dataframe.
+ x_start (str): The column name for the start time of the event.
+ x_end (str): The column name for the end time of the event.
+ hue (str): The column name for the color of the event.
+ label (str): The column name for the label of the event.
+ participant_id (int): The participant ID to filter events by.
+ array_index (int): The array index to filter events by.
+ time_range (Iterable[str]): The time range to filter events by.
+ y_include (Iterable[str]): The list of values to include in the plot.
+ y_exclude (Iterable[str]): The list of values to exclude from the plot.
+ legend (bool): Whether to show the legend.
+ palette (str): The name of the palette to use for coloring events.
+ alpha (float): The transparency of the filled regions.
+ ax (plt.Axes): The axis to plot on. If None, a new figure is created.
+ figsize (Tuple[float, float]): The size of the figure (width, height) in inches.
+ """
+ events, color_map = prep_to_plot_timeseries(
+ events, x_start, x_end,
+ hue, label,
+ participant_id, array_index, time_range,
+ y_include, y_exclude,
+ palette=palette)
+ if hue is None:
+ hue = 'hue'
+
+ if ax is None:
+ fig, ax = plt.subplots(figsize=figsize)
+ if type(ax) is not list:
+ ax = [ax]
+
+ for a in ax:
+ # Plotting events
+ this_color = hue if hue is not None else '#4c72b0'
+ for _, row in events.iterrows():
+ if color_map is not None:
+ this_color = color_map[row[hue]]
+ # Plot the event as a filled region, with zorder to ensure it's behind other elements
+ a.axvspan(
+ row[x_start], row[x_end], 0, 1,
+ color=this_color, alpha=alpha, zorder=0,
+ transform=a.get_xaxis_transform())
+
+ # Add labels as xticks on the top secondary x-axis
+ if label:
+ secax = a.secondary_xaxis('top')
+ secax.set_xticks(events[x_start])
+ secax.set_xticklabels(events[label], rotation=0, ha='center')
+
+ # Add legend
+ if legend:
+ # Get existing handles from existing legends in the axes
+ handles, labels = a.get_legend_handles_labels()
+ if color_map is not None:
+ handles += [plt.Rectangle((0, 0), 1, 1, color=c, alpha=alpha) for c in color_map]
+ labels += color_map.index.tolist()
+ else:
+ handles += [plt.Rectangle((0, 0), 1, 1, color=this_color, alpha=alpha)]
+ labels += ['events']
+ a.legend(handles, labels, loc='upper left', bbox_to_anchor=LEGEND_SHIFT)
+
+ format_xticks(a)
+
+ return ax
+
+
+def prep_to_plot_timeseries(
+ data: pd.DataFrame,
+ x_start: str,
+ x_end: str,
+ hue: str,
+ label: str,
+ participant_id: int,
+ array_index: int,
+ time_range: Tuple[str, str],
+ y_include: Iterable[str],
+ y_exclude: Iterable[str],
+ add_columns: Iterable[str]=None,
+ palette=DEFAULT_PALETTE,
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
+ """
+ Prepare timeseries / events data for plotting.
+
+ Args:
+ events (pd.DataFrame): The timeseries / events dataframe.
+ x_start (str): The column name for the start time of the event.
+ x_end (str): The column name for the end time of the event.
+ hue (str): The column name for the color of the event.
+ label (str): The column name for the label of the event.
+ participant_id (int): The participant ID to filter events by.
+ array_index (int): The array index to filter events by.
+ time_range (Iterable[str]): The time range to filter events by.
+ y_include (Iterable[str]): The list of values to include in the plot.
+ y_exclude (Iterable[str]): The list of values to exclude from the plot.
+ add_columns (Iterable[str]): Additional columns to include in the plot.
+ palette (str): The name of the colormap to use for coloring events.
+
+ Returns:
+ Tuple[pd.DataFrame, pd.DataFrame]: The filtered events dataframe and the color map.
+ """
+ if type(add_columns) is str:
+ add_columns = [add_columns]
+
+ data = format_timeseries(data, participant_id, array_index, time_range, x_start, x_end)
+
+ # Filter events based on y_include and y_exclude
+ data = data.dropna(subset=[x_start, x_end])
+ if hue is not None and hue in data.index.names:
+ data = data.reset_index(hue)
+ if label is not None and label in data.index.names:
+ data = data.reset_index(label)
+ if y_include is not None:
+ ind = pd.Series(False, index=data.index)
+ if hue is not None:
+ ind |= data[hue].isin(y_include)
+ if label is not None:
+ ind |= data[label].isin(y_include)
+ data = data.loc[ind]
+ if y_exclude is not None:
+ ind = pd.Series(False, index=data.index)
+ if hue is not None:
+ ind |= data[hue].isin(y_exclude)
+ if label is not None:
+ ind |= data[label].isin(y_exclude)
+ data = data.loc[~ind]
+ if hue is None:
+ hue = 'hue'
+ data[hue] = 'events'
+
+ col_list = [x_start, x_end, hue, label]
+ if add_columns is not None:
+ col_list += list(add_columns)
+ col_list = pd.Series(col_list).dropna().drop_duplicates()
+
+ # Set colors
+ if hue in data.columns:
+ colors = get_color_map(data, hue, palette)
+ else:
+ colors = None
+
+ return data[col_list], colors
+
+
+def get_events_period(
+ events_filtered: pd.DataFrame,
+ period_start: str,
+ period_end: str,
+ period_name: str,
+ col: str = 'event',
+ first_start: bool = True,
+ first_end: bool = True,
+ include_start: bool = True,
+ include_end: bool = True,
+ x_start: str = 'collection_timestamp',
+ x_end: str = 'event_end',
+) -> pd.DataFrame:
+ """
+ Get the period of time between the start and end events.
+
+ Args:
+ events_filtered (pd.DataFrame): The events DataFrame.
+ period_start (str): The label of the start event.
+ period_end (str): The label of the end event.
+ period_name (str): The label to assign to the period.
+ col (str): The column name for the event labels. Default is 'event'.
+ first_start (bool): If True, get the first start event. Default is True.
+ first_end (bool): If True, get the first end event. Default is True.
+ include_start (bool): If True, include the start event in the period. Default is True.
+ include_end (bool): If True, include the end event in the period. Default is True.
+ x_start (str): The column name for the start time of the event. Default is 'collection_timestamp'.
+ x_end (str): The column name for the end time of the event. Default is 'event_end'.
+
+ Returns:
+ pd.DataFrame: The period of events in the same format as the input DataFrame.
+ """
+ events_filtered = format_timeseries(events_filtered, None, None, None, x_start, x_end)
+
+ start_time = events_filtered.loc[
+ events_filtered[col] == period_start,
+ x_start if include_start else x_end]\
+ .iloc[0 if first_start else -1]
+ end_time = events_filtered.loc[
+ events_filtered[col] == period_end,
+ x_end if include_end else x_start]\
+ .iloc[0 if first_end else -1]
+
+ return pd.DataFrame({
+ x_start: [start_time],
+ x_end: [end_time],
+ col: [period_name]
+ })
+
+
+def get_color_map(data: pd.DataFrame, hue: str, palette: str) -> pd.DataFrame:
+ """
+ Get a color map for a specific column in the data.
+
+ Args:
+ data (pd.DataFrame): The data to get the color map from.
+ hue (str): The column name to use for the color map.
+ palette (str): The name of the colormap to use.
+
+ Returns:
+ pd.DataFrame: A DataFrame with the color map.
+ """
+ colors = sorted(data[hue].unique())
+ colors = pd.DataFrame({
+ hue: colors,
+ 'color': sns.color_palette(palette, len(colors))
+ }).set_index(hue)['color']
+
+ return colors