Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SesquilinearForm and complex space #133

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
36 changes: 30 additions & 6 deletions sympde/calculus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@

from operator import mul, add
from functools import reduce
from math import sqrt

from sympy import Indexed, sympify
from sympy import Matrix, ImmutableDenseMatrix
from sympy import cacheit
from sympy import cacheit, conjugate
from sympy.core import Basic
from sympy.core import Add, Mul, Pow
from sympy.core.containers import Tuple
Expand Down Expand Up @@ -182,6 +183,17 @@ def is_zero(x):
else:
return x == 0

# def is_Function_complex(list):
# res=False
# for expr in list:
# if isinstance(expr, (ScalarFunction, VectorFunction)):
# res = res or expr.is_complex
# else:
# res = res or is_Function_complex(expr.args)
# if res:
# break
# return res

#==============================================================================
class BasicOperator(CalculusFunction):
"""
Expand Down Expand Up @@ -266,6 +278,7 @@ def is_scalar(atom):
"""
return is_constant(atom) or isinstance(atom, ScalarFunction)


#==============================================================================
# TODO add dot(u,u) +2*dot(u,v) + dot(v,v) = dot(u+v,u+v)
# now we only have dot(u,u) + dot(u,v)+ dot(v,u) + dot(v,v) = dot(u+v,u+v)
Expand Down Expand Up @@ -351,9 +364,14 @@ def __new__(cls, arg1, arg2, **options):
args_2 = [i for i in b if not i.is_commutative]
c2 = [i for i in b if not i in args_2]


c = Mul(*c1)*Mul(*c2)
# 1D case where everything is commutative
if args_1==[] and args_2==[]:
return(c)

a = reduce(mul, args_1)
b = reduce(mul, args_2)
c = Mul(*c1)*Mul(*c2)

if str(a) > str(b):
a,b = b,a
Expand Down Expand Up @@ -505,12 +523,19 @@ def __new__(cls, arg1, arg2, **options):
args_2 = [i for i in b if not i.is_commutative]
c2 = [i for i in b if not i in args_2]

# if is_Function_complex(b):
# args_2 = [i if isinstance(i, conjugate) else conjugate(i) for i in args_2]
# c2 = [i if isinstance(i, conjugate) else conjugate(i) for i in c2]

c = Mul(*c1)*Mul(*c2)
if args_1==[] and args_2==[]:
return(c)

a = reduce(mul, args_1)
b = reduce(mul, args_2)
c = Mul(*c1)*Mul(*c2)

if str(a) > str(b):
a,b = b,a
# if str(a) > str(b):
# a,b = b,a

obj = Basic.__new__(cls, a, b)

Expand Down Expand Up @@ -780,7 +805,6 @@ def eval(cls, expr):
raise ArgumentTypeError(msg)

return cls(expr, evaluate=False)

#==============================================================================
class Curl(DiffOperator):
"""
Expand Down
6 changes: 4 additions & 2 deletions sympde/expr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ class BasicExpr(Expr):
is_Function = True
is_linear = False
is_bilinear = False
# TODO put is_sesquilinear=False here but it don't work
#is_sesquilinear = False
is_functional = False

@property
def fields(self):
atoms = self.expr.atoms(ScalarFunction, VectorFunction)
if self.is_bilinear or self.is_linear:
if self.is_bilinear or self.is_sesquilinear or self.is_linear:
args = self.variables
if self.is_bilinear:
if self.is_bilinear or self.is_sesquilinear :
args = args[0]+args[1]
fields = tuple(atoms.difference(args))
else:
Expand Down
13 changes: 8 additions & 5 deletions sympde/expr/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from sympde.topology.space import VectorFunction, IndexedVectorFunction
from sympde.topology import Boundary, NormalVector, TangentVector
from sympde.topology import Trace, trace_0, trace_1
from sympde.calculus import grad, dot
from sympde.calculus import grad, dot, inner
from sympde.core.utils import random_string

from .expr import BilinearForm, LinearForm
from .expr import BilinearForm, LinearForm, SesquilinearForm
from .expr import linearize
from .errors import ( UnconsistentLhsError, UnconsistentRhsError,
UnconsistentArgumentsError, UnconsistentBCError )
Expand Down Expand Up @@ -192,8 +192,8 @@ class Equation(Basic):

def __new__(cls, lhs, rhs, trials, tests, bc=None, constraint=None):
# ...
if not isinstance(lhs, BilinearForm):
raise UnconsistentLhsError('> lhs must be a bilinear')
if not isinstance(lhs, (BilinearForm, SesquilinearForm)):
raise UnconsistentLhsError('> lhs must be a bilinear or a Sesquilinear')

if not isinstance(rhs, LinearForm):
raise UnconsistentRhsError('> rhs must be a linear')
Expand Down Expand Up @@ -380,7 +380,10 @@ def __new__(cls, form, fields, bc=None, trials=None):
def find(trials, *, forall, lhs, rhs, bc=None, constraint=None):

tests = forall
lhs = BilinearForm((trials, tests), lhs)
if tests.space.codomain_complex:
lhs = SesquilinearForm((trials, tests), lhs)
else:
lhs = BilinearForm((trials, tests), lhs)
rhs = LinearForm( tests , rhs)

return Equation(lhs, rhs, trials, tests, bc=bc, constraint=constraint)
8 changes: 6 additions & 2 deletions sympde/expr/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from sympy import Abs, S, cacheit
from sympy import Indexed, Matrix, ImmutableDenseMatrix
from sympy import Indexed, Matrix, ImmutableDenseMatrix, conjugate
from sympy import expand
from sympy.core import Basic, Symbol
from sympy.core import Add, Mul, Pow
Expand Down Expand Up @@ -139,7 +139,7 @@ def _get_trials_tests(expr, *, flatten=False):
if not isinstance(expr, (BasicForm, BasicExpr)):
raise TypeError("Expression must be of type BasicForm or BasicExpr, got '{}' instead".format(type(expr)))

if expr.is_bilinear:
if expr.is_bilinear or expr.is_sesquilinear:
trials = _unpack_functions(expr.variables[0]) if flatten else expr.variables[0]
tests = _unpack_functions(expr.variables[1]) if flatten else expr.variables[1]

Expand Down Expand Up @@ -828,6 +828,10 @@ def eval(cls, expr, domain):
expr = cls(expr.expr, domain=domain)
return LogicalExpr(expr, domain=domain)

elif isinstance(expr, conjugate):
expr = cls(expr.args[0], domain=domain)
return conjugate(expr)

return expr


Expand Down
Loading