Skip to content

Commit

Permalink
Update with correct changes
Browse files Browse the repository at this point in the history
  • Loading branch information
justin-richling authored Nov 19, 2024
1 parent d8241db commit 0c10f6e
Showing 1 changed file with 66 additions and 18 deletions.
84 changes: 66 additions & 18 deletions scripts/plotting/global_mean_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,38 @@ def global_mean_timeseries(adfobj):
Include the CESM2 LENS result if it can be found.
"""

plot_loc = get_plot_loc(adfobj)
#Notify user that script has started:
print("\n Generating global mean time series plots...")

# Gather ADF configurations
plot_loc = get_plot_loc(adfobj)
plot_type = adfobj.read_config_var("diag_basic_info").get("plot_type", "png")

# Loop over variables
for field in adfobj.diag_var_list:
# reference time series (DataArray)
ref_ts_da = adfobj.data.load_reference_timeseries_da(field)

# Check to see if this field is available
if ref_ts_da is None:
print(
f"\t Variable named {field} provides Nonetype. Skipping this variable"
)
continue
# check if this is a "2-d" varaible:
has_lat_ref, has_lev_ref = pf.zm_validate_dims(ref_ts_da)
if has_lev_ref:
print(
f"Variable named {field} has a lev dimension, which does not work with this script."
)
continue
else:
# reference time series global average
ref_ts_da_ga = pf.spatial_average(ref_ts_da, weights=None, spatial_dims=None)

# annually averaged
ref_ts_da = pf.annual_mean(ref_ts_da_ga, whole_years=True, time_name="time")

# reference time series global average
ref_ts_da_ga = pf.spatial_average(ref_ts_da, weights=None, spatial_dims=None)
# check if this is a "2-d" varaible:
has_lat_ref, has_lev_ref = pf.zm_validate_dims(ref_ts_da)
if has_lev_ref:
print(
f"Variable named {field} has a lev dimension, which does not work with this script."
)
continue

# annually averaged
ref_ts_da = pf.annual_mean(ref_ts_da_ga, whole_years=True, time_name="time")

## SPECIAL SECTION -- CESM2 LENS DATA:
lens2_data = Lens2Data(
Expand All @@ -74,15 +81,39 @@ def global_mean_timeseries(adfobj):
if adfobj.data.ref_nickname
else adfobj.data.ref_case_label
)

has_lev = False
for case_name in adfobj.data.case_names:
c_ts_da = adfobj.data.load_timeseries_da(case_name, field)

# If no reference, we still neeed to check if this is a "2-d" varaible:
if ref_ts_da is None:
has_lat_ref, has_lev_ref = pf.zm_validate_dims(c_ts_da)
# End if

# If 3-d variable, notify useer, flag and move to next test case
if has_lev_ref:
print(
f"Variable named {field} has a lev dimension for '{case_name}', which does not work with this script."
)

has_lev = True
continue
# End if

# Gather spatial avg for test case
c_ts_da_ga = pf.spatial_average(c_ts_da)
case_ts[labels[case_name]] = pf.annual_mean(c_ts_da_ga)
# now have to plot the timeseries

# If this case is 3-d, then break the loop and go to next variable
if has_lev:
continue

# Plot the timeseries
fig, ax = make_plot(
ref_ts_da, case_ts, lens2_data, label=adfobj.data.ref_nickname
case_ts, lens2_data, label=adfobj.data.ref_nickname, ref_ts_da=ref_ts_da
)
ax.set_ylabel(getattr(ref_ts_da,"units", "[-]")) # add units
ax.set_ylabel(getattr(ref_ts_da,"new_unit", "[-]")) # add units
plot_name = plot_loc / f"{field}_GlobalMean_ANN_TimeSeries_Mean.{plot_type}"

conditional_save(adfobj, plot_name, fig)
Expand All @@ -96,6 +127,12 @@ def global_mean_timeseries(adfobj):
plot_type="TimeSeries",
)

#Notify user that script has ended:
print(" ...lat/lon maps have been generated successfully.")


# Helper/plotting functions
###########################

def conditional_save(adfobj, plot_name, fig, verbose=None):
"""Determines whether to save figure"""
Expand All @@ -118,6 +155,7 @@ def conditional_save(adfobj, plot_name, fig, verbose=None):
f"Conditional save found unknown condition. File will not be written: {plot_name}"
)
plt.close(fig)
######


def get_plot_loc(adfobj, verbose=None):
Expand Down Expand Up @@ -147,6 +185,7 @@ def get_plot_loc(adfobj, verbose=None):
plot_loc = Path(plot_location)
print(f"Determined plot location: {plot_loc}")
return plot_loc
######


class Lens2Data:
Expand All @@ -168,13 +207,17 @@ def _include_lens(self):
has_lens = False
lens2 = None
return has_lens, lens2
######


def make_plot(ref_ts_da, case_ts, lens2, label=None):
def make_plot(case_ts, lens2, label=None, ref_ts_da=None):
"""plot yearly values of ref_ts_da"""
field = lens2.field # this will be defined even if no LENS2 data
fig, ax = plt.subplots()
ax.plot(ref_ts_da.year, ref_ts_da, label=label)

# Plot reference/baseline if available
if ref_ts_da:
ax.plot(ref_ts_da.year, ref_ts_da, label=label)
for c, cdata in case_ts.items():
ax.plot(cdata.year, cdata, label=c)
if lens2.has_lens:
Expand Down Expand Up @@ -202,3 +245,8 @@ def make_plot(ref_ts_da, case_ts, lens2, label=None):
plt.tight_layout(pad=2, w_pad=1.0, h_pad=1.0)

return fig, ax
######


##############
#END OF SCRIPT

0 comments on commit 0c10f6e

Please sign in to comment.