From 78567a712284c2853bd6f6baf0b7701ff7b55741 Mon Sep 17 00:00:00 2001 From: e-moral-sanchez <88042165+e-moral-sanchez@users.noreply.github.com> Date: Fri, 1 Dec 2023 23:52:12 +0100 Subject: [PATCH] Import functions from cmath when domain is complex (#358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Functions in printed code now are imported from `cmath` instead of `Numpy` in the case of a complex domain. This solves #313. If the domain is real they are imported from `math` as before. This change requires Pyccel version >= 1.9.2 which supports the `cmath` library. --------- Co-authored-by: Yaman Güçlü --- psydac/api/ast/parser.py | 19 ++++++------------- pyproject.toml | 4 ++-- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/psydac/api/ast/parser.py b/psydac/api/ast/parser.py index 9248e7ffd..eeeb2fafe 100644 --- a/psydac/api/ast/parser.py +++ b/psydac/api/ast/parser.py @@ -541,19 +541,12 @@ def _visit_DefNode(self, expr, **kwargs): body = tuple(inits) + body name = expr.name - # TODO : when Pyccel will work on the cmath library, we should import the math function from cmath and not from numpy - # If we are with complex object, we should import the mathematical function from numpy and not math to handle complex value. - if expr.domain_dtype=='complex': - numpy_imports = ('array', 'zeros', 'zeros_like', 'floor', *self._math_functions) - imports = [Import('numpy', numpy_imports)] + \ - [*expr.imports] - # Else we import them from math - else: - math_imports = (*self._math_functions,) - numpy_imports = ('array', 'zeros', 'zeros_like', 'floor') - imports = [Import('numpy', numpy_imports)] + \ - ([Import('math', math_imports)] if math_imports else []) + \ - [*expr.imports] + math_library = 'cmath' if expr.domain_dtype=='complex' else 'math' # Function names are the same + math_imports = (*self._math_functions,) + numpy_imports = ('array', 'zeros', 'zeros_like', 'floor') + imports = [Import('numpy', numpy_imports)] + \ + ([Import(math_library, math_imports)] if math_imports else []) + \ + [*expr.imports] results = [self._visit(a) for a in expr.results] diff --git a/pyproject.toml b/pyproject.toml index fbc42f1bf..26220066d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools >= 64.0", "wheel", "numpy", "pyccel >= 1.8.1"] +requires = ["setuptools >= 64.0", "wheel", "numpy", "pyccel >= 1.9.2"] build-backend = "setuptools.build_meta" [project] @@ -33,7 +33,7 @@ dependencies = [ # Our packages from PyPi 'sympde == 0.18.1', - 'pyccel >= 1.8.1', + 'pyccel >= 1.9.2', 'gelato == 0.12', # In addition, we depend on mpi4py and h5py (MPI version).