From 97076024a787dd5007b1fc235ea98b08b772e1e6 Mon Sep 17 00:00:00 2001 From: Robert Farmer Date: Sat, 2 Dec 2023 14:31:22 +0000 Subject: [PATCH] Add initial support for callback arguments --- README.md | 15 ++++++++++++++ gfort2py/fProc.py | 45 +++++++++++++++++++++++++++++++---------- tests/proc_ptrs_test.py | 8 +++++++- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 0f463b5..30545a4 100644 --- a/README.md +++ b/README.md @@ -238,6 +238,21 @@ quad precision values is planned but not yet supported. Quad values can also not ``pyQuadp`` is currently an optional requirement, you must manually install it, it does not get auto-installed when ``gfort2py`` is installed. If you try to access a quad precision variable without ``pyQuadp`` you should get a ``TypeError``. +### Callback arguments + +To pass a Fortran function as a callback argument to another function then pass the function directly: + +````python + +y = x.callback_function(1) + +y = x.another_function(x.callback_function) + +```` + +Currently only Fortran functions can be passed. No checking is done to ensure that the callback function has the +correct signature to be a callback to the second function. + ## Testing diff --git a/gfort2py/fProc.py b/gfort2py/fProc.py index 7187210..82131a0 100644 --- a/gfort2py/fProc.py +++ b/gfort2py/fProc.py @@ -55,7 +55,7 @@ def __init__(self, lib, obj, allobjs, **kwargs): self._lib = lib self._return_value = None - self._func = getattr(lib, self.mangled_name) + self._func = getattr(self._lib, self.mangled_name) @property def mangled_name(self): @@ -64,10 +64,17 @@ def mangled_name(self): def in_dll(self, lib): return self._func + def from_param(self, *args): + return self._func + @property def module(self): return self.obj.head.module + @property + def value(self): + return self._func + @property def name(self): return self.obj.name @@ -121,11 +128,16 @@ def args_check(self, *args, **kwargs): arguments = [] # Build list of inputs for fval in self.obj.args(): - var = fVar(self._allobjs[fval.ref], allobjs=self._allobjs) + if self._allobjs[fval.ref].is_procedure(): + var = None + name = self._allobjs[fval.ref].name + else: + var = fVar(self._allobjs[fval.ref], allobjs=self._allobjs) + name = var.name try: - x = kwargs[var.name] - except KeyError: + x = kwargs[name] + except (KeyError, AttributeError): if count <= len(args): x = args[count] count = count + 1 @@ -139,6 +151,9 @@ def args_check(self, *args, **kwargs): var = x x = var.value + if isinstance(x, fProc): + var = x + arguments.append(variable(x, var, fval.ref)) return arguments @@ -148,7 +163,12 @@ def args_convert(self, input_args): args_end = [] # Convert to ctypes for var in input_args: - _, a, e = var.fvar.to_proc(var.value, input_args) + if var.fvar.obj.is_procedure(): + a = var.value.value + e = None + # print(a.value) + else: + _, a, e = var.fvar.to_proc(var.value, input_args) args.append(a) if e is not None: args_end.append(e) @@ -187,13 +207,16 @@ def _convert_result(self, result, args): if len(self.obj.args()): for var in self.input_args: - if var.fvar.unpack: - try: - x = ptr_unpack(var.fvar.value) - except AttributeError: # unset optional arguments - x = None + if var.fvar.obj.is_procedure(): + x = var.value else: - x = var.fvar.cvalue + if var.fvar.unpack: + try: + x = ptr_unpack(var.fvar.value) + except AttributeError: # unset optional arguments + x = None + else: + x = var.fvar.cvalue name = self._allobjs[var.symbol_ref].name if hasattr(x, "_type_"): diff --git a/tests/proc_ptrs_test.py b/tests/proc_ptrs_test.py index 78a95c8..7d125c5 100644 --- a/tests/proc_ptrs_test.py +++ b/tests/proc_ptrs_test.py @@ -17,11 +17,11 @@ x = gf.fFort(SO, MOD) -@pytest.mark.skip class TestProcPtrsMethods: def assertEqual(self, x, y): assert x == y + @pytest.mark.skip def test_proc_ptr_ffunc(self): x.sub_null_proc_ptr() with pytest.raises(AttributeError) as cm: @@ -37,6 +37,7 @@ def test_proc_ptr_ffunc(self): y2 = x.p_func_func_run_ptr(5) self.assertEqual(y.result, y2.result) + @pytest.mark.skip def test_proc_ptr_ffunc2(self): x.sub_null_proc_ptr() with pytest.raises(AttributeError) as cm: @@ -46,6 +47,7 @@ def test_proc_ptr_ffunc2(self): y = x.p_func_func_run_ptr2(10) self.assertEqual(y.result, 100) + @pytest.mark.skip def test_proc_update(self): x.sub_null_proc_ptr() x.p_func_func_run_ptr = x.func_func_run @@ -63,6 +65,10 @@ def test_proc_func_arg(self): y = x.func_func_arg(x.func_func_run) self.assertEqual(y.result, 10) + y = x.func_func_arg(func=x.func_func_run) + self.assertEqual(y.result, 10) + + @pytest.mark.skip def test_proc_proc_func_arg(self): x.sub_null_proc_ptr() x.p_func_func_run_ptr = x.func_func_run