diff --git a/Cargo.lock b/Cargo.lock index 67a7ba4..9c5f180 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,14 +223,12 @@ dependencies = [ [[package]] name = "cjyafn" -version = "0.1.2" +version = "0.2.0" dependencies = [ "cbindgen", "get-size", "jyafn", - "libloading", "serde_json", - "smallvec", ] [[package]] @@ -697,7 +695,7 @@ dependencies = [ [[package]] name = "jyafn" -version = "0.1.2" +version = "0.2.0" dependencies = [ "bincode", "byte-slice-cast", @@ -710,9 +708,6 @@ dependencies = [ "home", "lazy_static", "libloading", - "maplit", - "memmap", - "no-panic", "qbe", "rand", "scopeguard", @@ -721,7 +716,6 @@ dependencies = [ "serde_derive", "serde_json", "serde_with", - "smallvec", "special-fun", "tempfile", "thiserror", @@ -743,14 +737,12 @@ dependencies = [ [[package]] name = "jyafn-python" -version = "0.1.2" +version = "0.2.0" dependencies = [ - "byte-slice-cast", "chrono", "get-size", "jyafn", "pyo3", - "semver", "serde_json", ] @@ -831,28 +823,12 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" -[[package]] -name = "maplit" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" - [[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memmap" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6585fd95e7bb50d6cc31e20d4cf9afb4e2ba16c5846fc76793f11218da9c475b" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "memoffset" version = "0.9.1" @@ -877,17 +853,6 @@ dependencies = [ "adler", ] -[[package]] -name = "no-panic" -version = "0.1.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8540b7d99a20166178b42a05776aef900cdbfec397f861dfc7819bf1d7760b3d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.71", -] - [[package]] name = "nom" version = "7.1.3" @@ -1285,9 +1250,6 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -dependencies = [ - "serde", -] [[package]] name = "special-fun" diff --git a/Makefile b/Makefile index 970dbac..e9c459d 100644 --- a/Makefile +++ b/Makefile @@ -35,3 +35,9 @@ install-dylib: cjyafn install-so: cjyafn cp target/release/libcjyafn.so /usr/local/lib/ + +bump-minor: + cargo set-version --bump minor --package cjyafn jyafn jyafn-python + +bump: + cargo set-version --bump patch --package cjyafn --package jyafn --package jyafn-python diff --git a/cjyafn/Cargo.toml b/cjyafn/Cargo.toml index d6aab05..56a25cb 100644 --- a/cjyafn/Cargo.toml +++ b/cjyafn/Cargo.toml @@ -1,21 +1,18 @@ [package] name = "cjyafn" -version = "0.1.2" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -doc = false name = "cjyafn" crate-type = ["cdylib"] [dependencies] get-size = "0.1.4" jyafn = { path = "../jyafn", default-features = false } -libloading = "0.8.4" serde_json = "1.0.115" -smallvec = { version = "1.13.2", features = ["serde"] } [build-dependencies] cbindgen = "0.26.0" diff --git a/cjyafn/src/lib.rs b/cjyafn/src/lib.rs index 3df77be..761a4dd 100644 --- a/cjyafn/src/lib.rs +++ b/cjyafn/src/lib.rs @@ -291,7 +291,9 @@ pub unsafe extern "C" fn graph_to_json(graph: *const ()) -> *const c_char { /// Expects `graph` to be a valid pointer to a graph. #[no_mangle] pub unsafe extern "C" fn graph_render(graph: *const ()) -> Outcome { - with(graph, |graph: &Graph| new_c_str(graph.render().to_string())) + try_with(graph, |graph: &Graph| { + Ok(new_c_str(graph.render()?.to_string())) + }) } /// # Safety @@ -361,7 +363,7 @@ pub unsafe extern "C" fn layout_from_json(json: *const c_char) -> Outcome { /// Expects the `layout` parameter to be a valid pointer to a layout. #[no_mangle] pub unsafe extern "C" fn layout_size(layout: *const ()) -> usize { - with_unchecked(layout, |layout: &Layout| layout.size()) + with_unchecked(layout, |layout: &Layout| layout.size()).in_bytes() } /// # Safety @@ -416,6 +418,14 @@ pub unsafe extern "C" fn layout_is_struct(layout: *const ()) -> bool { }) } +/// # Safety +/// +/// Expects the `layout` parameter to be a valid pointer to a layout. +#[no_mangle] +pub unsafe extern "C" fn layout_is_tuple(layout: *const ()) -> bool { + with_unchecked(layout, |layout: &Layout| matches!(layout, Layout::Tuple(_))) +} + /// # Safety /// /// Expects the `layout` parameter to be a valid pointer to a layout. @@ -454,6 +464,36 @@ pub unsafe extern "C" fn layout_as_struct(layout: *const ()) -> *const () { }) } +/// # Safety +/// +/// Expects the `strct` parameter to be a valid pointer to a jyafn struct. +#[no_mangle] +pub unsafe extern "C" fn layout_tuple_size(layout: *const ()) -> usize { + with_unchecked(layout, |layout: &Layout| { + if let Layout::Tuple(t) = layout { + return t.len(); + } + + 0 + }) +} + +/// # Safety +/// +/// Expects the `layout` parameter to be a valid pointer to a layout. +#[no_mangle] +pub unsafe extern "C" fn layout_get_tuple_item(layout: *const (), index: usize) -> *const () { + with_unchecked(layout, |layout: &Layout| { + if let Layout::Tuple(t) = layout { + if index < t.len() { + return &t[index] as *const Layout as *const (); + } + } + + std::ptr::null() + }) +} + /// # Safety /// /// Expects the `layout` parameter to be a valid pointer to a layout. @@ -565,7 +605,7 @@ pub unsafe extern "C" fn function_name(func: *const ()) -> *const c_char { /// Expects the `func` parameter to be a valid pointer to a jyafn function. #[no_mangle] pub unsafe extern "C" fn function_input_size(func: *const ()) -> usize { - with_unchecked(func, |func: &Function| func.input_size()) + with_unchecked(func, |func: &Function| func.input_size()).in_bytes() } /// # Safety @@ -573,7 +613,7 @@ pub unsafe extern "C" fn function_input_size(func: *const ()) -> usize { /// Expects the `func` parameter to be a valid pointer to a jyafn function. #[no_mangle] pub unsafe extern "C" fn function_output_size(func: *const ()) -> usize { - with_unchecked(func, |func: &Function| func.output_size()) + with_unchecked(func, |func: &Function| func.output_size()).in_bytes() } /// # Safety @@ -692,13 +732,15 @@ pub unsafe extern "C" fn function_call_raw( ) -> Outcome { with_unchecked(func, |func: &Function| { match std::panic::catch_unwind(|| { - let input = std::slice::from_raw_parts(input, func.input_size()); - let output = std::slice::from_raw_parts_mut(output, func.output_size()); + let input = std::slice::from_raw_parts(input, func.input_size().in_bytes()); + let output = std::slice::from_raw_parts_mut(output, func.output_size().in_bytes()); let fn_err = func.call_raw(input, output); if !fn_err.is_null() { let fn_err = Box::from_raw(fn_err).take(); - return Outcome::from_result(Result::<(), Error>::Err(fn_err.into())); + return Outcome::from_result(Result::<(), Error>::Err(rust::Error::StatusRaised( + fn_err, + ))); } Outcome::from_result(Result::<(), Error>::Ok(())) @@ -720,7 +762,7 @@ pub unsafe extern "C" fn function_call_raw( #[no_mangle] pub unsafe extern "C" fn function_eval_raw(func: *const (), input: *const u8) -> Outcome { with(func, |func: &Function| { - let input = std::slice::from_raw_parts(input, func.input_size()); + let input = std::slice::from_raw_parts(input, func.input_size().in_bytes()); Outcome::from_result( func.eval_raw(input) .map(|output| Box::leak(output) as *const [u8] as *const ()), diff --git a/jyafn-ext/Makefile b/jyafn-ext/Makefile index d17216e..c259a63 100644 --- a/jyafn-ext/Makefile +++ b/jyafn-ext/Makefile @@ -1,2 +1,5 @@ build: - ./build-all.sh + ./run-all.sh build + +install: + ./run-all.sh install diff --git a/jyafn-ext/extensions/dummy/src/lib.rs b/jyafn-ext/extensions/dummy/src/lib.rs index addaeda..e36bf5d 100644 --- a/jyafn-ext/extensions/dummy/src/lib.rs +++ b/jyafn-ext/extensions/dummy/src/lib.rs @@ -78,3 +78,19 @@ impl Resource for Dummy { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_load() { + unsafe { + let ptr = extension_init() as *mut i8; + if ptr.is_null() { + panic!("prt was null"); + } + println!("{:?}", CString::from_raw(ptr)); + } + } +} diff --git a/jyafn-ext/jyafn_ext.h b/jyafn-ext/jyafn_ext.h new file mode 100644 index 0000000..8f78f6b --- /dev/null +++ b/jyafn-ext/jyafn_ext.h @@ -0,0 +1,160 @@ +/* + * This is a tentative implementation of jyafn extensions in pure C, for the purists. + */ + +#include +#include +#include +#include + + +#define QUOTE(...) #__VA_ARGS__ + + +typedef const void* Outcome; + +extern const char* outcome_get_err(Outcome); +extern void* outcome_get_ok(Outcome); +extern void outcome_drop(Outcome); + +#define OUTCOME_OF(T) Outcome +#define OUTCOME_MANIFEST QUOTE({ \ + "fn_get_err": "outcome_get_err", \ + "fn_get_ok": "outcome_get_ok", \ + "fn_drop": "outcome_drop" \ +}) + + +typedef const void* Dumped; + +extern size_t dumped_get_len(Dumped); +extern const unsigned char* dumped_get_ptr(Dumped); +extern void dumped_drop(Dumped); + +#define DUMPED_MANIFEST QUOTE({ \ + "fn_get_len": "dumped_get_len", \ + "fn_get_ptr": "dumped_get_ptr", \ + "fn_drop": "dumped_drop" \ +}) + + +extern void string_drop(char*); + +#define STRING_MANIFEST QUOTE({ \ + "fn_drop": "string_drop" \ +}) + + +typedef const void* RawResource; + +typedef OUTCOME_OF(RawResource) (*FnFromBytes)(const unsigned char*, size_t); +typedef OUTCOME_OF(Dumped) (*FnDump)(RawResource); +typedef size_t (*FnSize)(RawResource); +typedef char* (*FnGetMethodDef)(RawResource, const char*); +typedef void (*FnDrop)(RawResource); + +#define DEF_SYMBOL_T(FN_TY) typedef struct { FN_TY fn_ptr; char* name; } Symbol##FN_TY +#define SYMBOL_T(FN_TY) Symbol##FN_TY +#define SYMBOL(FUNC) { fn_ptr: &FUNC, name: #FUNC } +DEF_SYMBOL_T(FnFromBytes); +DEF_SYMBOL_T(FnDump); +DEF_SYMBOL_T(FnSize); +DEF_SYMBOL_T(FnGetMethodDef); +DEF_SYMBOL_T(FnDrop); + + +#define MANIFEST_BEGIN "{" \ + "\"outcome\": "OUTCOME_MANIFEST \ + ", \"dumped\": "DUMPED_MANIFEST \ + ", \"string\": "STRING_MANIFEST \ + ", \"resources\": {" +#define MANIFEST_END "}}" + +typedef char* DeclaredResource; + +DeclaredResource declare_resource( + char* resource_name, + SYMBOL_T(FnFromBytes) fn_from_bytes, + SYMBOL_T(FnDump) fn_dump, + SYMBOL_T(FnSize) fn_size, + SYMBOL_T(FnGetMethodDef) fn_get_method_def, + SYMBOL_T(FnDrop) fn_drop +) { + const char* fmt_entry = "\"%s\": " QUOTE({ + "fn_from_bytes": %s, + "fn_dump": %s, + "fn_size": %s, + "fn_get_method_def": %s, + "fn_drop": %s, + }); + + size_t needed = snprintf( + NULL, + 0, + fmt_entry, + resource_name, + fn_from_bytes, + fn_dump, + fn_size, + fn_get_method_def, + fn_drop + ) + 1; + char *buffer = malloc(needed); + sprintf( + buffer, + fmt_entry, + resource_name, + fn_from_bytes, + fn_dump, + fn_size, + fn_get_method_def, + fn_drop + ); + + return (DeclaredResource)buffer; +} + +void joinstr(char** buf, size_t* cap, size_t* len, char* src) { + if (*buf == NULL) { + *buf = malloc(10); + *cap = 10; + *len = 0; + } + + size_t i = 0; + while (src[i] != '\0') { + if (*len == *cap) { + *cap *= 2; + char* newbuf = malloc(*cap); + memcpy(newbuf, *buf, *len); + free(*buf); + *buf = newbuf; + } + + (*buf)[*len] = src[i]; + *len += 1; + } +} + +char* build_manifest(DeclaredResource* resources, size_t n_resources) { + char* buf = NULL; + size_t cap = 0; + size_t len = 0; + + joinstr(&buf, &cap, &len, &*(MANIFEST_BEGIN)); + + for (size_t i = 0; i < n_resources - 1; i++) { + joinstr(&buf, &cap, &len, resources[i]); + free(resources[i]); + joinstr(&buf, &cap, &len, ", "); + } + + if (n_resources > 0) { + joinstr(&buf, &cap, &len, resources[n_resources - 1]); + } + + joinstr(&buf, &cap, &len, MANIFEST_END); + joinstr(&buf, &cap, &len, "\0"); + + return buf; +} diff --git a/jyafn-ext/build-all.sh b/jyafn-ext/run-all.sh similarity index 95% rename from jyafn-ext/build-all.sh rename to jyafn-ext/run-all.sh index bc3b536..46a827b 100755 --- a/jyafn-ext/build-all.sh +++ b/jyafn-ext/run-all.sh @@ -17,6 +17,6 @@ done for extension in $extensions; do echo Building $extension... cd $extension - make build + make $1 cd $BASEDIR done diff --git a/jyafn-ext/src/fn_error.rs b/jyafn-ext/src/fn_error.rs deleted file mode 100644 index b4654ef..0000000 --- a/jyafn-ext/src/fn_error.rs +++ /dev/null @@ -1,17 +0,0 @@ -/// The error type returned from the compiled function. If you need to create a new error -/// from your code, use `String::into`. -pub struct FnError(Option); - -impl From for FnError { - fn from(s: String) -> FnError { - FnError(Some(s)) - } -} - -impl FnError { - /// Takes the underlying error message from this error. Calling this method more than - /// once will result in a panic. - pub fn take(&mut self) -> String { - self.0.take().expect("can only call take once") - } -} diff --git a/jyafn-ext/src/lib.rs b/jyafn-ext/src/lib.rs index 16893f9..11d4dfe 100644 --- a/jyafn-ext/src/lib.rs +++ b/jyafn-ext/src/lib.rs @@ -1,7 +1,6 @@ //! This crate is intended to help extension authors. It exposes a minimal version of //! `jyafn` and many convenience macros to generate all the boilerplate involved. -mod fn_error; mod io; mod layout; mod outcome; @@ -13,19 +12,72 @@ pub use paste::paste; /// We need JSON support to zip JSON values around the FFI boundary. pub use serde_json; -pub use fn_error::FnError; pub use io::{Input, OutputBuilder}; pub use layout::{Layout, Struct, ISOFORMAT}; pub use outcome::Outcome; pub use resource::{Method, Resource}; /// Generates the boilerplate code for a `jyafn` extension. +/// +/// # Usage +/// +/// This macro accepts a list of comman-separated types, each of which has to implement +/// the [`Resource`] trait, like so +/// ``` +/// extension! { +/// Foo, Bar, Baz +/// } +/// ``` +/// Optionally, you may define an init function, which takes no arguments and returns +/// `Result<(), String>`, like so +/// ``` +/// extension! { +/// init = my_init; +/// Foo, Bar, Baz +/// } +/// +/// fn my_init() -> Result<(), String> { /* ... */} +/// ``` #[macro_export] macro_rules! extension { ($($ty:ty),*) => { + fn noop() -> Result<(), String> { Ok (()) } + + $crate::extension! { + init = noop; + $($ty),* + } + }; + (init = $init_fn:ident; $($ty:ty),*) => { use std::ffi::{c_char, CString}; use $crate::Outcome; + /// Creates a C-style string out of a `String` in a way that doesn't produce errors. This + /// function substitutes nul characters by the ` ` (space) character. This avoids an + /// allocation. + /// + /// This method **leaks** the string. So, don't forget to guarantee that somene somewhere + /// is freeing it. + /// + /// # Note + /// + /// Yes, I know! It's a pretty lousy implementation that is even... O(n^2) (!!). You can + /// do better than I in 10mins. + pub(crate) fn make_safe_c_str(s: String) -> CString { + let mut v = s.into_bytes(); + loop { + match std::ffi::CString::new(v) { + Ok(c_str) => return c_str, + Err(err) => { + let nul_position = err.nul_position(); + v = err.into_vec(); + v[nul_position] = b' '; + } + } + } + } + + /// # Safety /// /// Expecting a valid pointer from input. @@ -87,13 +139,15 @@ macro_rules! extension { } #[no_mangle] - pub unsafe extern "C" fn method_def_drop(method: *mut c_char) { + pub unsafe extern "C" fn string_drop(method: *mut c_char) { let _ = CString::from_raw(method); } #[no_mangle] pub extern "C" fn extension_init() -> *const c_char { - fn safe_extension_init() -> String { + fn safe_extension_init() -> Result<$crate::serde_json::Value, String> { + $init_fn()?; + let manifest = $crate::serde_json::json!({ "metadata": { "name": env!("CARGO_PKG_NAME"), @@ -112,33 +166,40 @@ macro_rules! extension { "fn_get_len": "dump_get_len", "fn_drop": "dump_drop" }, + "string": { + "fn_drop": "string_drop" + }, "resources": {$( stringify!($ty): { "fn_from_bytes": stringify!($ty).to_string() + "_from_bytes", "fn_dump": stringify!($ty).to_string() + "_dump", "fn_size": stringify!($ty).to_string() + "_size", "fn_get_method_def": stringify!($ty).to_string() + "_get_method", - "fn_drop_method_def": "method_def_drop", "fn_drop": stringify!($ty).to_string() + "_drop" }, )*} }); - manifest.to_string() + Ok(manifest) } - std::panic::catch_unwind(|| { - // This leak will never be un-leaked. - let boxed = CString::new(safe_extension_init()) - .expect("json output shouldn't contain nul characters") - .into_boxed_c_str(); - let c_str = Box::leak(boxed); - c_str.as_ptr() - }) - .unwrap_or_else(|_| { - eprintln!("extension initialization panicked. See stderr"); - std::ptr::null() - }) + let outcome = std::panic::catch_unwind(|| { + match safe_extension_init() { + Ok(manifest) => manifest, + Err(err) => { + $crate::serde_json::json!({"error": err}) + } + } + }).unwrap_or_else(|_| { + $crate::serde_json::json!({ + "error": "extension initialization panicked. See stderr" + }) + }); + + match CString::new(outcome.to_string()) { + Ok(s) => s.into_raw(), + Err(_) => std::ptr::null(), + } } $( @@ -279,7 +340,7 @@ macro_rules! method { input_slots: u64, output_ptr: *mut u8, output_slots: u64, - ) -> *mut $crate::FnError { + ) -> *mut u8 { match std::panic::catch_unwind(|| { unsafe { // Safety: all this stuff came from jyafn code. The jyafn code should @@ -297,16 +358,14 @@ macro_rules! method { }) { Ok(Ok(())) => std::ptr::null_mut(), Ok(Err(err)) => { - let boxed = Box::new(err.to_string().into()); - Box::leak(boxed) + make_safe_c_str(err).into_raw() as *mut u8 } // DON'T forget the nul character when working with bytes directly! Err(_) => { - let boxed = Box::new(format!( + make_safe_c_str(format!( "method {:?} panicked. See stderr", stringify!($safe_interface), - ).into()); - Box::leak(boxed) + )).into_raw() as *mut u8 } } } diff --git a/jyafn-go/pkg/jyafn/ffi.go b/jyafn-go/pkg/jyafn/ffi.go index 7795268..97d9714 100644 --- a/jyafn-go/pkg/jyafn/ffi.go +++ b/jyafn-go/pkg/jyafn/ffi.go @@ -78,9 +78,12 @@ type ffiType struct { layoutIsDatetime func(LayoutPtr) bool layoutIsSymbol func(LayoutPtr) bool layoutIsStruct func(LayoutPtr) bool + layoutIsTuple func(LayoutPtr) bool layoutIsList func(LayoutPtr) bool layoutDatetimeFormat func(LayoutPtr) AllocatedStr layoutAsStruct func(LayoutPtr) StructPtr + layoutTupleSize func(LayoutPtr) uintptr + layoutGetTupleItem func(LayoutPtr) LayoutPtr layoutListElement func(LayoutPtr) LayoutPtr layoutListSize func(LayoutPtr) uintptr layoutIsSuperset func(LayoutPtr, LayoutPtr) bool @@ -185,9 +188,12 @@ func init() { register(&ffi.layoutIsDatetime, "layout_is_datetime") register(&ffi.layoutIsSymbol, "layout_is_symbol") register(&ffi.layoutIsStruct, "layout_is_struct") + register(&ffi.layoutIsTuple, "layout_is_tuple") register(&ffi.layoutIsList, "layout_is_list") register(&ffi.layoutDatetimeFormat, "layout_datetime_format") register(&ffi.layoutAsStruct, "layout_as_struct") + register(&ffi.layoutTupleSize, "layout_tuple_size") + register(&ffi.layoutGetTupleItem, "layout_get_tuple_item") register(&ffi.layoutListElement, "layout_list_element") register(&ffi.layoutListSize, "layout_list_size") register(&ffi.layoutIsSuperset, "layout_is_superset") diff --git a/jyafn-go/pkg/jyafn/layout.go b/jyafn-go/pkg/jyafn/layout.go index 71ab228..291b6e5 100644 --- a/jyafn-go/pkg/jyafn/layout.go +++ b/jyafn-go/pkg/jyafn/layout.go @@ -101,6 +101,10 @@ func (l Layout) IsStruct() bool { return ffi.layoutIsStruct(l.ptr) } +func (l Layout) IsTuple() bool { + return ffi.layoutIsTuple(l.ptr) +} + func (l Layout) IsList() bool { return ffi.layoutIsList(l.ptr) } @@ -109,6 +113,18 @@ func (l Layout) AsStruct() Struct { return Struct{ptr: ffi.layoutAsStruct(l.ptr)} } +func (l Layout) TupleSize() uint { + return uint(ffi.layoutTupleSize(l.ptr)) +} + +func (l Layout) GetItemLayout() Layout { + item := ffi.layoutGetTupleItem(l.ptr) + if item == 0 { + panic("called GetItemLayout on a Tuple out of bounds") + } + return Layout{ptr: item} +} + func (l Layout) DateTimeFormat() string { format := ffi.layoutDatetimeFormat(l.ptr) if format == 0 { diff --git a/jyafn-python/Cargo.toml b/jyafn-python/Cargo.toml index 6531331..e0cfce8 100644 --- a/jyafn-python/Cargo.toml +++ b/jyafn-python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jyafn-python" -version = "0.1.2" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,10 +10,8 @@ name = "jyafn" crate-type = ["cdylib"] [dependencies] -byte-slice-cast = "1.2.2" chrono = "0.4.37" get-size = "0.1.4" jyafn = { path = "../jyafn" } pyo3 = { version = "0.22.0", features = ["extension-module"] } -semver = "1.0.23" serde_json = "1.0.115" diff --git a/jyafn-python/python/jyafn/__init__.py b/jyafn-python/python/jyafn/__init__.py index 628677a..aa7189d 100644 --- a/jyafn-python/python/jyafn/__init__.py +++ b/jyafn-python/python/jyafn/__init__.py @@ -1,5 +1,3 @@ -# type: ignore - from .jyafn import * import jyafn as fn @@ -15,6 +13,7 @@ from dataclasses import dataclass from .np_dropin import * +from .describe import describe # re-export __version__ = fn.__get_version() @@ -98,8 +97,10 @@ def make_layout(cls, args: tuple[Any, ...]) -> fn.Layout: match args: case (): return fn.Layout.datetime() - case (format,): + case (format,) if isinstance(format, str): return fn.Layout.datetime(format) + case (format,) if isinstance(format, bytes): + return fn.Layout.datetime(format.decode("utf8")) case _: raise TypeError(f"Invalid args for datetime annotation: {args}") @@ -273,7 +274,9 @@ def build(self) -> fn.Graph: for arg, param in type_hints.items() if arg != "return" } - _ret_from_annotation(self.original(**inputs), type_hints["return"]) + _ret_from_annotation( + self.original(**inputs), type_hints.get("return", inspect._empty) + ) for key, value in self.metadata.items(): g.set_metadata(str(key), str(value)) @@ -373,11 +376,23 @@ def inner(f: Any) -> fn.Function: return inner(args[0]) if len(args) == 1 else inner +ANONYMOUS_COUNTER: dict[str, int] = {} + + +def __anonymous_name(kind: str) -> str: + """Generates an anononymous name for a "kind" of thing.""" + num = ANONYMOUS_COUNTER.setdefault(kind, 0) + name = f"{kind}_{num}" + ANONYMOUS_COUNTER[kind] += 1 + + return name + + def mapping( - name: str, key_layout: fn.Layout | type[BaseAnnotation] | types.GenericAlias, value_layout: fn.Layout | type[BaseAnnotation] | types.GenericAlias, - obj: Any, + obj: Any = {}, + name: str | None = None, ) -> fn.LazyMapping: """ Creates a new key-value mapping to be used in a graph. Mappings in JYAFN work very @@ -389,6 +404,8 @@ def mapping( mapping will be marked as consumed and an exception will be raised on reuse. This is done to avoid errors stemming from already spent iterators. """ + if name is None: + name = __anonymous_name("mapping") return fn.LazyMapping(name, make_layout(key_layout), make_layout(value_layout), obj) @@ -490,7 +507,7 @@ def resource_type( def resource( - name: str, + name: str | None = None, *, data: bytes, type: str = "External", @@ -508,6 +525,39 @@ def resource( to make sure that they are installed in your environment (you can use the `jyafn get` CLI utility for managing extensions). """ + if name is None: + name = __anonymous_name("resource") return resource_type( type=type, extension=extension, resource=resource, **kwargs ).load(name, data) + + +# This needs to be down here because it redefines the name for `tuple`. +pytuple = tuple + + +class tuple(BaseAnnotation): + """Annotates the `tuple` layout.""" + + @classmethod + def make_layout(cls, args: tuple[Any, ...]) -> fn.Layout: + match args: + case fields if isinstance(fields, pytuple): + tup = [] + for field in fields: + match field: + case type(): + tup.append(field.make_layout(())) + case types.GenericAlias(): + tup.append( + typing.get_origin(field).make_layout( + typing.get_args(field) + ) + ) + case _: + raise TypeError( + f"Invalid arg for tuple field annotation: {field}" + ) + return fn.Layout.tuple_of(pytuple(tup)) + case _: + raise TypeError(f"Invalid args for tuple annotation: {args}") diff --git a/jyafn-python/python/jyafn/describe.py b/jyafn-python/python/jyafn/describe.py index ca1d07c..a88763a 100644 --- a/jyafn-python/python/jyafn/describe.py +++ b/jyafn-python/python/jyafn/describe.py @@ -1,4 +1,3 @@ -# type:ignore import jyafn as fn import sys @@ -87,6 +86,7 @@ def print(*args): def describe(thing: str | fn.Graph | fn.Function) -> str: + """Describes a graph, function or file as a nicely-formatted report.""" if isinstance(thing, str): return describe_fn(fn.read_fn(thing)) elif isinstance(thing, fn.Graph): diff --git a/jyafn-python/python/jyafn/jyafn.pyi b/jyafn-python/python/jyafn/jyafn.pyi index 2a68445..8330247 100644 --- a/jyafn-python/python/jyafn/jyafn.pyi +++ b/jyafn-python/python/jyafn/jyafn.pyi @@ -373,6 +373,10 @@ class Layout: """Whether this layout is of the flavor "datetime".""" def is_symbol(self) -> bool: """Whether this layout is of the flavor "symbol".""" + def is_struct(self) -> bool: + """Whether this layout is of the flavor "struct".""" + def is_tuple(self) -> bool: + """Whether this layout is of the flavor "tuple".""" def struct_keys(self) -> Optional[list[str]]: """ Returns the field names of this struct layout, if it is of flavor "struct", else @@ -405,6 +409,12 @@ class Layout: Returns a new layout of flavor "struct", with the fields given by the supplied Python dictionary. """ + @staticmethod + def tuple_of(fields: tuple[Layout, ...]) -> Layout: + """ + Returns a new layout of flavor "tuple", with the fields given by the supplied + Python dictionary. + """ def putative_layout(obj: Any) -> Layout: """ @@ -556,21 +566,23 @@ class Extension: ``` """ + @staticmethod + def list_loaded() -> dict[str, list[str]]: + """ + Lists the loaded extensions as the keys of a dictionary and all the loaded + versions as the elements of the values. + """ def __init__(self, name: str, version_req: str = "*") -> None: """Loads a new extension of name `name` with the given version requirements.""" - @property def name(self) -> str: """The name of this extension.""" - @property def version(self) -> str: """The version of this extension.""" - @property def resources(self) -> list[str]: """The resources that this extension provides.""" - def get(self, resource_name: str) -> ResourceType: """ Gets a resource that this extension declares. Throws an `IndexError` if the diff --git a/jyafn-python/src/extension.rs b/jyafn-python/src/extension.rs index 85dc302..d63e95b 100644 --- a/jyafn-python/src/extension.rs +++ b/jyafn-python/src/extension.rs @@ -1,5 +1,6 @@ use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::prelude::*; +use std::collections::HashMap; use std::sync::Arc; use crate::resource::ResourceType; @@ -11,6 +12,14 @@ pub struct Extension(Arc); #[pymethods] impl Extension { + #[staticmethod] + fn list_loaded() -> HashMap> { + rust::extension::list() + .into_iter() + .map(|(name, versions)| (name, versions.into_iter().map(|v| v.to_string()).collect())) + .collect() + } + #[new] #[pyo3(signature = (name, version_req = "*"))] fn new(name: &str, version_req: &str) -> PyResult { diff --git a/jyafn-python/src/function.rs b/jyafn-python/src/function.rs index bb94450..b77805a 100644 --- a/jyafn-python/src/function.rs +++ b/jyafn-python/src/function.rs @@ -65,12 +65,12 @@ impl Function { #[getter] fn input_size(&self) -> usize { - self.inner().input_size() + self.inner().input_size().in_bytes() } #[getter] fn output_size(&self) -> usize { - self.inner().output_size() + self.inner().output_size().in_bytes() } #[getter] diff --git a/jyafn-python/src/graph/indexed.rs b/jyafn-python/src/graph/indexed.rs index 4a850fe..460b2df 100644 --- a/jyafn-python/src/graph/indexed.rs +++ b/jyafn-python/src/graph/indexed.rs @@ -28,10 +28,7 @@ impl IndexedList { }; let layout = first.putative_layout(); - if let Some(different) = depythonized - .iter() - .find(|v| v.putative_layout() != layout) - { + if let Some(different) = depythonized.iter().find(|v| v.putative_layout() != layout) { return Err(exceptions::PyTypeError::new_err(format!( "not all elements in list have the same layout. Expected {layout} and found \ {different}" diff --git a/jyafn-python/src/graph/mod.rs b/jyafn-python/src/graph/mod.rs index 7502b98..63cb7f4 100644 --- a/jyafn-python/src/graph/mod.rs +++ b/jyafn-python/src/graph/mod.rs @@ -199,8 +199,14 @@ impl Graph { .insert(key, value); } - fn render(&self) -> String { - self.0.lock().expect("poisoned").render().to_string() + fn render(&self) -> PyResult { + Ok(self + .0 + .lock() + .expect("poisoned") + .render() + .map_err(ToPyErr)? + .to_string()) } fn render_assembly(&self) -> PyResult { diff --git a/jyafn-python/src/layout.rs b/jyafn-python/src/layout.rs index 26ee04d..2987283 100644 --- a/jyafn-python/src/layout.rs +++ b/jyafn-python/src/layout.rs @@ -1,6 +1,7 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::PyDict; +use pyo3::types::PyTuple; use rust::layout::{Decoder, Encode, Layout as RustLayout, Sym, Visitor}; pub struct Obj<'py>(pub Bound<'py, PyAny>); @@ -69,6 +70,17 @@ impl<'py> Encode for Obj<'py> { Obj(item).visit(field, symbols, visitor)?; } } + RustLayout::Tuple(fields) => { + for (idx, field) in fields.iter().enumerate() { + let Ok(item) = self.0.get_item(idx) else { + return Err(exceptions::PyTypeError::new_err(format!( + "missing field {idx} in {}", + self.0 + ))); + }; + Obj(item).visit(field, symbols, visitor)?; + } + } RustLayout::List(element, size) => { let mut n_items = 0; for item in self.0.iter()? { @@ -131,6 +143,16 @@ impl<'py> Decoder for PyDecoder<'py> { dict.to_object(self.0) } + RustLayout::Tuple(fields) => { + let tuple = pyo3::types::PyTuple::new_bound( + self.0, + fields + .iter() + .map(|field| self.build(field, symbols, visitor)), + ); + + tuple.to_object(self.0) + } RustLayout::List(element, size) => pyo3::types::PyList::new_bound( self.0, (0..*size).map(|_| self.build(element, symbols, visitor)), @@ -192,6 +214,14 @@ impl Layout { matches!(&self.0, rust::layout::Layout::Symbol) } + fn is_struct(&self) -> bool { + matches!(&self.0, rust::layout::Layout::Struct(_)) + } + + fn is_tuple(&self) -> bool { + matches!(&self.0, rust::layout::Layout::Tuple(_)) + } + fn struct_keys(&self, py: Python) -> PyResult { let rust::layout::Layout::Struct(s) = &self.0 else { return Ok(pyo3::types::PyNone::get_bound(py).to_object(py)); @@ -249,4 +279,14 @@ impl Layout { fields, )))) } + + #[staticmethod] + fn tuple_of(fields: &Bound<'_, PyTuple>) -> PyResult { + let fields = fields + .iter() + .map(|value| Ok(value.extract::()?.0)) + .collect::>>()?; + + Ok(Layout(rust::layout::Layout::Tuple(fields))) + } } diff --git a/jyafn-python/src/lib.rs b/jyafn-python/src/lib.rs index c8373af..49149fd 100644 --- a/jyafn-python/src/lib.rs +++ b/jyafn-python/src/lib.rs @@ -132,6 +132,16 @@ fn pythonize_ref_value(py: Python, val: rust::layout::RefValue) -> PyResult { + let tuple = PyTuple::new_bound( + py, + fields + .into_iter() + .map(|val| pythonize_ref_value(py, val)) + .collect::, _>>()?, + ); + tuple.unbind().into() + } rust::layout::RefValue::List(l) => PyTuple::new_bound( py, l.into_iter() @@ -180,7 +190,7 @@ fn depythonize_ref_value( .iter() .map(|val| depythonize_inner(g, &val)) .collect::>>()?; - return Ok(rust::layout::RefValue::List(vals)); + return Ok(rust::layout::RefValue::Tuple(vals)); } if let Ok(scalar) = const_from_py(g, obj) { diff --git a/jyafn-python/tests/__main__.py b/jyafn-python/tests/__main__.py new file mode 100644 index 0000000..ec71af5 --- /dev/null +++ b/jyafn-python/tests/__main__.py @@ -0,0 +1,11 @@ +import os +import sys + +if __name__ == "__main__": + basedir = os.path.dirname(__file__) + for path in sorted(os.listdir(basedir)): + if path.endswith(".py") and path != "__main__.py": + status = os.system(f"python {basedir}/{path}") + if status: + print(f"error: {path} test failed") + exit(1) diff --git a/jyafn-python/tests/assert.py b/jyafn-python/tests/assert.py index ec8925f..2e259f1 100644 --- a/jyafn-python/tests/assert.py +++ b/jyafn-python/tests/assert.py @@ -1,11 +1,16 @@ import jyafn as fn +import traceback +try: -@fn.func(debug=True) -def asserts(x: fn.scalar) -> None: - fn.assert_(x > 0.0, "x must be positive") + @fn.func(debug=True) + def asserts(x: fn.scalar) -> None: + fn.assert_(x > 0.0, "x must be positive") - -print(asserts.get_graph().to_json()) -print(asserts(1.0)) -print(asserts(-1.0)) + print(asserts.get_graph().to_json()) + print(asserts(1.0)) + print(asserts(-1.0)) +except Exception: + traceback.print_exc() +else: + raise Exception("should raise") diff --git a/jyafn-python/tests/dataset.py b/jyafn-python/tests/dataset.py deleted file mode 100644 index 8a691e9..0000000 --- a/jyafn-python/tests/dataset.py +++ /dev/null @@ -1,28 +0,0 @@ -import jyafn as fn - -if __name__ == "__main__": - - @fn.func - def a_fun(a: fn.scalar, b: fn.scalar) -> fn.scalar: - return 2.0 * a + b + 1.0 - - data = fn.Dataset.build( - a_fun.input_layout, - [ - {"a": 3, "b": 1}, - {"a": 2, "b": 2}, - {"a": 1, "b": 3}, - ], - ) - print(data) - print(data.decode()) - - mapped = data.map(a_fun) - print(mapped) - print(mapped.decode()) - - par_mapped = data.par_map(a_fun) - print(par_mapped) - print(par_mapped.decode()) - - assert mapped.decode() == par_mapped.decode() diff --git a/jyafn-python/tests/datetime_.py b/jyafn-python/tests/datetime_.py index 29edae1..4f62dc8 100644 --- a/jyafn-python/tests/datetime_.py +++ b/jyafn-python/tests/datetime_.py @@ -4,7 +4,7 @@ @fn.func -def make_date(dt: fn.datetime) -> fn.datetime["%Y-%m-%d"]: +def make_date(dt: fn.datetime) -> fn.datetime[b"%Y-%m-%d"]: return dt diff --git a/jyafn-python/tests/describe.py b/jyafn-python/tests/describe.py index 0be415f..4df9fc3 100644 --- a/jyafn-python/tests/describe.py +++ b/jyafn-python/tests/describe.py @@ -1,6 +1,7 @@ import jyafn as fn +import traceback -fn_file = "data/vbt.jyafn" +fn_file = "data/a_fun.jyafn" func = fn.read_fn(fn_file) graph = func.get_graph() @@ -9,4 +10,10 @@ fn.describe(func) fn.describe(graph) fn.describe(fn_file) -fn.describe(None) + +try: + fn.describe(None) +except TypeError: + traceback.print_exc() +else: + raise Exception("should raise") diff --git a/jyafn-python/tests/eval_const.py b/jyafn-python/tests/eval_const.py index 0eea846..2e0051b 100644 --- a/jyafn-python/tests/eval_const.py +++ b/jyafn-python/tests/eval_const.py @@ -5,6 +5,5 @@ def func(a: fn.scalar) -> fn.scalar: return fn.const(True).choose(fn.exp(a + 0.0) * 1.0, -1e-100) - print(func.get_graph().render()) print(func(1.0)) diff --git a/jyafn-python/tests/extensions.py b/jyafn-python/tests/extensions.py index 5e8d082..79f6383 100644 --- a/jyafn-python/tests/extensions.py +++ b/jyafn-python/tests/extensions.py @@ -5,6 +5,8 @@ import jyafn as fn import traceback +print("Before func creation...") + @fn.func def with_resources(x: fn.scalar) -> fn.scalar: @@ -19,6 +21,7 @@ def with_resources(x: fn.scalar) -> fn.scalar: return the_result +print("Starting call...") print(with_resources.get_graph().render()) assert with_resources(2.5) == 1.0 @@ -81,3 +84,5 @@ def with_resources(x: fn.scalar) -> fn.scalar: serialized = with_resources.write("with_resources.jyafn") deserialized = fn.read_fn("with_resources.jyafn") assert deserialized(2.5) == 1.0 + +print(fn.Extension.list_loaded()) diff --git a/jyafn-python/tests/illegal.py b/jyafn-python/tests/illegal.py index 62ff2e5..0a77c46 100644 --- a/jyafn-python/tests/illegal.py +++ b/jyafn-python/tests/illegal.py @@ -1,6 +1,11 @@ import jyafn as fn +import traceback - -@fn.func -def illegal(x: fn.scalar) -> None: - fn.assert_(2.0 + 2.0 == 5.0, "ingsoc") +try: + @fn.func + def illegal(x: fn.scalar) -> None: + fn.assert_(2.0 + 2.0 == 5.0, "ingsoc") +except Exception: + traceback.print_exc() +else: + raise Exception("should fail") diff --git a/jyafn-python/tests/logic.py b/jyafn-python/tests/logic.py index 4aae3ac..e0eefc0 100644 --- a/jyafn-python/tests/logic.py +++ b/jyafn-python/tests/logic.py @@ -1,4 +1,5 @@ import jyafn as fn +import traceback @fn.func @@ -13,13 +14,17 @@ def relu(a: fn.scalar) -> fn.scalar: for c in cases: print(f"relu({c}) = {relu(c)}") - -@fn.func -def should_fail(a: fn.scalar) -> fn.scalar: - if a.to_bool(): - return 0.0 - else: - return 1.0 +try: + @fn.func + def should_fail(a: fn.scalar) -> fn.scalar: + if a.to_bool(): + return 0.0 + else: + return 1.0 -print(f"should_fail({1.0}) = {should_fail(1.0)}") + print(f"should_fail({1.0}) = {should_fail(1.0)}") +except Exception: + traceback.print_exc() +else: + raise Exception("should fail") diff --git a/jyafn-python/tests/mapping.py b/jyafn-python/tests/mapping.py index d20620d..25c16b2 100644 --- a/jyafn-python/tests/mapping.py +++ b/jyafn-python/tests/mapping.py @@ -1,6 +1,6 @@ import jyafn as fn -silly_map = fn.mapping("silly", fn.symbol, fn.scalar, {"a": 2, "b": 4}) +silly_map = fn.mapping(fn.symbol, fn.scalar, {"a": 2, "b": 4}) @fn.func diff --git a/jyafn-python/tests/serde.py b/jyafn-python/tests/serde.py index 28f9102..f28fce1 100644 --- a/jyafn-python/tests/serde.py +++ b/jyafn-python/tests/serde.py @@ -8,7 +8,7 @@ def a_fun(a: fn.scalar, b: fn.scalar, c: fn.symbol) -> fn.scalar: print(a_fun.to_json()) -a_fun.write("a_fun.jyafn") +a_fun.write("data/a_fun.jyafn") -other_fun = fn.read_fn("a_fun.jyafn") +other_fun = fn.read_fn("data/a_fun.jyafn") print(a_fun(5, 6, "a"), other_fun(5, 6, "a")) diff --git a/jyafn-python/tests/showcase.py b/jyafn-python/tests/showcase.py deleted file mode 100644 index e37855c..0000000 --- a/jyafn-python/tests/showcase.py +++ /dev/null @@ -1,43 +0,0 @@ -import json -from time import time - -import boto3 -import jyafn as fn -import numpy as np -import ppca_rs - -s3 = boto3.client("s3") - -latest_meta = json.load( - s3.get_object( - Bucket="fh-ca-data", Key="cheapest_providers/prod/predictions/latest_meta.json" - )["Body"] -) -model = ppca_rs.PPCAModel.load( - s3.get_object( - Bucket="fh-ca-data", - Key=f"cheapest_providers/prod/model/{latest_meta['model_id']}.bincode", - )["Body"].read() -) - -tic = time() - - -@fn.func -def from_components( - comps: fn.tensor[model.state_size], -) -> fn.tensor[model.output_size]: - total = comps @ model.transform.T + model.mean - return total - - -toc = time() -print(f"Took {toc-tic}s") - -with open("from_components.ssa", "w") as f: - f.write(from_components.get_graph().render()) - -with open("from_components.s", "w") as f: - f.write(from_components.get_graph().render_assembly()) - -from_components.write("from_components.jyafn") diff --git a/jyafn-python/tests/subgraph.py b/jyafn-python/tests/subgraph.py index 5332ee0..3f2cd15 100644 --- a/jyafn-python/tests/subgraph.py +++ b/jyafn-python/tests/subgraph.py @@ -6,9 +6,10 @@ def simple(a: fn.scalar, b: fn.scalar): return 2.0 * a + b -@fn.func(debug=True) +@fn.func() def call_simple(a: fn.scalar, b: fn.scalar): return simple(a, b) +print(simple.build().render()) assert call_simple(2.0, 3.0) == 7.0 diff --git a/jyafn-python/tests/tuples.py b/jyafn-python/tests/tuples.py new file mode 100644 index 0000000..6a0f731 --- /dev/null +++ b/jyafn-python/tests/tuples.py @@ -0,0 +1,9 @@ +import jyafn as fn + + +@fn.func +def tuples(tup: fn.tuple[fn.scalar, fn.scalar]) -> fn.tuple[fn.scalar, fn.scalar]: + return tup[0] + tup[1], tup[0] - tup[1] + + +assert tuples((1.0, 3.0)) == (4.0, -2.0) diff --git a/jyafn/Cargo.toml b/jyafn/Cargo.toml index 789f696..ec0a468 100644 --- a/jyafn/Cargo.toml +++ b/jyafn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jyafn" -version = "0.1.2" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -23,9 +23,6 @@ hashbrown = { version = "0.14.3", features = ["serde", "raw"] } home = "0.5.9" lazy_static = "1.4.0" libloading = "0.8.4" -maplit = "1.0.2" -memmap = "0.7.0" -no-panic = "0.1.29" qbe = { path = "../vendored/qbe-rs" } rand = "0.8.5" scopeguard = "1.2.0" @@ -34,7 +31,6 @@ serde = { version = "1.0.197", features = ["rc"] } serde_derive = "1.0.197" serde_json = "1.0.115" serde_with = "3.9.0" -smallvec = { version = "1.13.2", features = ["serde"] } special-fun = "0.3.0" tempfile = "3.10.1" thiserror = "1.0.58" diff --git a/jyafn/src/const.rs b/jyafn/src/const.rs index 396fd14..a21e36b 100644 --- a/jyafn/src/const.rs +++ b/jyafn/src/const.rs @@ -1,10 +1,10 @@ //! Constant values in the computational graph. Constants need to have a type and a binary //! representation as a 64-bit peice of data. -use super::Type; - use std::fmt::Debug; +use super::Type; + /// A constant. Constants need to have a type and a binary representation as a 64-bit /// peice of data. #[typetag::serde(tag = "type")] diff --git a/jyafn/src/extension.rs b/jyafn/src/extension.rs index 7c0127e..758cb93 100644 --- a/jyafn/src/extension.rs +++ b/jyafn/src/extension.rs @@ -33,6 +33,17 @@ pub struct RawResource(pub(crate) *mut ()); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Dumped(pub(crate) *mut ()); +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum LoadOutcome { + // Extension loading failed + Failed { + /// The reason why it failed. + error: String, + }, + Loaded(ExtensionManifest), +} + /// This is the data format, returned as a C-style string from the `extension_init` /// initialization function. This describes which symbols to be used by each resource in /// this extension. @@ -44,6 +55,9 @@ pub struct ExtensionManifest { pub outcome: OutcomeManifest, /// Describes the symbols to be used when accessing buffers of binary memory. pub dumped: DumpedManifest, + /// Describes the symbols to be used when accessing C-style strings generated by this + /// library. + pub string: StringManifest, /// Describes the symbols to be used when interfacing with each resource type provided /// by this extension. pub resources: HashMap, @@ -80,6 +94,14 @@ pub struct DumpedManifest { pub fn_drop: String, } +/// Lists the name of the symbols needed to create the interface between C-style strings +/// from the extension and jyafn. See [`StringSymbols`] for detailed information on the +/// contract for each symbol. +#[derive(Debug, Serialize, Deserialize)] +pub struct StringManifest { + pub fn_drop: String, +} + /// Lists the names of the symbols needed to create the interface between a resource and /// jyafn. See [`ResourceSymbols`] for detailed information on the contract for each /// symbol. @@ -89,14 +111,18 @@ pub struct ResourceManifest { pub fn_dump: String, pub fn_size: String, pub fn_get_method_def: String, - pub fn_drop_method_def: String, pub fn_drop: String, } +/// A declaration of an external method. Gives a raw function pointer to be called in +/// jyafn code accompanied by the layout of the input and the output of the function. #[derive(Debug, Serialize, Deserialize)] pub struct ExternalMethod { + /// The function pointer to be called in jyafn code. pub fn_ptr: usize, + /// The input layout of the given function. pub input_layout: Struct, + /// Output layout of the given function. pub output_layout: Layout, } @@ -177,6 +203,33 @@ impl DumpedSymbols { } } +/// Lists the name of the symbols needed to create the interface between C-style strings +/// from the extension and jyafn. +#[derive(Debug, Clone)] +pub struct StringSymbols { + /// Drops any allocated memory created for the C-style string pointed by the input. + /// This will be called only on pointers generated from within the extension (returned + /// strings) and will be called only once per pointer. + pub fn_drop: unsafe extern "C" fn(*mut c_char), +} + +impl StringSymbols { + /// Loads the resource symbols from the supplied library, given a manifest. + unsafe fn load(library: &Library, manifest: &StringManifest) -> Result { + /// For building structs that are symbol tables. + macro_rules! symbol { + ($($sym:ident),*) => { Self {$( + $sym: get_symbol(library, &manifest.$sym).context( + concat!("getting symbol for ", stringify!($sym) + ) + )?, + )*}} + } + + Ok(symbol!(fn_drop)) + } +} + /// Lists the names of the symbols needed to create the interface between a resource and /// jyafn. #[derive(Debug, Clone)] @@ -192,9 +245,6 @@ pub struct ResourceSymbols { /// C-style strings, returns the JSON representation of an [`ExternalMethod`] as a /// C-style string. pub fn_get_method_def: unsafe extern "C" fn(RawResource, *const c_char) -> *mut c_char, - /// Drops any allocated memory created for this given method definition. Will be - /// called only once per method definiton created by `fn_get_method_def`. - pub fn_drop_method_def: unsafe extern "C" fn(*mut c_char), /// Drops any allocation memory created for this resource. This will be called only /// once per resource and, after this call, no more calls are expected on the given /// resource. @@ -222,7 +272,6 @@ impl ResourceSymbols { fn_dump, fn_size, fn_get_method_def, - fn_drop_method_def, fn_drop )) } @@ -248,6 +297,9 @@ pub struct Extension { outcome: OutcomeSymbols, /// Describes the symbols to be used when accessing buffers of binary memory. dumped: DumpedSymbols, + /// Describes the symbols to be used when accessing C-style strings generated by this + /// library. + pub string: StringSymbols, /// Describes the symbols to be used when interfacing with each resource type provided /// by this extension. resources: HashMap, @@ -260,15 +312,32 @@ impl Extension { unsafe { // Safety: we can only pray nobody loads anything funny here. However, it's // not my responsibilty what kind of crap you install in your computer. - let library = Library::new(path)?; - let extension_init: Symbol = library.get(EXTENSION_INIT_SYMBOL)?; + println!("starting to load {path:?}"); + let library = Library::new(path).inspect_err(|err| println!("oops:{err}"))?; + println!("aisndoa"); + let extension_init: Symbol = library + .get(EXTENSION_INIT_SYMBOL) + .inspect_err(|err| println!("oops:{err}"))?; + println!("init is not null"); let outcome = extension_init(); + println!("init is not null"); if outcome.is_null() { return Err(format!("library {path:?} failed to load").into()); } - let manifest: ExtensionManifest = - serde_json::from_slice(CStr::from_ptr(outcome).to_bytes()) - .map_err(|err| err.to_string())?; + let parsed: LoadOutcome = serde_json::from_slice(CStr::from_ptr(outcome).to_bytes()) + .map_err(|err| err.to_string())?; + + let manifest = match parsed { + LoadOutcome::Loaded(manifest) => manifest, + LoadOutcome::Failed { error } => return Err(error.into()), + }; + + let string = StringSymbols::load(&library, &manifest.string) + .with_context(|| format!("loading `string` symbols from {path:?}"))?; + let fn_drop = string.fn_drop; + scopeguard::defer! { + (fn_drop)(outcome as *mut i8); + } let outcome = OutcomeSymbols::load(&library, &manifest.outcome) .with_context(|| format!("loading `outcome` symbols from {path:?}"))?; @@ -293,6 +362,7 @@ impl Extension { metadata: manifest.metadata, outcome, dumped, + string, resources, }) } @@ -476,21 +546,27 @@ pub fn try_get(name: &str, version_req: &semver::VersionReq) -> Result Option> { -// let lock = EXTENSIONS.read().expect("poisoned"); -// let loaded_extensions = lock.get(name)?; - -// for (version, extension) in loaded_extensions { -// if version_req.matches(version) { -// return Some(extension.clone()); -// } -// } - -// None -// } - /// Gets an extension by its name, panicking if it was not loaded. pub fn get(name: &str, version_req: &semver::VersionReq) -> Arc { try_get(name, version_req).expect("extension not loaded") } + +/// Lists the names and versions of all currently loaded extensions. +pub fn list() -> HashMap> { + EXTENSIONS + .read() + .expect("poisoned") + .iter() + .map(|(name, versions)| (name.clone(), versions.keys().cloned().collect::>())) + .collect() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_load_extension() { + get("dummy", &"*".parse().unwrap()); + } +} diff --git a/jyafn/src/function.rs b/jyafn/src/function.rs index 5421817..31c489f 100644 --- a/jyafn/src/function.rs +++ b/jyafn/src/function.rs @@ -1,6 +1,7 @@ use get_size::GetSize; use libloading::Library; -use std::ffi::{c_char, CStr}; +use std::borrow::Cow; +use std::ffi::{c_char, CStr, CString}; use std::{ cell::RefCell, fmt::Debug, @@ -10,30 +11,39 @@ use std::{ use tempfile::NamedTempFile; use thread_local::ThreadLocal; -use super::{layout, Error, Graph, Type}; +use crate::size::Size; + +use super::{layout, Error, Graph}; /// The error type returned from the compiled function. If you need to create a new error /// from your code, use `String::into`. -pub struct FnError(Option); +pub struct FnError(Option>); impl FnError { /// Takes the underlying error message from this error. Calling this method more than /// once will result in a panic. - pub fn take(&mut self) -> String { + pub fn take(&mut self) -> Cow<'static, CStr> { self.0.take().expect("can only call take once") } /// This is used from inside jyafn to create an error from static C-style error /// messages. pub(crate) unsafe extern "C" fn make_static(s: *const c_char) -> *mut FnError { - let boxed = Box::new(Self(Some(CStr::from_ptr(s).to_string_lossy().to_string()))); + let boxed = Box::new(Self(Some(Cow::Borrowed(CStr::from_ptr(s))))); + Box::leak(boxed) + } + + /// This is used from inside jyafn to create an error from static C-style error + /// messages. + pub(crate) unsafe extern "C" fn make_allocated(s: *mut c_char) -> *mut FnError { + let boxed = Box::new(Self(Some(Cow::Owned(CString::from_raw(s))))); Box::leak(boxed) } } impl From for FnError { fn from(s: String) -> FnError { - FnError(Some(s)) + FnError(Some(Cow::Owned(crate::utils::make_safe_c_str(s)))) } } @@ -48,8 +58,8 @@ pub struct FunctionData { library_len: u64, input_layout: layout::Layout, output_layout: layout::Layout, - input_size: usize, - output_size: usize, + input_size: Size, + output_size: Size, fn_ptr: RawFn, input: ThreadLocal>, output: ThreadLocal>, @@ -95,12 +105,12 @@ impl<'a> From<&'a Function> for Arc { impl Function { /// The size of the input of this function, in bytes. - pub fn input_size(&self) -> usize { + pub fn input_size(&self) -> Size { self.data.input_size } /// The size of the output of this function, in bytes. - pub fn output_size(&self) -> usize { + pub fn output_size(&self) -> Size { self.data.output_size } @@ -160,9 +170,9 @@ impl Function { let mut data = FunctionData { _library: library, library_len: std::fs::metadata(shared_object.path())?.len(), - input_size: input_size_in_floats * Type::Float.size(), + input_size: input_size_in_floats, input_layout: input_layout.into(), - output_size: output_size_in_floats * Type::Float.size(), + output_size: output_size_in_floats, output_layout, fn_ptr, graph, @@ -195,8 +205,8 @@ impl Function { let input = input.as_ref(); let output = output.as_mut(); - assert_eq!(self.data.input_size, input.len()); - assert_eq!(self.data.output_size, output.len()); + assert_eq!(self.data.input_size.in_bytes(), input.len()); + assert_eq!(self.data.output_size.in_bytes(), output.len()); // Safety: input and output sizes are checked and function pinky-promisses not to // accesses anything out of bounds. @@ -213,7 +223,7 @@ impl Function { where I: AsRef<[u8]>, { - let mut output = vec![0; self.data.output_size].into_boxed_slice(); + let mut output = vec![0; self.data.output_size.in_bytes()].into_boxed_slice(); let status = self.call_raw(input, &mut output); if status.is_null() { Ok(output) @@ -237,11 +247,11 @@ impl Function { let local_input = self .data .input - .get_or(|| RefCell::new(layout::Visitor::new(self.data.input_size / 8))); + .get_or(|| RefCell::new(layout::Visitor::new(self.data.input_size))); let local_output = self .data .output - .get_or(|| RefCell::new(layout::Visitor::new(self.data.output_size / 8))); + .get_or(|| RefCell::new(layout::Visitor::new(self.data.output_size))); let mut encode_visitor = local_input.borrow_mut(); encode_visitor.reset(); let mut decode_visitor = local_output.borrow_mut(); diff --git a/jyafn/src/graph/compile/mod.rs b/jyafn/src/graph/compile/mod.rs index 2d863c8..a2dbc05 100644 --- a/jyafn/src/graph/compile/mod.rs +++ b/jyafn/src/graph/compile/mod.rs @@ -9,23 +9,51 @@ use tempfile::NamedTempFile; use crate::Function; -use super::{Error, Graph, Node}; +use super::{Error, Graph, Node, SLOT_SIZE}; impl Graph { + /// Renders this graph as a QBE module. This fails if the graph contains illegal + /// operations that cannot be optimized away (e.g., unconditional errors). + pub fn render(&self) -> Result, Error> { + let mut module = qbe::Module::new(); + let mut graph = self.clone(); + graph.do_check_optimize()?; + graph.do_render(&mut module, "run"); + + Ok(module) + } + + /// Finds illegal instructions in graphs. fn find_illegal(&self) -> Option<&Node> { self.nodes .iter() .find(|node| node.op.is_illegal(&node.args)) } - /// Renders this graph as a QBE module. - pub fn render(&self) -> qbe::Module<'static> { - let mut module = qbe::Module::new(); - self.clone().do_render(&mut module, "run"); - module + /// Performs optimizations in the current graph. These optimizations currently are, + /// in this order: + /// 1. Constant evaluation: things like `1 * x` or `2 + 2`, which we already know the + /// result beforehand. + /// 2. Reachability eliminations: remove nodes that will never be computed. + /// 3. Finds illegal instructions that remain: thigs that are not allowed, such as + /// unconditionally failing assertions. + fn do_check_optimize(&mut self) -> Result<(), Error> { + // Constant evaluation: + optimize::const_eval(self); + + // Reachability (needs to be after const eval): + let reachable = optimize::find_reachable(&self.outputs, &self.nodes); + optimize::remap_reachable(self, &reachable); + + // Find illegal (needs to be after reachability): + if let Some(node) = self.find_illegal() { + return Err(Error::IllegalInstruction(format!("{node:?}"))); + } + + Ok(()) } - fn do_render(&mut self, module: &mut qbe::Module<'static>, namespace: &str) { + fn do_render(&self, module: &mut qbe::Module<'static>, namespace: &str) { // Rendering main: let main = module.add_function(qbe::Function::new( qbe::Linkage::public(), @@ -48,12 +76,11 @@ impl Graph { qbe::Value::Temporary("in".to_string()), qbe::Type::Long, qbe::Instr::Add( - qbe::Value::Const(input.size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), qbe::Value::Temporary("in".to_string()), ), ); } - // This is the old naive implementation, kept here in case you need a quick // rollback... // // Supposes that the nodes were already declared in topological order: @@ -64,11 +91,8 @@ impl Graph { // } // } - // This is the fancier implementation, that passes the right nodes to the inside - // of conditionals. - optimize::const_eval(self); - let reachable = optimize::find_reachable(&self.outputs, &self.nodes); - optimize::Statements::build(&self.nodes).render_into(self, &reachable, main, namespace); + // optimize::Statements::build(&self.nodes).render_into(self, &reachable, main, namespace); + optimize::Statements::build(&self.nodes).render_into(self, main, namespace); for output in &self.outputs { main.add_instr(qbe::Instr::Store( @@ -80,7 +104,7 @@ impl Graph { qbe::Value::Temporary("out".to_string()), qbe::Type::Long, qbe::Instr::Add( - qbe::Value::Const(self.type_of(*output).size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), qbe::Value::Temporary("out".to_string()), ), ); @@ -107,7 +131,7 @@ impl Graph { } // Render sub-graphs: - for (i, subgraph) in self.subgraphs.iter_mut().enumerate() { + for (i, subgraph) in self.subgraphs.iter().enumerate() { subgraph.do_render(module, &format!("{namespace}.graph.{i}")) } } @@ -134,22 +158,14 @@ impl Graph { /// Renders this graph as assembly code for the current machine's architecture, /// using a standard assembler under the hood. pub fn render_assembly(&self) -> Result { - let rendered = self.render(); + let rendered = self.render()?; create_assembly(rendered) } /// Compiles this graph to machine code and loads the resulting shared object into /// the current process. pub fn compile(&self) -> Result { - let mut graph = self.clone(); - let mut module = qbe::Module::new(); - graph.do_render(&mut module, "run"); - - if let Some(node) = graph.find_illegal() { - return Err(Error::IllegalInstruction(format!("{node:?}"))); - } - - let assembly = create_assembly(module)?; + let assembly = self.render_assembly()?; let unlinked = assemble(&assembly)?; let shared_object = link(&unlinked)?; diff --git a/jyafn/src/graph/compile/optimize.rs b/jyafn/src/graph/compile/optimize.rs index fce2acd..e540b90 100644 --- a/jyafn/src/graph/compile/optimize.rs +++ b/jyafn/src/graph/compile/optimize.rs @@ -1,6 +1,6 @@ //! Graph optimizations (those not covered by qbe). -use std::collections::BTreeSet; +use std::collections::{BTreeSet, BTreeMap}; use crate::{Graph, Node, Ref}; @@ -12,6 +12,7 @@ pub fn find_reachable(outputs: &[Ref], nodes: &[Node]) -> Vec { let mut stack = outputs .iter() .filter_map(|r| { + // All output nodes. if let &Ref::Node(node_id) = r { Some(node_id) } else { @@ -43,6 +44,43 @@ pub fn find_reachable(outputs: &[Ref], nodes: &[Node]) -> Vec { reachable } +/// Remaps the nodes of this graph to exclude unreachable nodes. +pub fn remap_reachable(graph: &mut Graph, reachable: &[bool]) { + // Create new ids: + let id_map = reachable + .iter() + .enumerate() + .filter(|(_, &is_reachable)| is_reachable) + .map(|(old_id, _)| old_id) + .enumerate() + .map(|(new_id, old_id)| (old_id, new_id)) + .collect::>(); + + // Retain only reachable nodes: + let mut node_id = 0; + graph.nodes.retain(|_| { + let retain = id_map.contains_key(&node_id); + node_id += 1; + retain + }); + + // Rewrite references in nodes: + for node in &mut graph.nodes { + for arg in &mut node.args { + if let Ref::Node(id) = arg { + *id = id_map[id]; + } + } + } + + // Rewrite references in output: + for output in &mut graph.outputs { + if let Ref::Node(id) = output { + *id = id_map[id]; + } + } +} + /// Runs constant evaluation optimization on the graph. pub fn const_eval(graph: &mut Graph) { let mut visited = vec![false; graph.nodes.len()]; @@ -275,13 +313,12 @@ impl Statements { pub fn render_into( &self, graph: &Graph, - reachable: &[bool], func: &mut qbe::Function, namespace: &str, ) { for statement in &self.0 { match statement { - &StatementOrConditional::Statement(node_id) if reachable[node_id] => { + &StatementOrConditional::Statement(node_id) => { let node = &graph.nodes[node_id]; node.op.render_into( graph, @@ -310,7 +347,7 @@ impl Statements { )); func.add_block(true_label); - true_side.render_into(graph, reachable, func, namespace); + true_side.render_into(graph, func, namespace); func.assign_instr( output.clone(), node.ty.render(), @@ -319,7 +356,7 @@ impl Statements { func.add_instr(qbe::Instr::Jmp(end_label.clone())); func.add_block(false_label); - false_side.render_into(graph, reachable, func, namespace); + false_side.render_into(graph, func, namespace); func.assign_instr( output, node.ty.render(), @@ -328,7 +365,6 @@ impl Statements { func.add_block(end_label); } - _ => {} } } } diff --git a/jyafn/src/graph/mod.rs b/jyafn/src/graph/mod.rs index e9260cc..46ba4b9 100644 --- a/jyafn/src/graph/mod.rs +++ b/jyafn/src/graph/mod.rs @@ -2,8 +2,12 @@ mod check; mod compile; mod node; mod serde; +mod r#type; -pub use node::{Node, Ref, Type}; +pub mod size; + +pub use node::{Node, Ref}; +pub use r#type::{Type, SLOT_SIZE}; use get_size::GetSize; use serde_derive::{Deserialize, Serialize}; @@ -171,6 +175,9 @@ impl Graph { .map(|(name, field)| (name.clone(), self.alloc_input(field))) .collect(), ), + Layout::Tuple(fields) => { + RefValue::Tuple(fields.iter().map(|field| self.alloc_input(field)).collect()) + } Layout::List(element, size) => { RefValue::List((0..*size).map(|_| self.alloc_input(element)).collect()) } diff --git a/jyafn/src/graph/node.rs b/jyafn/src/graph/node.rs index 6962b4a..8bea357 100644 --- a/jyafn/src/graph/node.rs +++ b/jyafn/src/graph/node.rs @@ -6,88 +6,7 @@ use std::fmt::{self, Display}; use crate::{Error, Op}; use super::Graph; - -/// The primitive types of data that can be represented in the computational graph. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, GetSize)] -#[repr(u8)] -pub enum Type { - /// A floating point number. - Float, - /// A boolean. - Bool, - /// An _id_ referencing a piece of imutable text "somewhere". - Symbol, - /// A pointer, with an origin node id. Pointers _cannot_ appear in the public - /// interface of a graph. - Ptr { origin: usize }, - /// An integer timestamp in microseconds. - DateTime, -} - -impl TryFrom for Type { - type Error = Error; - - fn try_from(v: u8) -> Result { - match v { - 0 => Ok(Type::Float), - 1 => Ok(Type::Bool), - 2 => Ok(Type::Symbol), - 3 => Ok(Type::Ptr { origin: usize::MAX }), - 4 => Ok(Type::DateTime), - _ => Err(format!("{v} is not a valid type id"))?, - } - } -} - -impl Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Type::Float => write!(f, "scalar"), - Type::Bool => write!(f, "bool"), - Type::Symbol => write!(f, "symbol"), - Type::Ptr { origin } => write!(f, "ptr@{origin}"), - Type::DateTime => write!(f, "datetime"), - } - } -} - -/// All slots in jyafn are 64 bits long. -pub const SIZE: usize = 8; - -impl Type { - pub(crate) fn render(self) -> qbe::Type<'static> { - match self { - Type::Float => qbe::Type::Double, - Type::Bool => qbe::Type::Long, - Type::Symbol => qbe::Type::Long, - Type::Ptr { .. } => qbe::Type::Long, - Type::DateTime => qbe::Type::Long, - } - } - - /// All types in jyafn are 64 bits long. This function returns a constant. - pub fn size(&self) -> usize { - SIZE - } - - fn print(self, val: u64) -> String { - match self { - Type::Float => format!("{}", f64::from_ne_bytes(val.to_ne_bytes())), - Type::Bool => format!("{}", val == 1), - Type::Symbol => format!("{val}"), - Type::Ptr { .. } => format!("{val:#x}"), - Type::DateTime => { - if let Some(date) = - chrono::DateTime::::from_timestamp_micros(val as i64) - { - format!("{date}",) - } else { - "".to_string() - } - } - } - } -} +use super::Type; /// A reference to a value in a graph. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, GetSize)] diff --git a/jyafn/src/graph/size.rs b/jyafn/src/graph/size.rs new file mode 100644 index 0000000..bd3cf69 --- /dev/null +++ b/jyafn/src/graph/size.rs @@ -0,0 +1,67 @@ +//! Utilities for dealing with memory sizes without fantastically messing up the units. + +use std::{ + iter::Sum, + ops::{Add, Mul}, +}; + +/// A size of something in memory. This is just a newtype on top of a `usize` that is +/// also type-checked to make sure that we are representing that size in bytes, in jyafn +/// slots, etc... and not fantastically mess up the units. +#[derive(Debug, Clone, Copy, Default)] +pub struct Size(usize); + +impl Size { + /// Gets this size represented in _bytes_. + pub const fn in_bytes(self) -> usize { + self.0 + } +} + +impl Mul for usize { + type Output = Size; + fn mul(self, other: Size) -> Size { + Size(self * other.0) + } +} + +impl Add for Size { + type Output = Size; + fn add(self, other: Size) -> Size { + Size(self.0 + other.0) + } +} + +impl Sum for Size { + fn sum>(iter: I) -> Self { + let mut sum = Size::default(); + for el in iter { + sum = sum + el; + } + sum + } +} + +/// Represents a unit of memory size to be used in [`Size`]. +pub trait Unit: Send + Sync + Copy { + /// The size of "1 unit". + const UNIT: Size; +} + +/// The unit of "1 byte". +#[derive(Debug, Clone, Copy)] +pub struct InBytes; + +impl Unit for InBytes { + const UNIT: Size = Size(1); +} + +/// The unit of "1 jyafn slot". +#[derive(Debug, Clone, Copy)] +pub struct InSlots; + +const SLOT_SIZE: usize = 8; + +impl Unit for InSlots { + const UNIT: Size = Size(SLOT_SIZE); +} diff --git a/jyafn/src/graph/type.rs b/jyafn/src/graph/type.rs new file mode 100644 index 0000000..18f5d06 --- /dev/null +++ b/jyafn/src/graph/type.rs @@ -0,0 +1,84 @@ +use get_size::GetSize; +use serde_derive::{Deserialize, Serialize}; +use std::fmt::{self, Display}; + +use crate::Error; + +use super::size::{InSlots, Size, Unit}; + +/// The primitive types of data that can be represented in the computational graph. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, GetSize)] +#[repr(u8)] +pub enum Type { + /// A floating point number. + Float, + /// A boolean. + Bool, + /// An _id_ referencing a piece of imutable text "somewhere". + Symbol, + /// A pointer, with an origin node id. Pointers _cannot_ appear in the public + /// interface of a graph. + Ptr { origin: usize }, + /// An integer timestamp in microseconds. + DateTime, +} + +impl TryFrom for Type { + type Error = Error; + + fn try_from(v: u8) -> Result { + match v { + 0 => Ok(Type::Float), + 1 => Ok(Type::Bool), + 2 => Ok(Type::Symbol), + 3 => Ok(Type::Ptr { origin: usize::MAX }), + 4 => Ok(Type::DateTime), + _ => Err(format!("{v} is not a valid type id"))?, + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Float => write!(f, "scalar"), + Type::Bool => write!(f, "bool"), + Type::Symbol => write!(f, "symbol"), + Type::Ptr { origin } => write!(f, "ptr@{origin}"), + Type::DateTime => write!(f, "datetime"), + } + } +} + +/// All slots in jyafn are 64 bits long. +pub const SLOT_SIZE: Size = InSlots::UNIT; + +impl Type { + pub(crate) fn render(self) -> qbe::Type<'static> { + match self { + Type::Float => qbe::Type::Double, + Type::Bool => qbe::Type::Long, + Type::Symbol => qbe::Type::Long, + Type::Ptr { .. } => qbe::Type::Long, + Type::DateTime => qbe::Type::Long, + } + } + + pub(crate) fn print(self, val: u64) -> String { + match self { + Type::Float => format!("{}", f64::from_ne_bytes(val.to_ne_bytes())), + Type::Bool => format!("{}", val == 1), + Type::Symbol => format!("{val}"), + Type::Ptr { .. } => format!("{val:#x}"), + Type::DateTime => { + if let Some(date) = + chrono::DateTime::::from_timestamp_micros(val as i64) + { + format!("{date}",) + } else { + "".to_string() + } + } + } + } +} diff --git a/jyafn/src/layout/decode.rs b/jyafn/src/layout/decode.rs index eabf625..266b914 100644 --- a/jyafn/src/layout/decode.rs +++ b/jyafn/src/layout/decode.rs @@ -79,6 +79,11 @@ impl Decode for serde_json::Value { .map(|(name, field)| (name.clone(), Self::build(field, symbols, visitor))) .collect::>() .into(), + Layout::Tuple(fields) => fields + .iter() + .map(|field| Self::build(field, symbols, visitor)) + .collect::>() + .into(), Layout::List(element, size) => (0..*size) .map(|_| Self::build(element, symbols, visitor)) .collect::>() diff --git a/jyafn/src/layout/encode.rs b/jyafn/src/layout/encode.rs index 28025bb..2b0a12d 100644 --- a/jyafn/src/layout/encode.rs +++ b/jyafn/src/layout/encode.rs @@ -258,6 +258,47 @@ impl> Encode for BTreeMap { } } +macro_rules! impl_encode_tuple { + ($($n:tt: $typ:ident),*) => { + impl< $( $typ, )* > Encode for ( $( $typ, )* ) where $($typ: Encode),* { + type Err = Error; + fn visit( + &self, + layout: &Layout, + symbols: &mut dyn Sym, + visitor: &mut Visitor, + ) -> Result<(), Error> { + match layout { + Layout::Tuple(fields) => { + $( + self.$n.visit( + fields.get($n).ok_or_else(|| format!("missing field {} in tuple", $n))?, + symbols, + visitor, + )?; + )* + } + _ => return Err("expected a tuple".to_string().into()), + } + + Ok(()) + } + } + }; +} + +impl_encode_tuple!(0: A); +impl_encode_tuple!(0: A, 1: B); +impl_encode_tuple!(0: A, 1: B, 2: C); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J); +impl_encode_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K); + impl Encode for serde_json::Value { type Err = Error; fn visit( diff --git a/jyafn/src/layout/mod.rs b/jyafn/src/layout/mod.rs index 0dbae14..bb377e6 100644 --- a/jyafn/src/layout/mod.rs +++ b/jyafn/src/layout/mod.rs @@ -20,6 +20,8 @@ use serde_derive::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::{self, Display}; +use crate::size::{InSlots, Size, Unit}; + use super::{Ref, Type}; /// The `strptime` format for ISO 8601, the standard used in the [`Layout::DateTime`] @@ -58,7 +60,7 @@ impl Display for Struct { impl Struct { /// The size in slots of this struct. - pub fn size(&self) -> usize { + pub fn size(&self) -> Size { self.0.iter().map(|(_, layout)| layout.size()).sum() } @@ -121,11 +123,7 @@ impl Struct { } for name in other_keys { - let (_, self_field) = self - .0 - .iter() - .find(|&(n, _)| n == name) - .expect("key exists"); + let (_, self_field) = self.0.iter().find(|&(n, _)| n == name).expect("key exists"); let (_, other_field) = other .0 .iter() @@ -159,8 +157,10 @@ pub enum Layout { DateTime(String), /// An imutable piece of text. Symbol, - /// An ordered sequence of values, layed out in memory sequentially. + /// An ordered sequence of named values, layed out in memory sequentially. Struct(Struct), + /// An ordered sequence of unnamed values, layed out in memory sequentially. + Tuple(Vec), /// A layout repeated a given number of times. List(Box, usize), } @@ -182,6 +182,15 @@ impl Display for Layout { Layout::Symbol => write!(f, "symbol"), Layout::Struct(fields) if f.alternate() => write!(f, "{fields:#}"), Layout::Struct(fields) => write!(f, "{fields}"), + Layout::Tuple(fields) => write!( + f, + "({})", + fields + .iter() + .map(|field| field.to_string()) + .collect::>() + .join(", ") + ), Layout::List(element, size) if element.as_ref() == &Layout::Scalar => { write!(f, "[{size}]") } @@ -192,15 +201,16 @@ impl Display for Layout { impl Layout { /// The size in slots of this struct. - pub fn size(&self) -> usize { + pub fn size(&self) -> Size { match self { - Layout::Unit => 0, - Layout::Scalar => 1, - Layout::Bool => 1, - Layout::DateTime(_) => 1, - Layout::Symbol => 1, + Layout::Unit => 0 * InSlots::UNIT, + Layout::Scalar => 1 * InSlots::UNIT, + Layout::Bool => 1 * InSlots::UNIT, + Layout::DateTime(_) => 1 * InSlots::UNIT, + Layout::Symbol => 1 * InSlots::UNIT, Layout::Struct(fields) => fields.size(), - Layout::List(element, size) => size * element.size(), + Layout::Tuple(fields) => fields.iter().map(Layout::size).sum(), + Layout::List(element, size) => *size * element.size(), } } @@ -213,6 +223,7 @@ impl Layout { Layout::DateTime(_) => vec![Type::DateTime], Layout::Symbol => vec![Type::Symbol], Layout::Struct(fields) => fields.slots(), + Layout::Tuple(fields) => fields.iter().map(Layout::slots).flatten().collect(), Layout::List(element, size) => [element.slots()] .into_iter() .cycle() @@ -241,6 +252,14 @@ impl Layout { }) .collect::>()?, ), + Layout::Tuple(fields) => RefValue::Tuple( + fields + .iter() + .map(|field| { + Some(field.build_ref_value_inner(it.by_ref())?) + }) + .collect::>()?, + ), Layout::List(element, size) => RefValue::List( (0..*size) .map(|_| element.build_ref_value_inner(it.by_ref())) diff --git a/jyafn/src/layout/ref_value.rs b/jyafn/src/layout/ref_value.rs index a2f4a75..1a9bcc5 100644 --- a/jyafn/src/layout/ref_value.rs +++ b/jyafn/src/layout/ref_value.rs @@ -20,6 +20,8 @@ pub enum RefValue { Symbol(Ref), /// A struct of values. Struct(HashMap), + /// Atuple of values. + Tuple(Vec), /// A list of values, all of the same layout. List(Vec), } @@ -39,6 +41,13 @@ impl Display for RefValue { } write!(f, "}}") } + Self::Tuple(fields) => { + write!(f, "( ")?; + for field in fields { + write!(f, "{field}, ")?; + } + write!(f, ")") + } Self::List(list) => { write!(f, "[ ")?; for field in list { @@ -68,6 +77,9 @@ impl RefValue { strct.sort_unstable_by_key(|(n, _)| n.clone()); strct })), + Self::Tuple(fields) => { + Layout::Tuple(fields.iter().map(Self::putative_layout).collect()) + } Self::List(list) => { if let Some(first) = list.first() { Layout::List(Box::new(first.putative_layout()), list.len()) @@ -99,6 +111,15 @@ impl RefValue { vals.get(name)?.build_output_vec(field, buf); } } + (Self::Tuple(vals), Layout::Tuple(fields)) => { + if vals.len() != fields.len() { + return None; + } + + for (val, field) in vals.iter().zip(fields) { + val.build_output_vec(field, buf)?; + } + } (Self::List(list), Layout::List(element, size)) if list.len() == *size => { for item in list { item.build_output_vec(element, buf)?; diff --git a/jyafn/src/layout/visitor.rs b/jyafn/src/layout/visitor.rs index 82fdd88..a0b0484 100644 --- a/jyafn/src/layout/visitor.rs +++ b/jyafn/src/layout/visitor.rs @@ -1,6 +1,8 @@ use byte_slice_cast::*; use std::convert::AsRef; +use crate::size::Size; + /// A builder of binary data to be sent to and from functions. This represents a sequence /// of slots of 64-bit data that can be grown by pushing more 64-bid data into it. #[derive(Debug, Clone)] @@ -13,8 +15,8 @@ impl AsRef<[u8]> for Visitor { } impl Visitor { - pub(crate) fn new(size: usize) -> Visitor { - Visitor(vec![0; size * 8].into_boxed_slice(), 0) + pub(crate) fn new(size: Size) -> Visitor { + Visitor(vec![0; size.in_bytes()].into_boxed_slice(), 0) } pub(crate) fn into_inner(self) -> Box<[u8]> { diff --git a/jyafn/src/lib.rs b/jyafn/src/lib.rs index e2dfbeb..195972f 100644 --- a/jyafn/src/lib.rs +++ b/jyafn/src/lib.rs @@ -20,10 +20,13 @@ pub use dataset::Dataset; pub use function::{FnError, Function, FunctionData, RawFn}; pub use graph::{Graph, IndexedList, Node, Ref, Type}; pub use op::Op; +pub use graph::size; pub use r#const::Const; use std::{ + borrow::Cow, error::Error as StdError, + ffi::CStr, fmt::{self, Debug, Display}, process::ExitStatus, }; @@ -37,7 +40,7 @@ pub enum Error { #[error("reference for {0:?} has already been defined")] AlreadyDefined(String), #[error("io error: {0}")] - Io(std::io::Error), + Io(#[from] std::io::Error), #[error("found illegal instruction: {0}")] IllegalInstruction(String), #[error("qbe failed with {status}: {err}")] @@ -47,9 +50,9 @@ pub enum Error { #[error("linker failed with status {status}: {err}")] Linker { status: ExitStatus, err: String }, #[error("loader error: {0}")] - Loader(libloading::Error), - #[error("function raised status: {0}")] - StatusRaised(String), + Loader(#[from] libloading::Error), + #[error("function raised status: {0:?}")] + StatusRaised(Cow<'static, CStr>), #[error("encode error: {0}")] EncodeError(Box), #[error("wrong layout: expected {expected}, got {got}")] @@ -63,11 +66,11 @@ pub enum Error { got: layout::RefValue, }, #[error("bincode error: {0}")] - Bincode(bincode::Error), + Bincode(#[from] bincode::Error), #[error("json error: {0}")] - Json(serde_json::Error), + Json(#[from] serde_json::Error), #[error("zip error: {0}")] - Zip(zip::result::ZipError), + Zip(#[from] zip::result::ZipError), #[error("{0}")] Other(String), #[error("{error}\n\n{context}")] @@ -77,24 +80,6 @@ pub enum Error { }, } -impl From for Error { - fn from(err: std::io::Error) -> Error { - Error::Io(err) - } -} - -impl From for Error { - fn from(err: libloading::Error) -> Error { - Error::Loader(err) - } -} - -impl From for Error { - fn from(err: zip::result::ZipError) -> Error { - Error::Zip(err) - } -} - impl From for Error { fn from(err: String) -> Error { Error::Other(err) @@ -184,7 +169,7 @@ mod test { #[test] fn test_render_simple_graph() { let graph = create_simple_graph(); - println!("{}", graph.render()); + println!("{}", graph.render().unwrap()); } #[test] @@ -203,7 +188,7 @@ mod test { fn test_run_simple_graph() { let graph = create_simple_graph(); let func = graph.compile().unwrap(); - println!("{}", graph.render()); + println!("{}", graph.render().unwrap()); println!("{}", graph.render_assembly().unwrap()); let i = [5.0, 6.0]; @@ -231,7 +216,7 @@ mod test { fn test_run_pfunc() { let graph = create_pfunc_graph(); let func = graph.compile().unwrap(); - println!("{}", graph.render()); + println!("{}", graph.render().unwrap()); println!("{:?}", func); let num = 4.0; @@ -262,7 +247,7 @@ mod test { fn test_run_abs() { let graph = create_abs_graph(); let func = graph.compile().unwrap(); - println!("{}", graph.render()); + println!("{}", graph.render().unwrap()); println!("{:?}", func); let num = 4.0; diff --git a/jyafn/src/op/arithmetic.rs b/jyafn/src/op/arithmetic.rs index f3a7b33..71f12e9 100644 --- a/jyafn/src/op/arithmetic.rs +++ b/jyafn/src/op/arithmetic.rs @@ -43,6 +43,10 @@ impl Op for Add { return Some(args[0]); } + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + return Some((x + y).into()); + } + None } } @@ -82,6 +86,10 @@ impl Op for Sub { return Some(args[0]); } + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + return Some((x - y).into()); + } + None } } @@ -125,6 +133,10 @@ impl Op for Mul { return Some(args[0]); } + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + return Some((x * y).into()); + } + None } } @@ -164,6 +176,10 @@ impl Op for Div { return Some(args[0]); } + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + return Some((x / y).into()); + } + None } } @@ -196,8 +212,8 @@ impl Op for Rem { } fn const_eval(&self, args: &[Ref]) -> Option { - if Ref::from(1.0) == args[1] { - return Some(args[0]); + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + return Some((x % y).into()); } None @@ -235,8 +251,8 @@ impl Op for Neg { } fn const_eval(&self, args: &[Ref]) -> Option { - if Ref::from(0.0) == args[0] { - return Some(Ref::from(0.0)); + if let Some(x) = args[0].as_f64() { + return Some((-x).into()); } None @@ -305,4 +321,12 @@ impl Op for Abs { func.add_block(end_side); } + + fn const_eval(&self, args: &[Ref]) -> Option { + if let Some(x) = args[0].as_f64() { + return Some(x.abs().into()); + } + + None + } } diff --git a/jyafn/src/op/call.rs b/jyafn/src/op/call.rs index 091374e..09297cf 100644 --- a/jyafn/src/op/call.rs +++ b/jyafn/src/op/call.rs @@ -1,7 +1,7 @@ use get_size::GetSize; use serde_derive::{Deserialize, Serialize}; -use crate::{impl_is_eq, impl_op, pfunc, Graph, Ref, Type}; +use crate::{graph::SLOT_SIZE, impl_is_eq, impl_op, pfunc, Graph, Ref, Type}; use super::{unique_for, Op}; @@ -101,14 +101,14 @@ impl Op for CallGraph { graph.subgraphs[self.0] .inputs .iter() - .map(|ty| ty.size()) + .map(|ty| SLOT_SIZE.in_bytes()) .sum::() as u64, ), ); func.assign_instr( output_ptr.clone(), qbe::Type::Long, - qbe::Instr::Alloc8(graph.subgraphs[self.0].output_layout.size() as u64 * 8), + qbe::Instr::Alloc8(graph.subgraphs[self.0].output_layout.size().in_bytes() as u64), ); func.assign_instr( @@ -128,7 +128,7 @@ impl Op for CallGraph { qbe::Type::Long, qbe::Instr::Add( data_ptr.clone(), - qbe::Value::Const(subgraph.type_of(arg).size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), ), ); } diff --git a/jyafn/src/op/compare.rs b/jyafn/src/op/compare.rs index cc7fdee..d9870a2 100644 --- a/jyafn/src/op/compare.rs +++ b/jyafn/src/op/compare.rs @@ -51,8 +51,8 @@ impl Op for Eq { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_f64(), args[1].as_f64()) { - Some(Ref::from(a == b)) + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + Some(Ref::from(x == y)) } else { None } @@ -95,8 +95,8 @@ impl Op for Gt { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_f64(), args[1].as_f64()) { - Some(Ref::from(a > b)) + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + Some(Ref::from(x > y)) } else { None } @@ -139,8 +139,8 @@ impl Op for Lt { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_f64(), args[1].as_f64()) { - Some(Ref::from(a < b)) + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + Some(Ref::from(x < y)) } else { None } @@ -183,8 +183,8 @@ impl Op for Ge { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_f64(), args[1].as_f64()) { - Some(Ref::from(a >= b)) + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + Some(Ref::from(x >= y)) } else { None } @@ -227,8 +227,8 @@ impl Op for Le { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_f64(), args[1].as_f64()) { - Some(Ref::from(a <= b)) + if let Some((x, y)) = args[0].as_f64().zip(args[1].as_f64()) { + Some(Ref::from(x <= y)) } else { None } diff --git a/jyafn/src/op/convert.rs b/jyafn/src/op/convert.rs index d29a836..db53179 100644 --- a/jyafn/src/op/convert.rs +++ b/jyafn/src/op/convert.rs @@ -4,7 +4,7 @@ use crate::{impl_op, Graph, Ref, Type}; use super::Op; -/// Converts a float to a boolean. This is equivalent to `a == 1`. +/// Converts a float to a boolean. This is equivalent to `a != 0`. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ToBool; @@ -32,12 +32,20 @@ impl Op for ToBool { Type::Bool.render(), qbe::Instr::Cmp( Type::Float.render(), - qbe::Cmp::Eq, + qbe::Cmp::Ne, args[0].render(), - qbe::Value::Const(1), + qbe::Value::Const(0), ), ) } + + fn const_eval(&self, args: &[Ref]) -> Option { + if let Some(x) = args[0].as_f64() { + return Some((x != 0.0).into()); + } + + None + } } /// Converts a boolean to a float. This is equivalent to `if a then 1.0 else 0.0`. @@ -69,4 +77,12 @@ impl Op for ToFloat { qbe::Instr::Ultof(args[0].render()), ) } + + fn const_eval(&self, args: &[Ref]) -> Option { + if let Some(x) = args[0].as_bool() { + return Some((x as i64 as f64).into()); + } + + None + } } diff --git a/jyafn/src/op/list.rs b/jyafn/src/op/list.rs index 4dcfcef..a65deab 100644 --- a/jyafn/src/op/list.rs +++ b/jyafn/src/op/list.rs @@ -1,6 +1,6 @@ use serde_derive::{Deserialize, Serialize}; -use crate::{impl_op, Graph, Ref, Type}; +use crate::{graph::SLOT_SIZE, impl_op, Graph, Ref, Type}; use super::{unique_for, Op}; @@ -34,7 +34,7 @@ impl Op for List { func.assign_instr( output.clone(), qbe::Type::Long, - qbe::Instr::Alloc8((self.element.size() * self.n_elements) as u64), + qbe::Instr::Alloc8((self.n_elements * SLOT_SIZE).in_bytes() as u64), ); func.assign_instr( data_ptr.clone(), @@ -53,7 +53,7 @@ impl Op for List { qbe::Type::Long, qbe::Instr::Add( data_ptr.clone(), - qbe::Value::Const(self.element.size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), ), ) } @@ -134,7 +134,7 @@ impl Op for Index { qbe::Type::Long, qbe::Instr::Mul( displacement.clone(), - qbe::Value::Const(self.element.size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), ), ); func.assign_instr( @@ -236,7 +236,7 @@ impl Op for IndexOf { qbe::Type::Long, qbe::Instr::Add( displacement.clone(), - qbe::Value::Const(self.element.size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), ), ); } diff --git a/jyafn/src/op/logic.rs b/jyafn/src/op/logic.rs index 630d5f5..07cb63c 100644 --- a/jyafn/src/op/logic.rs +++ b/jyafn/src/op/logic.rs @@ -202,7 +202,7 @@ impl Op for And { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_bool(), args[1].as_bool()) { + if let Some((a, b)) = args[0].as_bool().zip(args[1].as_bool()) { Some(Ref::from(a && b)) } else { None @@ -241,7 +241,7 @@ impl Op for Or { } fn const_eval(&self, args: &[Ref]) -> Option { - if let (Some(a), Some(b)) = (args[0].as_bool(), args[1].as_bool()) { + if let Some((a, b)) = args[0].as_bool().zip(args[1].as_bool()) { Some(Ref::from(a || b)) } else { None diff --git a/jyafn/src/op/mod.rs b/jyafn/src/op/mod.rs index 09c51bc..83aeade 100644 --- a/jyafn/src/op/mod.rs +++ b/jyafn/src/op/mod.rs @@ -138,7 +138,7 @@ fn unique_for(v: qbe::Value, prefix: &str) -> String { format!("{prefix}_{name}") } -/// Renders the call to create an [`FnError`] out of a static string in jyafn code. +/// Renders the call to create an [`FnError`] out of a static C-Style string in jyafn code. pub(crate) fn render_return_error(func: &mut qbe::Function, error: qbe::Value) { let error_ptr = qbe::Value::Temporary("__error_ptr".to_string()); func.assign_instr( @@ -151,3 +151,18 @@ pub(crate) fn render_return_error(func: &mut qbe::Function, error: qbe::Value) { ); func.add_instr(qbe::Instr::Ret(Some(error_ptr))); } + +/// Renders the call to create an [`FnError`] out of an allocated C-style string in jyafn +/// code. +pub(crate) fn render_return_allocated_error(func: &mut qbe::Function, error: qbe::Value) { + let error_ptr = qbe::Value::Temporary("__error_ptr".to_string()); + func.assign_instr( + error_ptr.clone(), + qbe::Type::Long, + qbe::Instr::Call( + qbe::Value::Const(FnError::make_allocated as usize as u64), + vec![(qbe::Type::Long, error)], + ), + ); + func.add_instr(qbe::Instr::Ret(Some(error_ptr))); +} diff --git a/jyafn/src/op/resource.rs b/jyafn/src/op/resource.rs index bf13613..9bfb6e3 100644 --- a/jyafn/src/op/resource.rs +++ b/jyafn/src/op/resource.rs @@ -1,7 +1,7 @@ use get_size::GetSize; use serde_derive::{Deserialize, Serialize}; -use crate::{impl_is_eq, impl_op, Graph, Ref, Type}; +use crate::{graph::SLOT_SIZE, impl_is_eq, impl_op, Graph, Ref, Type}; use super::{unique_for, Op}; @@ -80,7 +80,7 @@ impl Op for CallResource { qbe::Type::Long, qbe::Instr::Add( data_ptr.clone(), - qbe::Value::Const(graph.type_of(arg).size() as u64), + qbe::Value::Const(SLOT_SIZE.in_bytes() as u64), ), ); } @@ -109,8 +109,7 @@ impl Op for CallResource { end_side.clone(), )); func.add_block(raise_side); - // This status is already a `*mut FnError`. So, no need to make. - func.add_instr(qbe::Instr::Ret(Some(status))); + super::render_return_allocated_error(func, status); func.add_block(end_side); func.assign_instr(output, qbe::Type::Long, qbe::Instr::Copy(output_ptr)); } diff --git a/jyafn/src/pfunc.rs b/jyafn/src/pfunc.rs index 5e8dcf2..d4b0180 100644 --- a/jyafn/src/pfunc.rs +++ b/jyafn/src/pfunc.rs @@ -75,7 +75,7 @@ pub struct PFunc { fn_ptr: ThreadsafePointer, /// The input types of the function. signature: &'static [Type], - /// The return type of the function. /// The return type of the function + /// The return type of the function. returns: Type, /// Provides compile-time evaluation behavior. pub(crate) const_eval: ConstEval, diff --git a/jyafn/src/resource/external.rs b/jyafn/src/resource/external.rs index d171e30..710b821 100644 --- a/jyafn/src/resource/external.rs +++ b/jyafn/src/resource/external.rs @@ -156,6 +156,7 @@ impl Resource for ExternalResource { fn get_method(&self, method: &str) -> Option { let c_method = CString::new(method.as_bytes()).expect("method cannot contain nul bytes"); + let extension = self.r#type.extension(); let resource = self.r#type.resource(); let external_method = unsafe { @@ -165,7 +166,7 @@ impl Resource for ExternalResource { return None; } scopeguard::defer! { - (resource.fn_drop_method_def)(maybe_method) + (extension.string.fn_drop)(maybe_method) } serde_json::from_slice::(CStr::from_ptr(maybe_method).to_bytes()) diff --git a/jyafn/src/resource/mod.rs b/jyafn/src/resource/mod.rs index 9fd0785..3628372 100644 --- a/jyafn/src/resource/mod.rs +++ b/jyafn/src/resource/mod.rs @@ -15,11 +15,11 @@ use std::sync::Arc; use zip::read::ZipFile; use crate::layout::{Layout, Struct}; -use crate::{Error, FnError}; +use crate::Error; /// The signature of the function that will be invoked from inside the function code. pub type RawResourceMethod = - unsafe extern "C" fn(*const (), *const u8, u64, *mut u8, u64) -> *mut FnError; + unsafe extern "C" fn(*const (), *const u8, u64, *mut u8, u64) -> *mut u8; /// A method from a resource. #[derive(Debug)] @@ -298,7 +298,7 @@ macro_rules! safe_method { input_slots: u64, output_ptr: *mut u8, output_slots: u64, - ) -> *mut $crate::FnError { + ) -> *mut u8 { match std::panic::catch_unwind(|| { unsafe { // Safety: all this stuff came from jyafn code. The jyafn code should @@ -315,15 +315,9 @@ macro_rules! safe_method { } }) { Ok(Ok(())) => std::ptr::null_mut(), - Ok(Err(err)) => { - let boxed = Box::new(err.to_string().into()); - Box::leak(boxed) - } - // DON'T forget the nul character when working with bytes directly! - Err(_) => { - let boxed = Box::new("method panicked. See stderr".to_string().into()); - Box::leak(boxed) - } + Ok(Err(err)) => crate::utils::make_safe_c_str(err).into_raw() as *mut u8, + Err(_) => crate::utils::make_safe_c_str("method panicked. See stderr".to_string()) + .into_raw() as *mut u8, } } diff --git a/jyafn/src/utils.rs b/jyafn/src/utils.rs index 67ab750..72e3e72 100644 --- a/jyafn/src/utils.rs +++ b/jyafn/src/utils.rs @@ -1,5 +1,7 @@ //! Utilities for this crate. +use std::ffi::CString; + use chrono::{ format::{ParseError, ParseErrorKind}, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Utc, @@ -77,6 +79,31 @@ pub fn int_to_datetime(i: i64) -> DateTime { DateTime::::from(Timestamp::from(i)) } +/// Creates a C-style string out of a `String` in a way that doesn't produce errors. This +/// function substitutes nul characters by the ` ` (space) character. This avoids an +/// allocation. +/// +/// This method **leaks** the string. So, don't forget to guarantee that somene somewhere +/// is freeing it. +/// +/// # Note +/// +/// Yes, I know! It's a pretty lousy implementation that is even... O(n^2) (!!). You can +/// do better than I in 10mins. +pub(crate) fn make_safe_c_str(s: String) -> CString { + let mut v = s.into_bytes(); + loop { + match std::ffi::CString::new(v) { + Ok(c_str) => return c_str, + Err(err) => { + let nul_position = err.nul_position(); + v = err.into_vec(); + v[nul_position] = b' '; + } + } + } +} + #[cfg(test)] mod test { use super::*;