Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-baer committed Dec 22, 2020
1 parent dd8ead9 commit 6909011
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 52 deletions.
4 changes: 2 additions & 2 deletions findiff/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def set_accuracy(self, acc):
if isinstance(self.right, Operator):
self.right.set_accuracy(acc)

def stencil(self, shape, h=None, acc=None, old_stl=None):
return Stencil(self, shape)
def stencil(self, shape, acc=None):
return Stencil(self, shape, acc=acc)


class Plus(BinaryOperator):
Expand Down
6 changes: 2 additions & 4 deletions findiff/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,10 @@ def apply(self, rhs, *args, **kwargs):
args = self.coords,
return self.pds(rhs, *args, **kwargs)

def stencil(self, shape, h=None, acc=None, old_stl=None):
if h is None and self.spac is not None:
h = self.spac
def stencil(self, shape, acc=None):
if acc is None and self.acc is not None:
acc = self.acc
return Stencil(self, shape)
return Stencil(self, shape, acc=acc)

def matrix(self, shape, h=None, acc=None):
if acc is None:
Expand Down
71 changes: 25 additions & 46 deletions findiff/stencils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from itertools import product
from copy import deepcopy
import operator
import numpy as np
from .coefs import coefficients
from .utils import to_long_index, to_index_tuple


Expand All @@ -11,7 +8,7 @@ class Stencil(object):
Represent the finite difference stencil for a given differential operator.
"""

def __init__(self, diff_op, shape, old_stl=None):
def __init__(self, diff_op, shape, old_stl=None, acc=None):
"""
Constructor for Stencil objects.
Expand All @@ -35,10 +32,11 @@ def __init__(self, diff_op, shape, old_stl=None):
self.shape = shape
self.diff_op = diff_op
self.char_pts = self._det_characteristic_points()
if old_stl:
self.data = old_stl.data
else:
self.data = {}
self.acc = None
if acc is not None:
self.acc = acc

self.data = {}

self._create_stencil()

Expand Down Expand Up @@ -136,34 +134,34 @@ def type_for_point(self, idx):

def _create_stencil(self):

ndim = len(self.shape)
data = self.data

matrix = self.diff_op.matrix(self.shape)
matrix = self.diff_op.matrix(self.shape, acc=self.acc)

for pt in self.char_pts:

coef_dict = {}
data[pt] = coef_dict
char_point_stencil = {}
self.data[pt] = char_point_stencil

index_for_char_pt = []
for axis, key in enumerate(pt):
if key == 'L':
index_for_char_pt.append(0)
elif key == 'C':
index_for_char_pt.append(self.shape[axis] // 2)
else:
index_for_char_pt.append(self.shape[axis] - 1)
index_tuple_for_char_pt = self._typical_index_tuple_for_char_point(pt)
long_index_for_char_pt = to_long_index(index_tuple_for_char_pt, self.shape)

long_index_for_char_pt = to_long_index(index_for_char_pt, self.shape)
row = matrix[long_index_for_char_pt, :]
long_row_inds, long_col_inds = row.nonzero()

for long_offset_ind in long_col_inds:
offset_ind_tuple = np.array(to_index_tuple(long_offset_ind, self.shape), dtype=np.int)
offset_ind_tuple -= np.array(index_for_char_pt, dtype=np.int)
coef_dict[tuple(offset_ind_tuple)] = row[0, long_offset_ind]

return None
offset_ind_tuple -= np.array(index_tuple_for_char_pt, dtype=np.int)
char_point_stencil[tuple(offset_ind_tuple)] = row[0, long_offset_ind]

def _typical_index_tuple_for_char_point(self, pt):
index_tuple_for_char_pt = []
for axis, key in enumerate(pt):
if key == 'L':
index_tuple_for_char_pt.append(0)
elif key == 'C':
index_tuple_for_char_pt.append(self.shape[axis] // 2)
else:
index_tuple_for_char_pt.append(self.shape[axis] - 1)
return tuple(index_tuple_for_char_pt)

def _det_characteristic_points(self):
shape = self.shape
Expand All @@ -177,22 +175,3 @@ def __str__(self):
s += str(typ) + ":\t" + str(stl) + "\n"
return s

def _binaryop(self, other, op):
stl = deepcopy(self)
assert stl.shape == other.shape

for char_pt, single_stl in stl.data.items():
other_single_stl = other.data[char_pt]
for o, c in other_single_stl.items():
if o in single_stl:
single_stl[o] = op(single_stl[o], c)
else:
single_stl[o] = op(0, c)

return stl

def __add__(self, other):
return self._binaryop(other, operator.__add__)

def __sub__(self, other):
return self._binaryop(other, operator.__sub__)
6 changes: 6 additions & 0 deletions test/test_bugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def test_accuracy_should_be_passed_down_to_stencil(self):
stl = stencil1.data[char_pt]
self.assertDictEqual(expected[char_pt], stl)

d1x = FinDiff(0, dx, 1)
stencil1 = d1x.stencil(shape, acc=4)
for char_pt in stencil1.data:
stl = stencil1.data[char_pt]
self.assertDictEqual(expected[char_pt], stl)


if __name__ == '__main__':
unittest.main()
13 changes: 13 additions & 0 deletions test/test_findiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ def test_local_stencil_operator_multiplication(self):

np.testing.assert_array_almost_equal(np.ones_like(X), du_dx)

def test_local_stencil_operator_with_coef(self):
x = np.linspace(0, 10, 101)
y = np.linspace(0, 10, 101)
X, Y = np.meshgrid(x, y, indexing='ij')
u = X * Y
dx = x[1] - x[0]
dy = y[1] - y[0]
d1x = Coef(2) * FinDiff(0, dx) * FinDiff(1, dy)
stencil1 = d1x.stencil(u.shape)
du_dx = stencil1.apply_all(u)

np.testing.assert_array_almost_equal(2*np.ones_like(X), du_dx)

def dict_almost_equal(self, d1, d2):

self.assertEqual(len(d1), len(d2))
Expand Down

0 comments on commit 6909011

Please sign in to comment.