Skip to content

Commit

Permalink
new api
Browse files Browse the repository at this point in the history
  • Loading branch information
DjDeveloperr authored Jun 25, 2022
1 parent a1eb9d7 commit 8156e2b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 51 deletions.
130 changes: 86 additions & 44 deletions src/python.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// deno-lint-ignore-file no-explicit-any
// deno-lint-ignore-file no-explicit-any no-fallthrough

import { py } from "./ffi.ts";
import { cstr, SliceItemRegExp } from "./util.ts";

Expand Down Expand Up @@ -47,8 +48,9 @@ export interface PythonProxy {
*
* - `Set` becomes `set` in Python.
*
* - `function` becomes a Python function. First argument passed is
* an object containing kwargs and rest arguments are positional.
* - `Callback` (custom type) becomes a Python function. First argument
* passed is an object containing kwargs and rest arguments are
* positional.
*
* If you pass a PyObject, it is used as-is.
*
Expand All @@ -69,10 +71,10 @@ export type PythonConvertible =
| { [key: string]: PythonConvertible }
| Map<PythonConvertible, PythonConvertible>
| Set<PythonConvertible>
| PythonJSCallback;
| Callback;

export type PythonJSCallback = (
kwargs: Record<string, any>,
kwargs: any,
...args: any[]
) => PythonConvertible;

Expand Down Expand Up @@ -109,6 +111,59 @@ export function kw(
return new NamedArgument(strings[0].split("=")[0].trim(), value);
}

/**
* Wraps a JS function into Python callback which can be
* passed to Python land. It must be destroyed explicitly
* to free up resources on Rust-side.
*
* Example:
* ```ts
* // Creating
* const add = new Callback((_, a: number, b: number) => {
* return a + b;
* });
* // or
* const add = new Callback((kw: { a: number, b: number }) => {
* return kw.a + kw.b;
* });
*
* // Usage
* some_python_func(add);
*
* // Destroy
* add.destroy();
* ```
*/
export class Callback {
unsafe: Deno.UnsafeCallback;

constructor(public callback: PythonJSCallback) {
this.unsafe = new Deno.UnsafeCallback(
{
parameters: ["pointer", "pointer", "pointer"],
result: "pointer",
},
(
_self: bigint,
args: bigint,
kwargs: bigint,
) => {
return PyObject.from(callback(
kwargs === 0n ? {} : Object.fromEntries(
new PyObject(kwargs).asDict()
.entries(),
),
...(args === 0n ? [] : new PyObject(args).valueOf()),
)).handle;
},
);
}

destroy() {
this.unsafe.close();
}
}

/**
* Represents a Python object.
*
Expand Down Expand Up @@ -377,6 +432,24 @@ export class PyObject {
py.PyList_SetItem(list, i, PyObject.from(v[i]).owned.handle);
}
return new PyObject(list);
} else if (v instanceof Callback) {
const struct = new Uint8Array(8 + 8 + 4 + 8);
const view = new DataView(struct.buffer);
const LE =
new Uint8Array(new Uint32Array([0x12345678]).buffer)[0] !== 0x7;
const nameBuf = new TextEncoder().encode(
"JSCallback:" + (v.callback.name || "anonymous") + "\0",
);
view.setBigUint64(0, Deno.UnsafePointer.of(nameBuf), LE);
view.setBigUint64(8, v.unsafe.pointer, LE);
view.setInt32(16, 0x1 | 0x2, LE);
view.setBigUint64(20, Deno.UnsafePointer.of(nameBuf), LE);
const fn = py.PyCFunction_NewEx(
struct,
PyObject.from(null).handle,
0n,
);
return new PyObject(fn);
} else if (v instanceof PyObject) {
return v;
} else if (v instanceof Set) {
Expand Down Expand Up @@ -426,45 +499,6 @@ export class PyObject {
case "function": {
if (ProxiedPyObject in v) {
return (v as any)[ProxiedPyObject];
} else {
const resource = new Deno.UnsafeCallback(
{
parameters: ["pointer", "pointer", "pointer"],
result: "pointer",
},
(
_self: bigint,
args: bigint,
kwargs: bigint,
) => {
return PyObject.from(v(
kwargs === 0n ? {} : Object.fromEntries(
new PyObject(kwargs).asDict()
.entries(),
),
...(args === 0n ? [] : new PyObject(args).valueOf()),
)).handle;
},
);
const ptr = Deno.UnsafePointer.of(resource);
const struct = new Uint8Array(8 + 8 + 4 + 8);
const view = new DataView(struct.buffer);
const LE =
new Uint8Array(new Uint32Array([0x12345678]).buffer)[0] !== 0x7;
const nameBuf = new TextEncoder().encode(
(v.name || "anonymous") + "\0",
);
const docBuf = nameBuf;
view.setBigUint64(0, Deno.UnsafePointer.of(nameBuf), LE);
view.setBigUint64(8, ptr, LE);
view.setInt32(16, 0x1 | 0x2, LE);
view.setBigUint64(20, Deno.UnsafePointer.of(docBuf), LE);
const fn = py.PyCFunction_NewEx(
struct,
PyObject.from(null).handle,
0n,
);
return new PyObject(fn);
}
}

Expand Down Expand Up @@ -785,6 +819,9 @@ export class Python {
/** Python `Ellipsis` type proxied object */
Ellipsis: any;

/** Shortcut to kw function (template string tag) */
kw = kw;

constructor() {
py.Py_Initialize();
this.builtins = this.import("builtins");
Expand Down Expand Up @@ -858,6 +895,11 @@ export class Python {
import(name: string) {
return this.importObject(name).proxy;
}

/** Shortcut to create Callback instance. */
callback(cb: PythonJSCallback): Callback {
return new Callback(cb);
}
}

/**
Expand Down
13 changes: 6 additions & 7 deletions test/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,20 @@ async def test():
assertEquals(aio.run(test()).valueOf(), "ok");
});

Deno.test("callback", {
sanitizeResources: false,
}, () => {
Deno.test("callback", () => {
const { call } = python.runModule(
`
def call(cb):
return cb(61, reduce=1) + 1
`,
"cb_test.py",
);

const cb = python.callback((kw: { reduce: number }, num: number) => {
return num - kw.reduce + 8;
});
assertEquals(
call((kw: { reduce: number }, num: number) => {
return num - kw.reduce + 8;
}).valueOf(),
call(cb).valueOf(),
69,
);
cb.destroy();
});

0 comments on commit 8156e2b

Please sign in to comment.