From 69090118048acab40dacd46feaca2c31f2f8ecdc Mon Sep 17 00:00:00 2001 From: Matthias Baer Date: Tue, 22 Dec 2020 15:30:44 +0100 Subject: [PATCH] Refactor --- findiff/diff.py | 4 +-- findiff/operators.py | 6 ++-- findiff/stencils.py | 71 ++++++++++++++++---------------------------- test/test_bugs.py | 6 ++++ test/test_findiff.py | 13 ++++++++ 5 files changed, 48 insertions(+), 52 deletions(-) diff --git a/findiff/diff.py b/findiff/diff.py index c587f92..ae3e04a 100644 --- a/findiff/diff.py +++ b/findiff/diff.py @@ -44,8 +44,8 @@ def set_accuracy(self, acc): if isinstance(self.right, Operator): self.right.set_accuracy(acc) - def stencil(self, shape, h=None, acc=None, old_stl=None): - return Stencil(self, shape) + def stencil(self, shape, acc=None): + return Stencil(self, shape, acc=acc) class Plus(BinaryOperator): diff --git a/findiff/operators.py b/findiff/operators.py index 3c11a8c..43cf508 100644 --- a/findiff/operators.py +++ b/findiff/operators.py @@ -94,12 +94,10 @@ def apply(self, rhs, *args, **kwargs): args = self.coords, return self.pds(rhs, *args, **kwargs) - def stencil(self, shape, h=None, acc=None, old_stl=None): - if h is None and self.spac is not None: - h = self.spac + def stencil(self, shape, acc=None): if acc is None and self.acc is not None: acc = self.acc - return Stencil(self, shape) + return Stencil(self, shape, acc=acc) def matrix(self, shape, h=None, acc=None): if acc is None: diff --git a/findiff/stencils.py b/findiff/stencils.py index 1a7e8b9..3a1d6f3 100644 --- a/findiff/stencils.py +++ b/findiff/stencils.py @@ -1,8 +1,5 @@ from itertools import product -from copy import deepcopy -import operator import numpy as np -from .coefs import coefficients from .utils import to_long_index, to_index_tuple @@ -11,7 +8,7 @@ class Stencil(object): Represent the finite difference stencil for a given differential operator. """ - def __init__(self, diff_op, shape, old_stl=None): + def __init__(self, diff_op, shape, old_stl=None, acc=None): """ Constructor for Stencil objects. @@ -35,10 +32,11 @@ def __init__(self, diff_op, shape, old_stl=None): self.shape = shape self.diff_op = diff_op self.char_pts = self._det_characteristic_points() - if old_stl: - self.data = old_stl.data - else: - self.data = {} + self.acc = None + if acc is not None: + self.acc = acc + + self.data = {} self._create_stencil() @@ -136,34 +134,34 @@ def type_for_point(self, idx): def _create_stencil(self): - ndim = len(self.shape) - data = self.data - - matrix = self.diff_op.matrix(self.shape) + matrix = self.diff_op.matrix(self.shape, acc=self.acc) for pt in self.char_pts: - coef_dict = {} - data[pt] = coef_dict + char_point_stencil = {} + self.data[pt] = char_point_stencil - index_for_char_pt = [] - for axis, key in enumerate(pt): - if key == 'L': - index_for_char_pt.append(0) - elif key == 'C': - index_for_char_pt.append(self.shape[axis] // 2) - else: - index_for_char_pt.append(self.shape[axis] - 1) + index_tuple_for_char_pt = self._typical_index_tuple_for_char_point(pt) + long_index_for_char_pt = to_long_index(index_tuple_for_char_pt, self.shape) - long_index_for_char_pt = to_long_index(index_for_char_pt, self.shape) row = matrix[long_index_for_char_pt, :] long_row_inds, long_col_inds = row.nonzero() + for long_offset_ind in long_col_inds: offset_ind_tuple = np.array(to_index_tuple(long_offset_ind, self.shape), dtype=np.int) - offset_ind_tuple -= np.array(index_for_char_pt, dtype=np.int) - coef_dict[tuple(offset_ind_tuple)] = row[0, long_offset_ind] - - return None + offset_ind_tuple -= np.array(index_tuple_for_char_pt, dtype=np.int) + char_point_stencil[tuple(offset_ind_tuple)] = row[0, long_offset_ind] + + def _typical_index_tuple_for_char_point(self, pt): + index_tuple_for_char_pt = [] + for axis, key in enumerate(pt): + if key == 'L': + index_tuple_for_char_pt.append(0) + elif key == 'C': + index_tuple_for_char_pt.append(self.shape[axis] // 2) + else: + index_tuple_for_char_pt.append(self.shape[axis] - 1) + return tuple(index_tuple_for_char_pt) def _det_characteristic_points(self): shape = self.shape @@ -177,22 +175,3 @@ def __str__(self): s += str(typ) + ":\t" + str(stl) + "\n" return s - def _binaryop(self, other, op): - stl = deepcopy(self) - assert stl.shape == other.shape - - for char_pt, single_stl in stl.data.items(): - other_single_stl = other.data[char_pt] - for o, c in other_single_stl.items(): - if o in single_stl: - single_stl[o] = op(single_stl[o], c) - else: - single_stl[o] = op(0, c) - - return stl - - def __add__(self, other): - return self._binaryop(other, operator.__add__) - - def __sub__(self, other): - return self._binaryop(other, operator.__sub__) diff --git a/test/test_bugs.py b/test/test_bugs.py index 5386f84..20be4fa 100644 --- a/test/test_bugs.py +++ b/test/test_bugs.py @@ -75,6 +75,12 @@ def test_accuracy_should_be_passed_down_to_stencil(self): stl = stencil1.data[char_pt] self.assertDictEqual(expected[char_pt], stl) + d1x = FinDiff(0, dx, 1) + stencil1 = d1x.stencil(shape, acc=4) + for char_pt in stencil1.data: + stl = stencil1.data[char_pt] + self.assertDictEqual(expected[char_pt], stl) + if __name__ == '__main__': unittest.main() diff --git a/test/test_findiff.py b/test/test_findiff.py index 82d3836..5204df2 100644 --- a/test/test_findiff.py +++ b/test/test_findiff.py @@ -276,6 +276,19 @@ def test_local_stencil_operator_multiplication(self): np.testing.assert_array_almost_equal(np.ones_like(X), du_dx) + def test_local_stencil_operator_with_coef(self): + x = np.linspace(0, 10, 101) + y = np.linspace(0, 10, 101) + X, Y = np.meshgrid(x, y, indexing='ij') + u = X * Y + dx = x[1] - x[0] + dy = y[1] - y[0] + d1x = Coef(2) * FinDiff(0, dx) * FinDiff(1, dy) + stencil1 = d1x.stencil(u.shape) + du_dx = stencil1.apply_all(u) + + np.testing.assert_array_almost_equal(2*np.ones_like(X), du_dx) + def dict_almost_equal(self, d1, d2): self.assertEqual(len(d1), len(d2))