Skip to content

Commit

Permalink
Merge pull request #29 from NCAR/prior_post_plots
Browse files Browse the repository at this point in the history
fix: prior and posterior profile plotting
  • Loading branch information
hkershaw-brown authored Dec 23, 2024
2 parents bd0041c + cb9ac3b commit cb354e8
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/pydartdiags/obs_sequence/obs_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def __init__(self, file, synonyms=None):
# calculate bias and sq_err is the obs_seq is an obs_seq.final
if 'prior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
self.has_assimilation_info = True
self.df['bias'] = (self.df['prior_ensemble_mean'] - self.df['observation'])
self.df['sq_err'] = self.df['bias']**2 # squared error
self.df['prior_bias'] = (self.df['prior_ensemble_mean'] - self.df['observation'])
self.df['prior_sq_err'] = self.df['prior_bias']**2 # squared error
if 'posterior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
self.has_posterior_info = True
self.df['posterior_bias'] = (self.df['posterior_ensemble_mean'] - self.df['observation'])
Expand Down
197 changes: 171 additions & 26 deletions src/pydartdiags/plots/plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

def plot_rank_histogram(df):
Expand Down Expand Up @@ -68,7 +69,7 @@ def plot_profile(df, levels, verticalUnit = "pressure (Pa)"):
This function takes a DataFrame containing observational data and model predictions, categorizes
the data into specified vertical levels, and calculates the RMSE, bias and total spread for each level and
observation type. It then plots three line charts: one for RMSE, one for bias, one for total spread, as functions
observation type. It then plots three line charts: one for RMSE, one for bias, one for total spread, as functions
of vertical level. The vertical levels are plotted on the y-axis in reversed order to represent
the vertical profile in the atmosphere correctly if the vertical units are pressure.
Expand Down Expand Up @@ -113,32 +114,157 @@ def plot_profile(df, levels, verticalUnit = "pressure (Pa)"):
df.loc[:,'midpoint'] = df['vlevels'].apply(lambda x: x.mid)

# Calculations
df_profile = rmse_bias_totalspread(df)
df_profile_prior = rmse_bias_totalspread(df, phase='prior')
df_profile_posterior = None
if 'posterior_ensemble_mean' in df.columns:
df_profile_posterior = rmse_bias_totalspread(df, phase='posterior')

# Merge prior and posterior dataframes
if df_profile_posterior is not None:
df_profile = pd.merge(df_profile_prior, df_profile_posterior, on=['midpoint', 'type'], suffixes=('_prior', '_posterior'))
fig_rmse = plot_profile_prior_post(df_profile, 'rmse', verticalUnit)
fig_rmse.show()
fig_bias = plot_profile_prior_post(df_profile, 'bias', verticalUnit)
fig_bias.show()
fig_ts = plot_profile_prior_post(df_profile, 'totalspread', verticalUnit)
fig_ts.show()
else:
df_profile = df_profile_prior
fig_rmse = plot_profile_prior(df_profile, 'rmse', verticalUnit)
fig_rmse.show()
fig_bias = plot_profile_prior(df_profile, 'bias', verticalUnit)
fig_bias.show()
fig_ts = plot_profile_prior(df_profile, 'totalspread', verticalUnit)
fig_ts.show()

# RMSE plot
fig_rmse = px.line(df_profile, y='midpoint', x='rmse', title='RMSE by Level', markers=True, color='type', width=800, height=800,
labels={"midpoint": verticalUnit})
if verticalUnit == "pressure (Pa)":
fig_rmse.update_yaxes(autorange="reversed")
fig_rmse.show()
return df_profile, fig_rmse, fig_ts, fig_bias

def plot_profile_prior_post(df_profile, stat, verticalUnit):
"""
Plots prior and posterior statistics by vertical level for different observation types.
Parameters:
df_profile (pd.DataFrame): DataFrame containing the prior and posterior statistics.
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
Returns:
plotly.graph_objects.Figure: The generated Plotly figure.
"""
# Reshape DataFrame to long format for easier plotting
df_long = pd.melt(
df_profile,
id_vars=["midpoint", "type"],
value_vars=["prior_"+stat, "posterior_"+stat],
var_name=stat+"_type",
value_name=stat+"_value"
)

# Define a color mapping for observation each type
unique_types = df_long["type"].unique()
colors = px.colors.qualitative.Plotly
color_mapping = {type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)}

# Create a mapping for line styles based on stat
line_styles = {"prior_"+stat: "solid", "posterior_"+stat: "dash"}

# Create the figure
fig_stat = go.Figure()

# Loop through each type and type to add traces
for t in df_long["type"].unique():
for stat_type, dash_style in line_styles.items():
# Filter the DataFrame for this type and stat
df_filtered = df_long[(df_long[stat+"_type"] == stat_type) & (df_long["type"] == t)]

# Add a trace
fig_stat.add_trace(go.Scatter(
x=df_filtered[stat+"_value"],
y=df_filtered["midpoint"],
mode='lines+markers',
name='prior '+t if stat_type == "prior_"+stat else 'post ', # Show legend for "prior_stat OBS TYPE" only
line=dict(dash=dash_style, color=color_mapping[t]), # Same color for all traces in group
marker=dict(size=8, color=color_mapping[t]),
legendgroup=t # Group traces by type
))

# Update layout
fig_stat.update_layout(
title= stat+' by Level',
xaxis_title=stat,
yaxis_title=verticalUnit,
width=800,
height=800,
template="plotly_white"
)

# totalspread plot
fig_ts = px.line(df_profile, y='midpoint', x='totalspread', title='Totalspread by Level', markers=True, color='type', width=800, height=800,
labels={"midpoint": verticalUnit})
if verticalUnit == "pressure (Pa)":
fig_ts.update_yaxes(autorange="reversed")
fig_ts.show()
fig_stat.update_yaxes(autorange="reversed")

return fig_stat


def plot_profile_prior(df_profile, stat, verticalUnit):
"""
Plots prior statistics by vertical level for different observation types.
Parameters:
df_profile (pd.DataFrame): DataFrame containing the prior statistics.
stat (str): The statistic to plot (e.g., 'rmse', 'bias', 'totalspread').
verticalUnit (str): The unit of the vertical axis (e.g., 'pressure (Pa)').
Returns:
plotly.graph_objects.Figure: The generated Plotly figure.
"""
# Reshape DataFrame to long format for easier plotting - not needed for prior only, but
# leaving it in for consistency with the plot_profile_prior_post function for now
df_long = pd.melt(
df_profile,
id_vars=["midpoint", "type"],
value_vars=["prior_"+stat],
var_name=stat+"_type",
value_name=stat+"_value"
)

# Define a color mapping for observation each type
unique_types = df_long["type"].unique()
colors = px.colors.qualitative.Plotly
color_mapping = {type_: colors[i % len(colors)] for i, type_ in enumerate(unique_types)}

# Create the figure
fig_stat = go.Figure()

# Loop through each type to add traces
for t in df_long["type"].unique():
# Filter the DataFrame for this type and stat
df_filtered = df_long[(df_long["type"] == t)]

# Add a trace
fig_stat.add_trace(go.Scatter(
x=df_filtered[stat+"_value"],
y=df_filtered["midpoint"],
mode='lines+markers',
name='prior ' + t,
line=dict(color=color_mapping[t]), # Same color for all traces in group
marker=dict(size=8, color=color_mapping[t]),
legendgroup=t # Group traces by type
))

# Update layout
fig_stat.update_layout(
title=stat + ' by Level',
xaxis_title=stat,
yaxis_title=verticalUnit,
width=800,
height=800,
template="plotly_white"
)

# bias plot
fig_bias = px.line(df_profile, y='midpoint', x='bias', title='Bias by Level', markers=True, color='type', width=800, height=800,
labels={"midpoint": verticalUnit})
if verticalUnit == "pressure (Pa)":
fig_bias.update_yaxes(autorange="reversed")
fig_bias.show()

fig_stat.update_yaxes(autorange="reversed")
return fig_stat

return df_profile, fig_rmse, fig_ts, fig_bias


def mean_then_sqrt(x):
"""
Expand All @@ -158,15 +284,34 @@ def mean_then_sqrt(x):

return np.sqrt(np.mean(x))

def rmse_bias_totalspread(df):
rmse_bias_ts_df = df.groupby(['midpoint', 'type'], observed=False)
rmse_bias_ts_df = rmse_bias_ts_df.agg({'sq_err':mean_then_sqrt, 'bias':'mean', 'posterior_ensemble_spread':mean_then_sqrt, 'obs_err_var':mean_then_sqrt}).reset_index()
def rmse_bias_totalspread(df, phase='prior'):
if phase == 'prior':
sq_err_column = 'prior_sq_err'
bias_column = 'prior_bias'
rmse_column = 'prior_rmse'
spread_column = 'prior_ensemble_spread'
totalspread_column = 'prior_totalspread'
elif phase == 'posterior':
sq_err_column = 'posterior_sq_err'
bias_column = 'posterior_bias'
rmse_column = 'posterior_rmse'
spread_column = 'posterior_ensemble_spread'
totalspread_column = 'posterior_totalspread'
else:
raise ValueError("Invalid phase. Must be 'prior' or 'posterior'.")

rmse_bias_ts_df = df.groupby(['midpoint', 'type'], observed=False).agg({
sq_err_column: mean_then_sqrt,
bias_column: 'mean',
spread_column: mean_then_sqrt,
'obs_err_var': mean_then_sqrt
}).reset_index()

# Add column for totalspread
rmse_bias_ts_df['totalspread'] = np.sqrt(rmse_bias_ts_df['posterior_ensemble_spread']+rmse_bias_ts_df['obs_err_var'])
rmse_bias_ts_df[totalspread_column] = np.sqrt(rmse_bias_ts_df[spread_column] + rmse_bias_ts_df['obs_err_var'])

# Rename square error to root mean square error
rmse_bias_ts_df.rename(columns={'sq_err':'rmse'}, inplace=True)
rmse_bias_ts_df.rename(columns={sq_err_column: rmse_column}, inplace=True)

return rmse_bias_ts_df

Expand Down
4 changes: 2 additions & 2 deletions tests/test_obs_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_read1d(self, obs_seq_file_path):
obj = obsq.obs_sequence(obs_seq_file_path)
assert obj.loc_mod == 'loc1d'
assert len(obj.df) == 40 # 40 obs in the file
assert obj.df.columns.str.contains('posterior').sum() == 22 + 2 # members + sq_err + bias
assert obj.df.columns.str.contains('prior').sum() == 22
assert obj.df.columns.str.contains('posterior').sum() == 24 # 20 members + mean + spread + sq_err + bias
assert obj.df.columns.str.contains('prior').sum() == 24



Expand Down

0 comments on commit cb354e8

Please sign in to comment.