diff --git a/src/pydartdiags/plots/plots.py b/src/pydartdiags/plots/plots.py index 20406e8..a5ff2ad 100644 --- a/src/pydartdiags/plots/plots.py +++ b/src/pydartdiags/plots/plots.py @@ -62,54 +62,82 @@ def calculate_rank(df): return (rank, ens_size, result_df) -def plot_profile(df, levels): +def plot_profile(df, levels, verticalUnit = "pressure (Pa)"): """ - Plots RMSE and Bias profiles for different observation types across specified pressure levels. + Plots RMSE, bias, and total spread profiles for different observation types across specified vertical levels. This function takes a DataFrame containing observational data and model predictions, categorizes - the data into specified pressure levels, and calculates the RMSE and Bias for each level and - observation type. It then plots two line charts: one for RMSE and another for Bias, both as functions - of pressure level. The pressure levels are plotted on the y-axis in reversed order to represent - the vertical profile in the atmosphere correctly. + the data into specified vertical levels, and calculates the RMSE, bias and total spread for each level and + observation type. It then plots three line charts: one for RMSE, one for bias, one for total spread, as functions + of vertical level. The vertical levels are plotted on the y-axis in reversed order to represent + the vertical profile in the atmosphere correctly if the vertical units are pressure. Parameters: - df (pd.DataFrame): The input DataFrame containing at least the 'vertical' column for pressure levels, - and other columns required by the `rmse_bias` function for calculating RMSE and Bias. - levels (array-like): The bin edges for categorizing the 'vertical' column values into pressure levels. + df (pd.DataFrame): The input DataFrame containing at least the 'vertical' column for vertical levels, + the vert_unit column, and other columns required by the `rmse_bias` function for calculating RMSE and + Bias. + levels (array-like): The bin edges for categorizing the 'vertical' column values into the desired + vertical levels. + verticalUnit (string) (optional): The vertical unit to be used. Only observations in df which have this + string in the vert_unit column will be plotted. Defaults to 'pressure (Pa)'. Returns: - tuple: A tuple containing the DataFrame with RMSE and Bias calculations, the RMSE plot figure, and the - Bias plot figure. The DataFrame includes a 'plevels' column representing the categorized pressure levels - and 'hPa' column representing the midpoint of each pressure level bin. + tuple: A tuple containing the DataFrame with RMSE, bias and total spread calculations, + The DataFrame includes a 'vlevels' column representing the categorized vertical levels + and 'midpoint' column representing the midpoint of each vertical level bin. And the three figures. Raises: ValueError: If there are missing values in the 'vertical' column of the input DataFrame. + ValueError: If none of the input obs have 'verticalUnit' in the 'vert_unit' column of the input DataFrame. Note: - - The function modifies the input DataFrame by adding 'plevels' and 'hPa' columns. - - The 'hPa' values are calculated as half the midpoint of each pressure level bin, which may need - adjustment based on the specific requirements for pressure level representation. + - The function modifies the input DataFrame by adding 'vlevels' and 'midpoint' columns. + - The 'midpoint' values are calculated as half the midpoint of each vertical level bin, which may need + adjustment based on the specific requirements for vertical level representation. - The plots are generated using Plotly Express and are displayed inline. The y-axis of the plots is - reversed to align with standard atmospheric pressure level representation. + reversed to align with standard atmospheric pressure level representation if the vertical units + are atmospheric pressure. """ pd.options.mode.copy_on_write = True if df['vertical'].isnull().values.any(): # what about horizontal observations? raise ValueError("Missing values in 'vertical' column.") + elif verticalUnit not in df['vert_unit'].values: + raise ValueError("No obs with expected vertical unit '"+verticalUnit+"'.") else: - df.loc[:,'plevels'] = pd.cut(df['vertical'], levels) - df.loc[:,'hPa'] = df['plevels'].apply(lambda x: x.mid / 1000.) # HK todo units - - df_profile = rmse_bias(df) - fig_rmse = px.line(df_profile, y='hPa', x='rmse', title='RMSE by Level', markers=True, color='type', width=800, height=800) - fig_rmse.update_yaxes(autorange="reversed") + df = df[df["vert_unit"].isin({verticalUnit})] # Subset to only rows with the correct vertical unit + df.loc[:,'vlevels'] = pd.cut(df['vertical'], levels) + if verticalUnit == "pressure (Pa)": + df.loc[:,'midpoint'] = df['vlevels'].apply(lambda x: x.mid / 100.) # HK todo units + else: + df.loc[:,'midpoint'] = df['vlevels'].apply(lambda x: x.mid) + + # Calculations + df_profile = rmse_bias_totalspread(df) + + # RMSE plot + fig_rmse = px.line(df_profile, y='midpoint', x='rmse', title='RMSE by Level', markers=True, color='type', width=800, height=800, + labels={"midpoint": verticalUnit}) + if verticalUnit == "pressure (Pa)": + fig_rmse.update_yaxes(autorange="reversed") fig_rmse.show() - fig_bias = px.line(df_profile, y='hPa', x='bias', title='Bias by Level', markers=True, color='type', width=800, height=800) - fig_bias.update_yaxes(autorange="reversed") + # totalspread plot + fig_ts = px.line(df_profile, y='midpoint', x='totalspread', title='Totalspread by Level', markers=True, color='type', width=800, height=800, + labels={"midpoint": verticalUnit}) + if verticalUnit == "pressure (Pa)": + fig_ts.update_yaxes(autorange="reversed") + fig_ts.show() + + # bias plot + fig_bias = px.line(df_profile, y='midpoint', x='bias', title='Bias by Level', markers=True, color='type', width=800, height=800, + labels={"midpoint": verticalUnit}) + if verticalUnit == "pressure (Pa)": + fig_bias.update_yaxes(autorange="reversed") fig_bias.show() + - return df_profile, fig_rmse, fig_bias + return df_profile, fig_rmse, fig_ts, fig_bias def mean_then_sqrt(x): @@ -130,12 +158,17 @@ def mean_then_sqrt(x): return np.sqrt(np.mean(x)) -def rmse_bias(df): - rmse_bias_df = df.groupby(['hPa', 'type']).agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index() - rmse_bias_df.rename(columns={'sq_err':'rmse'}, inplace=True) - - return rmse_bias_df +def rmse_bias_totalspread(df): + rmse_bias_ts_df = df.groupby(['midpoint', 'type'], observed=False) + rmse_bias_ts_df = rmse_bias_ts_df.agg({'sq_err':mean_then_sqrt, 'bias':'mean', 'posterior_ensemble_spread':mean_then_sqrt, 'obs_err_var':mean_then_sqrt}).reset_index() + # Add column for totalspread + rmse_bias_ts_df['totalspread'] = np.sqrt(rmse_bias_ts_df['posterior_ensemble_spread']+rmse_bias_ts_df['obs_err_var']) + + # Rename square error to root mean square error + rmse_bias_ts_df.rename(columns={'sq_err':'rmse'}, inplace=True) + + return rmse_bias_ts_df def rmse_bias_by_obs_type(df, obs_type): """ @@ -155,7 +188,7 @@ def rmse_bias_by_obs_type(df, obs_type): raise ValueError(f"Observation type '{obs_type}' not found in DataFrame.") else: obs_type_df = df[df['type'] == obs_type] - obs_type_agg = obs_type_df.groupby('plevels').agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index() + obs_type_agg = obs_type_df.groupby('vlevels', observed=False).agg({'sq_err':mean_then_sqrt, 'bias':'mean'}).reset_index() obs_type_agg.rename(columns={'sq_err':'rmse'}, inplace=True) return obs_type_agg