Skip to content

Commit

Permalink
Add initial support for callback arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
rjfarmer committed Dec 2, 2023
1 parent 992a1df commit 9707602
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,21 @@ quad precision values is planned but not yet supported. Quad values can also not
``pyQuadp`` is currently an optional requirement, you must manually install it, it does not get auto-installed when ``gfort2py`` is installed. If you try to access a quad precision variable without ``pyQuadp`` you should get a ``TypeError``.


### Callback arguments

To pass a Fortran function as a callback argument to another function then pass the function directly:

````python

y = x.callback_function(1)

y = x.another_function(x.callback_function)

````

Currently only Fortran functions can be passed. No checking is done to ensure that the callback function has the
correct signature to be a callback to the second function.


## Testing

Expand Down
45 changes: 34 additions & 11 deletions gfort2py/fProc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, lib, obj, allobjs, **kwargs):
self._lib = lib
self._return_value = None

self._func = getattr(lib, self.mangled_name)
self._func = getattr(self._lib, self.mangled_name)

@property
def mangled_name(self):
Expand All @@ -64,10 +64,17 @@ def mangled_name(self):
def in_dll(self, lib):
return self._func

def from_param(self, *args):
return self._func

@property
def module(self):
return self.obj.head.module

@property
def value(self):
return self._func

@property
def name(self):
return self.obj.name
Expand Down Expand Up @@ -121,11 +128,16 @@ def args_check(self, *args, **kwargs):
arguments = []
# Build list of inputs
for fval in self.obj.args():
var = fVar(self._allobjs[fval.ref], allobjs=self._allobjs)
if self._allobjs[fval.ref].is_procedure():
var = None
name = self._allobjs[fval.ref].name
else:
var = fVar(self._allobjs[fval.ref], allobjs=self._allobjs)
name = var.name

try:
x = kwargs[var.name]
except KeyError:
x = kwargs[name]
except (KeyError, AttributeError):
if count <= len(args):
x = args[count]
count = count + 1
Expand All @@ -139,6 +151,9 @@ def args_check(self, *args, **kwargs):
var = x
x = var.value

if isinstance(x, fProc):
var = x

arguments.append(variable(x, var, fval.ref))

return arguments
Expand All @@ -148,7 +163,12 @@ def args_convert(self, input_args):
args_end = []
# Convert to ctypes
for var in input_args:
_, a, e = var.fvar.to_proc(var.value, input_args)
if var.fvar.obj.is_procedure():
a = var.value.value
e = None
# print(a.value)
else:
_, a, e = var.fvar.to_proc(var.value, input_args)
args.append(a)
if e is not None:
args_end.append(e)
Expand Down Expand Up @@ -187,13 +207,16 @@ def _convert_result(self, result, args):

if len(self.obj.args()):
for var in self.input_args:
if var.fvar.unpack:
try:
x = ptr_unpack(var.fvar.value)
except AttributeError: # unset optional arguments
x = None
if var.fvar.obj.is_procedure():
x = var.value
else:
x = var.fvar.cvalue
if var.fvar.unpack:
try:
x = ptr_unpack(var.fvar.value)
except AttributeError: # unset optional arguments
x = None
else:
x = var.fvar.cvalue

name = self._allobjs[var.symbol_ref].name
if hasattr(x, "_type_"):
Expand Down
8 changes: 7 additions & 1 deletion tests/proc_ptrs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
x = gf.fFort(SO, MOD)


@pytest.mark.skip
class TestProcPtrsMethods:
def assertEqual(self, x, y):
assert x == y

@pytest.mark.skip
def test_proc_ptr_ffunc(self):
x.sub_null_proc_ptr()
with pytest.raises(AttributeError) as cm:
Expand All @@ -37,6 +37,7 @@ def test_proc_ptr_ffunc(self):
y2 = x.p_func_func_run_ptr(5)
self.assertEqual(y.result, y2.result)

@pytest.mark.skip
def test_proc_ptr_ffunc2(self):
x.sub_null_proc_ptr()
with pytest.raises(AttributeError) as cm:
Expand All @@ -46,6 +47,7 @@ def test_proc_ptr_ffunc2(self):
y = x.p_func_func_run_ptr2(10)
self.assertEqual(y.result, 100)

@pytest.mark.skip
def test_proc_update(self):
x.sub_null_proc_ptr()
x.p_func_func_run_ptr = x.func_func_run
Expand All @@ -63,6 +65,10 @@ def test_proc_func_arg(self):
y = x.func_func_arg(x.func_func_run)
self.assertEqual(y.result, 10)

y = x.func_func_arg(func=x.func_func_run)
self.assertEqual(y.result, 10)

@pytest.mark.skip
def test_proc_proc_func_arg(self):
x.sub_null_proc_ptr()
x.p_func_func_run_ptr = x.func_func_run
Expand Down

0 comments on commit 9707602

Please sign in to comment.