-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PR]: Improving regrid2 performance #533
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #533 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 15 15
Lines 1602 1588 -14
=========================================
- Hits 1602 1588 -14 ☔ View full report in Codecov by Sentry. |
Hey @jasonb5 just checking in to see your estimated timeline for when this will be ready for review and merge. I'm shooting to have xCDAT v0.6.0 released in the next week or so. |
Notes from 9/13/23 meeting:
|
10/11/23 Meeting Notes: Next steps:
|
Any status updates here? |
@tomvothecoder @lee1043 @chengzhuzhang @pochedls
I placed the notebook under If everything looks alright lets merge and I'll get out the next few fixes and continue working on performance. |
Thank you for the update. Sure, I can test the branch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jasonb5, here's my initial code review with questions and minor suggestions.
xcdat/regridder/regrid2.py
Outdated
mapping = [ | ||
np.where( | ||
np.logical_and( | ||
shifted_src_west < dst_east[i], shifted_src_east > dst_west[i] | ||
shifted_src_west < dst_east[x], shifted_src_east > dst_west[x] | ||
) | ||
)[0] | ||
for x in range(dst_length) | ||
] | ||
|
||
weight = np.minimum(dst_east[i], shifted_src_east[contrib]) - np.maximum( | ||
dst_west[i], shifted_src_west[contrib] | ||
) | ||
|
||
weights.append(weight.values.reshape(1, contrib.shape[0])) | ||
|
||
contrib += shift | ||
weights = [ | ||
( | ||
np.minimum(dst_east[x], shifted_src_east[y]) | ||
- np.maximum(dst_west[x], shifted_src_west[y]) | ||
).reshape((1, -1)) | ||
for x, y in enumerate(mapping) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about adding comment to explain logic and purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved with 7695e2b
xcdat/regridder/regrid2.py
Outdated
name = input_data_var.cf.axes[cf_axis_name] | ||
|
||
if isinstance(name, list): | ||
name = name[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless the intent here is to only interpret the axis
CF attribute?
Instead of using cf_xarray
directly, I think you can use xc.get_dim_keys() which can also interpret the standard_name
attribute or use the xCDAT fall-back table of generally accepted axis names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could use get_dim_keys
but we will be trading performance for robustness. If we accept this I'm fine making the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Is the performance hit significant using get_dim_keys()
? If so, I think it is fine to only interpret the axis
attribute for performance.
@lee1043 any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this part of the code is in the back-end level that less likely be accessed by users, I would prefer prioritizing performance, unless the robustness trading off is too significant. How much performance change this would make?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try passing data variable directly to get_dim_keys()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like using get_dim_keys
works just fine and no decrease in performance, actually a small increase.
try: | ||
name = ds.cf.bounds[axis][0] | ||
except (KeyError, IndexError): | ||
raise RuntimeError(f"Could not determine {axis!r} bounds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | |
name = ds.cf.bounds[axis][0] | |
except (KeyError, IndexError): | |
raise RuntimeError(f"Could not determine {axis!r} bounds") | |
try: | |
name = ds.bounds.get_bounds(axis) | |
except (ValueError, KeyError): | |
raise RuntimeError(f"Could not determine {axis!r} bounds") |
I think you can use xCDAT's ds.bounds.get_bounds()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As from above I could use get_bounds
but we're again trading performance for robustness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. If the current implementation is faster and using .get_bounds()
isn't necessary, we can keep your current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some review I'm going to leave the ds.cf.bounds
usage. In this case it's fine to use as it's being called on the input/output grid, both of which have been generated/validated by previous code. We can guarantee that both grid objects only have lat/lon coordinates/bounds with the correct metadata thus we do not need the robustness of ds.bounds.get_bounds
.
xcdat/regridder/regrid2.py
Outdated
for y in range(y_length): | ||
y_seg = np.take(input_data, lat_mapping[y], axis=y_index) | ||
|
||
for lon_index, lon_map in enumerate(self._lon_mapping): | ||
lon_weight = self._lon_weights[lon_index] | ||
for x in range(x_length): | ||
x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") | ||
|
||
dot_weight = np.dot(lat_weight, lon_weight) | ||
cell_weight = np.dot(lat_weights[y], lon_weights[x]) | ||
|
||
cell_weight = np.sum(dot_weight) | ||
output_seg_index = y * x_length + x | ||
|
||
input_lon_segment = np.take( | ||
input_lat_segment, lon_map, axis=input_lon_index | ||
if is_2d: | ||
output_data[output_seg_index] = np.divide( | ||
np.sum( | ||
np.multiply(x_seg, cell_weight), | ||
axis=(y_index, x_index), | ||
), | ||
np.sum(cell_weight), | ||
) | ||
|
||
data = ( | ||
np.nansum( | ||
np.multiply(input_lon_segment, dot_weight), | ||
axis=(input_lat_index, input_lon_index), | ||
) | ||
/ cell_weight | ||
else: | ||
output_seg = output_data[output_seg_index] | ||
|
||
np.divide( | ||
np.sum( | ||
np.multiply(x_seg, cell_weight), | ||
axis=(y_index, x_index), | ||
), | ||
np.sum(cell_weight), | ||
out=output_seg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments explaining the logic here would be good. Maybe in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved with 7695e2b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed approval. I thought I have marked approval but missed.
Description
After some analysis, I determined we were losing some performance moving back and forth between xarray and numpy.
The first fix ensures we're doing all the heavy computation in numpy. This has reduce the example time from ~4.8583 to ~1.6833.
There's still some more room for improvement.
Checklist
If applicable: