diff --git a/plugins/wasmembed/secure-enclave/CMakeLists.txt b/plugins/wasmembed/secure-enclave/CMakeLists.txt index b4247b79d9b..a669ef21048 100644 --- a/plugins/wasmembed/secure-enclave/CMakeLists.txt +++ b/plugins/wasmembed/secure-enclave/CMakeLists.txt @@ -27,10 +27,17 @@ add_definitions(-D_USRDLL -DSECUREENCLAVE_EXPORTS) add_library(secure-enclave SHARED secure-enclave.cpp + secure-enclave.hpp + abi.cpp + abi.hpp + util.cpp + util.hpp ) target_link_libraries(secure-enclave PRIVATE ${WASMTIME_LIB} + eclrtl + jlib ) install( diff --git a/plugins/wasmembed/secure-enclave/abi.cpp b/plugins/wasmembed/secure-enclave/abi.cpp new file mode 100644 index 00000000000..1f8d1b717e7 --- /dev/null +++ b/plugins/wasmembed/secure-enclave/abi.cpp @@ -0,0 +1,138 @@ +/* + See: https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md + https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py +*/ + +#include "abi.hpp" +#include +#include +#include +#include +#include + +auto UTF16_TAG = 1 << 31; + +int align_to(int ptr, int alignment) +{ + return std::ceil(ptr / alignment) * alignment; +} + +// loading --- + +/* canonical load_int (python) + +def load_int(cx, ptr, nbytes, signed = False): + return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed) + +*/ + +template +T load_int(const wasmtime::Span &data, int32_t ptr) +{ + T retVal = 0; + auto nbytes = sizeof(retVal); + for (int i = 0; i < nbytes; ++i) + { + uint8_t b = data[ptr + i]; + if (i == nbytes - 1 && std::is_signed::value && b >= 0x80) + b -= 0x100; + retVal += b << (i * 8); + } + return retVal; +} + +/* canonical load_string_from_range (python) + +def load_string_from_range(cx, ptr, tagged_code_units): + match cx.opts.string_encoding: + case 'utf8': + alignment = 1 + byte_length = tagged_code_units + encoding = 'utf-8' + case 'utf16': + alignment = 2 + byte_length = 2 * tagged_code_units + encoding = 'utf-16-le' + case 'latin1+utf16': + alignment = 2 + if bool(tagged_code_units & UTF16_TAG): + byte_length = 2 * (tagged_code_units ^ UTF16_TAG) + encoding = 'utf-16-le' + else: + byte_length = tagged_code_units + encoding = 'latin-1' + + trap_if(ptr != align_to(ptr, alignment)) + trap_if(ptr + byte_length > len(cx.opts.memory)) + try: + s = cx.opts.memory[ptr : ptr+byte_length].decode(encoding) + except UnicodeError: + trap() + + return (s, cx.opts.string_encoding, tagged_code_units) + +*/ + +// More: Not currently available from the wasmtime::context object, see https://github.com/bytecodealliance/wasmtime/issues/6719 +std::string global_encoding = "utf8"; + +std::pair load_string_from_range(const wasmtime::Span &data, uint32_t ptr, uint32_t tagged_code_units) +{ + std::string encoding = "utf-8"; + uint32_t byte_length = tagged_code_units; + uint32_t alignment = 1; + if (global_encoding.compare("utf8") == 0) + { + alignment = 1; + byte_length = tagged_code_units; + encoding = "utf-8"; + } + else if (global_encoding.compare("utf16") == 0) + { + alignment = 2; + byte_length = 2 * tagged_code_units; + encoding = "utf-16-le"; + } + else if (global_encoding.compare("latin1+utf16") == 0) + { + alignment = 2; + if (tagged_code_units & UTF16_TAG) + { + byte_length = 2 * (tagged_code_units ^ UTF16_TAG); + encoding = "utf-16-le"; + } + else + { + byte_length = tagged_code_units; + encoding = "latin-1"; + } + } + + if (ptr != align_to(ptr, alignment)) + { + throw std::runtime_error("Invalid alignment"); + } + if (ptr + byte_length > data.size()) + { + throw std::runtime_error("Out of bounds"); + } + + return std::make_pair(ptr, byte_length); +} + +/* canonical load_string (python) + +def load_string(cx, ptr): + begin = load_int(cx, ptr, 4) + tagged_code_units = load_int(cx, ptr + 4, 4) + return load_string_from_range(cx, begin, tagged_code_units) + +*/ +std::pair load_string(const wasmtime::Span &data, uint32_t ptr) +{ + uint32_t begin = load_int(data, ptr); + uint32_t tagged_code_units = load_int(data, ptr + 4); + return load_string_from_range(data, begin, tagged_code_units); +} + +// Storing --- diff --git a/plugins/wasmembed/secure-enclave/abi.hpp b/plugins/wasmembed/secure-enclave/abi.hpp index c547a9fd29a..6b27fe85350 100644 --- a/plugins/wasmembed/secure-enclave/abi.hpp +++ b/plugins/wasmembed/secure-enclave/abi.hpp @@ -1,157 +1,12 @@ -/* - See: https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md - https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py -*/ - #include -#include -#include -#include -#include - -auto UTF16_TAG = 1 << 31; - -int align_to(int ptr, int alignment) -{ - return std::ceil(ptr / alignment) * alignment; -} - -// loading --- - -int load_int(const wasmtime::Span &data, int32_t ptr, int32_t nbytes, bool is_signed = false) -{ - int result = 0; - for (int i = 0; i < nbytes; i++) - { - int b = data[ptr + i]; - if (i == 3 && is_signed && b >= 0x80) - { - b -= 0x100; - } - result += b << (i * 8); - } - return result; -} - -std::string global_encoding = "utf8"; -std::string load_string_from_range(const wasmtime::Span &data, uint32_t ptr, uint32_t tagged_code_units) -{ - std::string encoding = "utf-8"; - uint32_t byte_length = tagged_code_units; - uint32_t alignment = 1; - if (global_encoding.compare("utf8") == 0) - { - alignment = 1; - byte_length = tagged_code_units; - encoding = "utf-8"; - } - else if (global_encoding.compare("utf16") == 0) - { - alignment = 2; - byte_length = 2 * tagged_code_units; - encoding = "utf-16-le"; - } - else if (global_encoding.compare("latin1+utf16") == 0) - { - alignment = 2; - if (tagged_code_units & UTF16_TAG) - { - byte_length = 2 * (tagged_code_units ^ UTF16_TAG); - encoding = "utf-16-le"; - } - else - { - byte_length = tagged_code_units; - encoding = "latin-1"; - } - } - - if (ptr != align_to(ptr, alignment)) - { - throw std::runtime_error("Invalid alignment"); - } - if (ptr + byte_length > data.size()) - { - throw std::runtime_error("Out of bounds"); - } - - std::string s; - s.resize(byte_length); - memcpy(&s[0], &data[ptr], byte_length); - return s; -} - -std::string load_string(const wasmtime::Span &data, uint32_t ptr) -{ - uint32_t begin = load_int(data, ptr, 4); - uint32_t tagged_code_units = load_int(data, ptr + 4, 4); - return load_string_from_range(data, begin, tagged_code_units); -} - -// Storing --- -void store_int(const wasmtime::Span &data, int64_t v, size_t ptr, size_t nbytes, bool _signed = false) -{ - // convert v to little-endian byte array - std::vector bytes(nbytes); - for (size_t i = 0; i < nbytes; i++) - { - bytes[i] = (v >> (i * 8)) & 0xFF; - } - // copy bytes to memory - memcpy(&data[ptr], bytes.data(), nbytes); -} - -// Other --- -std::vector read_wasm_binary_to_buffer(const std::string &filename) -{ - std::ifstream file(filename, std::ios::binary | std::ios::ate); - if (!file) - { - throw std::runtime_error("Failed to open file"); - } - - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); - - std::vector buffer(size); - if (!file.read(reinterpret_cast(buffer.data()), size)) - { - throw std::runtime_error("Failed to read file"); - } - - return buffer; -} - -std::string extractContentInDoubleQuotes(const std::string &input) -{ - int firstQuote = input.find_first_of('"'); - int secondQuote = input.find('"', firstQuote + 1); - if (firstQuote == std::string::npos || secondQuote == std::string::npos) - { - return ""; - } - return input.substr(firstQuote + 1, secondQuote - firstQuote - 1); -} +// ABI +int align_to(int ptr, int alignment); -std::pair splitQualifiedID(const std::string &qualifiedName) -{ - std::istringstream iss(qualifiedName); - std::vector tokens; - std::string token; +// ABI Loading --- +template +T load_int(const wasmtime::Span &data, int32_t ptr); - while (std::getline(iss, token, '.')) - { - tokens.push_back(token); - } - if (tokens.size() != 2) - { - throw std::runtime_error("Invalid import function " + qualifiedName + ", expected format: ."); - } - return std::make_pair(tokens[0], tokens[1]); -} +std::pair load_string(const wasmtime::Span &data, uint32_t ptr); -std::string createQualifiedID(const std::string &wasmName, const std::string &funcName) -{ - return wasmName + "." + funcName; -} \ No newline at end of file +// ABI Storing --- diff --git a/plugins/wasmembed/secure-enclave/secure-enclave.cpp b/plugins/wasmembed/secure-enclave/secure-enclave.cpp index d8dd157f21c..31503b70ef9 100644 --- a/plugins/wasmembed/secure-enclave/secure-enclave.cpp +++ b/plugins/wasmembed/secure-enclave/secure-enclave.cpp @@ -1,14 +1,18 @@ #include "secure-enclave.hpp" +#include "rtlconst.hpp" + +// From deftype.hpp in common +#define UNKNOWN_LENGTH 0xFFFFFFF1 #include "abi.hpp" +#include "util.hpp" #include #include +#include +#include -std::shared_ptr embedContextCallbacks; - -#define NENABLE_TRACE - +#define ENABLE_TRACE #ifdef ENABLE_TRACE #define TRACE(format, ...) embedContextCallbacks->DBGLOG(format, ##__VA_ARGS__) #else @@ -18,18 +22,80 @@ std::shared_ptr embedContextCallbacks; } while (0) #endif -class WasmEngine +template +class ThreadSafeMap { protected: - wasmtime::Engine engine; - - std::map wasmInstances; - std::map wasmMems; - std::map wasmFuncs; + std::unordered_map map; + mutable std::shared_mutex mutex; public: + ThreadSafeMap() {} + ~ThreadSafeMap() {} + + void clear() + { + std::unique_lock lock(mutex); + map.clear(); + } + + void insertIfMissing(const K &key, std::function &valueCallback) + { + std::unique_lock lock(mutex); + if (map.find(key) == map.end()) + map.insert(std::make_pair(key, valueCallback())); + } + + void erase(const K &key) + { + std::unique_lock lock(mutex); + map.erase(key); + } + + bool find(const K &key, std::optional &value) const + { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it != map.end()) + { + value = it->second; + return true; + } + return false; + } + + bool has(const K &key) const + { + std::shared_lock lock(mutex); + return map.find(key) != map.end(); + } + + void for_each(std::function func) const + { + std::shared_lock lock(mutex); + for (auto it = map.begin(); it != map.end(); ++it) + { + func(it->first, it->second); + } + } +}; + +std::shared_ptr embedContextCallbacks; + +class WasmEngine +{ +private: + wasmtime::Engine engine; wasmtime::Store store; + ThreadSafeMap wasmInstances; + // wasmMems and wasmFuncs are only written to during createInstance, so no need for a mutex + std::unordered_map wasmMems; + std::unordered_map wasmFuncs; + + mutable std::shared_mutex store_mutex; + +public: WasmEngine() : store(engine) { } @@ -38,102 +104,104 @@ class WasmEngine { } - bool hasInstance(const std::string &wasmName) + bool hasInstance(const std::string &wasmName) const { - return wasmInstances.find(wasmName) != wasmInstances.end(); + return wasmInstances.has(wasmName); } - wasmtime::Instance getInstance(const std::string &wasmName) + wasmtime::Instance getInstance(const std::string &wasmName) const { - auto instanceItr = wasmInstances.find(wasmName); - if (instanceItr == wasmInstances.end()) - throw std::runtime_error("Wasm instance not found: " + wasmName); - return instanceItr->second; + std::optional instance; + if (!wasmInstances.find(wasmName, instance)) + embedContextCallbacks->throwStringException(-1, "Wasm instance not found: %s", wasmName.c_str()); + return instance.value(); } - void registerInstance(const std::string &wasmName, const std::variant> &wasm) + wasmtime::Instance createInstance(const std::string &wasmName, const std::variant> &wasm) { - TRACE("registerInstance %s", wasmName.c_str()); - auto instanceItr = wasmInstances.find(wasmName); - if (instanceItr == wasmInstances.end()) - { - TRACE("resolveModule %s", wasmName.c_str()); - auto module = std::holds_alternative(wasm) ? wasmtime::Module::compile(engine, std::get(wasm)).unwrap() : wasmtime::Module::compile(engine, std::get>(wasm)).unwrap(); - TRACE("resolveModule2 %s", wasmName.c_str()); - - wasmtime::WasiConfig wasi; - wasi.inherit_argv(); - wasi.inherit_env(); - wasi.inherit_stdin(); - wasi.inherit_stdout(); - wasi.inherit_stderr(); - store.context().set_wasi(std::move(wasi)).unwrap(); - TRACE("resolveModule3 %s", wasmName.c_str()); - - wasmtime::Linker linker(engine); - linker.define_wasi().unwrap(); - TRACE("resolveModule4 %s", wasmName.c_str()); - - auto callback = [this, wasmName](wasmtime::Caller caller, uint32_t msg, uint32_t msg_len) - { - TRACE("callback: %i %i", msg_len, msg); + TRACE("resolveModule %s", wasmName.c_str()); + auto module = std::holds_alternative(wasm) ? wasmtime::Module::compile(engine, std::get(wasm)).unwrap() : wasmtime::Module::compile(engine, std::get>(wasm)).unwrap(); + TRACE("resolveModule2 %s", wasmName.c_str()); - auto data = this->getData(wasmName); - auto msg_ptr = (char *)&data[msg]; - std::string str(msg_ptr, msg_len); - embedContextCallbacks->DBGLOG("from wasm: %s", str.c_str()); - }; - auto host_func = linker.func_wrap("$root", "dbglog", callback).unwrap(); + wasmtime::WasiConfig wasi; + wasi.inherit_argv(); + wasi.inherit_env(); + wasi.inherit_stdin(); + wasi.inherit_stdout(); + wasi.inherit_stderr(); + store.context().set_wasi(std::move(wasi)).unwrap(); + TRACE("resolveModule3 %s", wasmName.c_str()); - auto newInstance = linker.instantiate(store, module).unwrap(); - linker.define_instance(store, "linking2", newInstance).unwrap(); + wasmtime::Linker linker(engine); + linker.define_wasi().unwrap(); + TRACE("resolveModule4 %s", wasmName.c_str()); - TRACE("resolveModule5 %s", wasmName.c_str()); + auto callback = [this, wasmName](wasmtime::Caller caller, uint32_t msg, uint32_t msg_len) + { + TRACE("callback: %i %i", msg_len, msg); + + auto data = this->getData(wasmName); + auto msg_ptr = (char *)&data[msg]; + std::string str(msg_ptr, msg_len); + embedContextCallbacks->DBGLOG("from wasm: %s", str.c_str()); + }; + auto host_func = linker.func_wrap("$root", "dbglog", callback).unwrap(); - wasmInstances.insert(std::make_pair(wasmName, newInstance)); + auto newInstance = linker.instantiate(store, module).unwrap(); + linker.define_instance(store, "linking2", newInstance).unwrap(); - for (auto exportItem : module.exports()) + for (auto exportItem : module.exports()) + { + auto externType = wasmtime::ExternType::from_export(exportItem); + std::string name(exportItem.name()); + if (std::holds_alternative(externType)) + { + TRACE("Exported function: %s", name.c_str()); + auto func = std::get(*newInstance.get(store, name)); + wasmFuncs.insert(std::make_pair(wasmName + "." + name, func)); + } + else if (std::holds_alternative(externType)) + { + TRACE("Exported memory: %s", name.c_str()); + auto memory = std::get(*newInstance.get(store, name)); + wasmMems.insert(std::make_pair(wasmName + "." + name, memory)); + } + else if (std::holds_alternative(externType)) + { + TRACE("Exported table: %s", name.c_str()); + } + else if (std::holds_alternative(externType)) + { + TRACE("Exported global: %s", name.c_str()); + } + else { - auto externType = wasmtime::ExternType::from_export(exportItem); - std::string name(exportItem.name()); - if (std::holds_alternative(externType)) - { - TRACE("Exported function: %s", name.c_str()); - auto func = std::get(*newInstance.get(store, name)); - wasmFuncs.insert(std::make_pair(wasmName + "." + name, func)); - } - else if (std::holds_alternative(externType)) - { - TRACE("Exported memory: %s", name.c_str()); - auto memory = std::get(*newInstance.get(store, name)); - wasmMems.insert(std::make_pair(wasmName + "." + name, memory)); - } - else if (std::holds_alternative(externType)) - { - TRACE("Exported table: %s", name.c_str()); - } - else if (std::holds_alternative(externType)) - { - TRACE("Exported global: %s", name.c_str()); - } - else - { - TRACE("Unknown export type"); - } + TRACE("Unknown export type"); } } + + return newInstance; + } + + void registerInstance(const std::string &wasmName, const std::variant> &wasm) + { + std::function createInstanceCallback = [this, wasmName, wasm]() + { + return createInstance(wasmName, wasm); + }; + wasmInstances.insertIfMissing(wasmName, createInstanceCallback); } - bool hasFunc(const std::string &qualifiedID) + bool hasFunc(const std::string &qualifiedID) const { return wasmFuncs.find(qualifiedID) != wasmFuncs.end(); } - wasmtime::Func getFunc(const std::string &qualifiedID) + wasmtime::Func getFunc(const std::string &qualifiedID) const { auto found = wasmFuncs.find(qualifiedID); if (found == wasmFuncs.end()) - throw std::runtime_error("Wasm function not found: " + qualifiedID); + embedContextCallbacks->throwStringException(-1, "Wasm function not found: %s", qualifiedID.c_str()); return found->second; } @@ -165,7 +233,7 @@ class WasmEngine { auto found = wasmMems.find(createQualifiedID(wasmName, "memory")); if (found == wasmMems.end()) - throw std::runtime_error("Wasm memory not found: " + wasmName); + embedContextCallbacks->throwStringException(-1, "Wasm memory not found: %s", wasmName.c_str()); return found->second.data(store.context()); } }; @@ -196,10 +264,9 @@ class SecureFunction : public ISecureEnclave auto gc_func_name = createQualifiedID(wasmName, "cabi_post_" + funcName); if (wasmEngine->hasFunc(gc_func_name)) { - auto func = wasmEngine->getFunc(gc_func_name); for (auto &result : results) { - func.call(wasmEngine->store, {result}).unwrap(); + wasmEngine->call(gc_func_name, {result}); } } } @@ -277,70 +344,52 @@ class SecureFunction : public ISecureEnclave TRACE("bindUnsignedParam %s %llu", name, val); args.push_back(static_cast(val)); } - virtual void bindStringParam(const char *paramName, size32_t len, const char *val) + virtual void bindStringParam(const char *name, size32_t code_units, const char *val) { - TRACE("bindStringParam %s %d %s", paramName, len, val); - auto memIdxVar = wasmEngine->callRealloc(wasmName, {0, 0, 1, (int32_t)len}); - auto memIdx = memIdxVar[0].i32(); - auto mem = wasmEngine->getData(wasmName); - for (int i = 0; i < len; i++) - { - mem[memIdx + i] = val[i]; - } - args.push_back(memIdx); - args.push_back((int32_t)len); + TRACE("bindStringParam %s %d %s", name, code_units, val); + size32_t utfCharCount; + rtlDataAttr utfText; + rtlStrToUtf8X(utfCharCount, utfText.refstr(), code_units, val); + bindUTF8Param(name, utfCharCount, utfText.getstr()); } virtual void bindVStringParam(const char *name, const char *val) { TRACE("bindVStringParam %s %s", name, val); - auto len = strlen(val); - auto memIdxVar = wasmEngine->callRealloc(wasmName, {0, 0, 1, (int32_t)len}); - auto memIdx = memIdxVar[0].i32(); - auto mem = wasmEngine->getData(wasmName); - for (int i = 0; i < len; i++) - { - mem[memIdx + i] = val[i]; - } - args.push_back(memIdx); - args.push_back((int32_t)len); + bindStringParam(name, strlen(val), val); } - virtual void bindUTF8Param(const char *name, size32_t chars, const char *val) + virtual void bindUTF8Param(const char *name, size32_t code_points, const char *val) { - TRACE("bindUTF8Param %s %d %s", name, chars, val); - auto memIdxVar = wasmEngine->callRealloc(wasmName, {0, 0, 1, (int32_t)chars}); + TRACE("bindUTF8Param %s %d %s", name, code_points, val); + auto code_units = rtlUtf8Size(code_points, val); + auto memIdxVar = wasmEngine->callRealloc(wasmName, {0, 0, 1, (int32_t)code_units}); auto memIdx = memIdxVar[0].i32(); auto mem = wasmEngine->getData(wasmName); - for (int i = 0; i < chars; i++) - { - mem[memIdx + i] = val[i]; - } + memcpy(&mem[memIdx], val, code_units); args.push_back(memIdx); - args.push_back((int32_t)chars); + args.push_back((int32_t)code_units); } - virtual void bindUnicodeParam(const char *name, size32_t chars, const UChar *val) + virtual void bindUnicodeParam(const char *name, size32_t code_points, const UChar *val) { - TRACE("bindUnicodeParam %s %d %S", name, chars, reinterpret_cast(val)); - auto memIdxVar = wasmEngine->callRealloc(wasmName, {0, 0, 2, (int32_t)chars * 2}); - auto memIdx = memIdxVar[0].i32(); - auto mem = wasmEngine->getData(wasmName); - for (int i = 0; i < chars * 2; i += 2) - { - mem[memIdx + i] = val[i]; - } - args.push_back(memIdx); - args.push_back((int32_t)chars); + TRACE("bindUnicodeParam %s %d", name, code_points); + size32_t utfCharCount; + rtlDataAttr utfText; + rtlUnicodeToUtf8X(utfCharCount, utfText.refstr(), code_points, val); + bindUTF8Param(name, utfCharCount, utfText.getstr()); } virtual void bindSetParam(const char *name, int elemType, size32_t elemSize, bool isAll, size32_t totalBytes, const void *setData) { TRACE("bindSetParam %s %d %d %d %d %p", name, elemType, elemSize, isAll, totalBytes, setData); + embedContextCallbacks->throwStringException(-1, "bindSetParam not implemented"); } virtual void bindRowParam(const char *name, IOutputMetaData &metaVal, const byte *val) override { TRACE("bindRowParam %s %p", name, val); + embedContextCallbacks->throwStringException(-1, "bindRowParam not implemented"); } virtual void bindDatasetParam(const char *name, IOutputMetaData &metaVal, IRowStream *val) { TRACE("bindDatasetParam %s %p", name, val); + embedContextCallbacks->throwStringException(-1, "bindDatasetParam not implemented"); } virtual bool getBooleanResult() { @@ -350,26 +399,27 @@ class SecureFunction : public ISecureEnclave virtual void getDataResult(size32_t &__len, void *&__result) { TRACE("getDataResult"); + embedContextCallbacks->throwStringException(-1, "getDataResult not implemented"); } virtual double getRealResult() { TRACE("getRealResult"); if (results[0].kind() == wasmtime::ValKind::F64) - return (int32_t)results[0].f64(); + return results[0].f64(); return results[0].f32(); } virtual __int64 getSignedResult() { TRACE("getSignedResult"); if (results[0].kind() == wasmtime::ValKind::I64) - return (int32_t)results[0].i64(); + return results[0].i64(); return results[0].i32(); } virtual unsigned __int64 getUnsignedResult() { TRACE("getUnsignedResult"); if (results[0].kind() == wasmtime::ValKind::I64) - return (int32_t)results[0].i64(); + return results[0].i64(); return results[0].i32(); } virtual void getStringResult(size32_t &__chars, char *&__result) @@ -377,47 +427,65 @@ class SecureFunction : public ISecureEnclave TRACE("getStringResult %zu", results.size()); auto ptr = results[0].i32(); auto data = wasmEngine->getData(wasmName); - - uint32_t begin = load_int(data, ptr, 4); - TRACE("begin %u", begin); - uint32_t tagged_code_units = load_int(data, ptr + 4, 4); - TRACE("tagged_code_units %u", tagged_code_units); - std::string s = load_string(data, ptr); - TRACE("load_string %s", s.c_str()); - __chars = s.length(); - __result = reinterpret_cast(embedContextCallbacks->rtlMalloc(__chars)); - s.copy(__result, __chars); + uint32_t strPtr; + uint32_t code_units; + std::tie(strPtr, code_units) = load_string(data, ptr); + rtlStrToStrX(__chars, __result, code_units, reinterpret_cast(&data[strPtr])); } virtual void getUTF8Result(size32_t &__chars, char *&__result) { TRACE("getUTF8Result"); + auto ptr = results[0].i32(); + auto data = wasmEngine->getData(wasmName); + uint32_t strPtr; + uint32_t code_units; + std::tie(strPtr, code_units) = load_string(data, ptr); + __chars = rtlUtf8Length(code_units, &data[strPtr]); + TRACE("getUTF8Result %d %d", code_units, __chars); + __result = (char *)rtlMalloc(code_units); + memcpy(__result, &data[strPtr], code_units); } virtual void getUnicodeResult(size32_t &__chars, UChar *&__result) { TRACE("getUnicodeResult"); + auto ptr = results[0].i32(); + auto data = wasmEngine->getData(wasmName); + uint32_t strPtr; + uint32_t code_units; + std::tie(strPtr, code_units) = load_string(data, ptr); + unsigned numchars = rtlUtf8Length(code_units, &data[strPtr]); + rtlUtf8ToUnicodeX(__chars, __result, numchars, reinterpret_cast(&data[strPtr])); } virtual void getSetResult(bool &__isAllResult, size32_t &__resultBytes, void *&__result, int elemType, size32_t elemSize) { - TRACE("getSetResult"); + TRACE("getSetResult %d %d %zu", elemType, elemSize, results.size()); + auto ptr = results[0].i32(); + auto data = wasmEngine->getData(wasmName); + + embedContextCallbacks->throwStringException(-1, "getSetResult not implemented"); } virtual IRowStream *getDatasetResult(IEngineRowAllocator *_resultAllocator) { TRACE("getDatasetResult"); + embedContextCallbacks->throwStringException(-1, "getDatasetResult not implemented"); return NULL; } virtual byte *getRowResult(IEngineRowAllocator *_resultAllocator) { TRACE("getRowResult"); + embedContextCallbacks->throwStringException(-1, "getRowResult not implemented"); return NULL; } virtual size32_t getTransformResult(ARowBuilder &builder) { TRACE("getTransformResult"); + embedContextCallbacks->throwStringException(-1, "getTransformResult not implemented"); return 0; } virtual void loadCompiledScript(size32_t chars, const void *_script) override { TRACE("loadCompiledScript %p", _script); + embedContextCallbacks->throwStringException(-1, "loadCompiledScript not implemented"); } virtual void enter() override { @@ -497,6 +565,6 @@ SECUREENCLAVE_API void syntaxCheck(size32_t &__lenResult, char *&__result, const } __lenResult = errMsg.length(); - __result = reinterpret_cast(embedContextCallbacks->rtlMalloc(__lenResult)); + __result = reinterpret_cast(rtlMalloc(__lenResult)); errMsg.copy(__result, __lenResult); } diff --git a/plugins/wasmembed/secure-enclave/secure-enclave.hpp b/plugins/wasmembed/secure-enclave/secure-enclave.hpp index 82e8c800fb2..d763265ce7d 100644 --- a/plugins/wasmembed/secure-enclave/secure-enclave.hpp +++ b/plugins/wasmembed/secure-enclave/secure-enclave.hpp @@ -1,18 +1,19 @@ #include "platform.h" #include "eclrtl.hpp" +#include "eclrtl_imp.hpp" #ifdef SECUREENCLAVE_EXPORTS - #define SECUREENCLAVE_API DECL_EXPORT +#define SECUREENCLAVE_API DECL_EXPORT #else - #define SECUREENCLAVE_API DECL_IMPORT +#define SECUREENCLAVE_API DECL_IMPORT #endif #include interface IWasmEmbedCallback { - virtual inline void DBGLOG(char const *format, ...) __attribute__((format(printf, 2, 3))) = 0; - virtual void *rtlMalloc(size32_t size) = 0; + virtual inline void DBGLOG(const char *format, ...) __attribute__((format(printf, 2, 3))) = 0; + virtual void throwStringException(int code, const char *format, ...) __attribute__((format(printf, 3, 4))) = 0; virtual const char *resolveManifestPath(const char *leafName) = 0; }; diff --git a/plugins/wasmembed/secure-enclave/util.cpp b/plugins/wasmembed/secure-enclave/util.cpp new file mode 100644 index 00000000000..6c8294b5f44 --- /dev/null +++ b/plugins/wasmembed/secure-enclave/util.cpp @@ -0,0 +1,58 @@ +#include "util.hpp" + +#include +#include + +std::vector read_wasm_binary_to_buffer(const std::string &filename) +{ + std::ifstream file(filename, std::ios::binary | std::ios::ate); + if (!file) + { + throw std::runtime_error("Failed to open file"); + } + + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(size); + if (!file.read(reinterpret_cast(buffer.data()), size)) + { + throw std::runtime_error("Failed to read file"); + } + + return buffer; +} + +std::string extractContentInDoubleQuotes(const std::string &input) +{ + + auto firstQuote = input.find_first_of('"'); + auto secondQuote = input.find('"', firstQuote + 1); + if (firstQuote == std::string::npos || secondQuote == std::string::npos) + { + return ""; + } + return input.substr(firstQuote + 1, secondQuote - firstQuote - 1); +} + +std::pair splitQualifiedID(const std::string &qualifiedName) +{ + std::istringstream iss(qualifiedName); + std::vector tokens; + std::string token; + + while (std::getline(iss, token, '.')) + { + tokens.push_back(token); + } + if (tokens.size() != 2) + { + throw std::runtime_error("Invalid import function " + qualifiedName + ", expected format: ."); + } + return std::make_pair(tokens[0], tokens[1]); +} + +std::string createQualifiedID(const std::string &wasmName, const std::string &funcName) +{ + return wasmName + "." + funcName; +} \ No newline at end of file diff --git a/plugins/wasmembed/secure-enclave/util.hpp b/plugins/wasmembed/secure-enclave/util.hpp new file mode 100644 index 00000000000..e9c5e9e0f47 --- /dev/null +++ b/plugins/wasmembed/secure-enclave/util.hpp @@ -0,0 +1,7 @@ +#include +#include + +std::vector read_wasm_binary_to_buffer(const std::string &filename); +std::string extractContentInDoubleQuotes(const std::string &input); +std::pair splitQualifiedID(const std::string &qualifiedName); +std::string createQualifiedID(const std::string &wasmName, const std::string &funcName); diff --git a/plugins/wasmembed/wasmembed.cpp b/plugins/wasmembed/wasmembed.cpp index 717e4958d95..a98b36ee64b 100644 --- a/plugins/wasmembed/wasmembed.cpp +++ b/plugins/wasmembed/wasmembed.cpp @@ -58,7 +58,7 @@ namespace wasmLanguageHelper } // IWasmEmbedCallback --- - virtual inline void DBGLOG(char const *format, ...) override + virtual inline void DBGLOG(const char *format, ...) override { va_list args; va_start(args, format); @@ -66,9 +66,13 @@ namespace wasmLanguageHelper va_end(args); } - virtual void *rtlMalloc(size32_t size) override + virtual void throwStringException(int code, const char *format, ...) override { - return ::rtlMalloc(size); + va_list args; + va_start(args, format); + IException *ret = makeStringExceptionVA(code, format, args); + va_end(args); + throw ret; } virtual const char *resolveManifestPath(const char *leafName) override diff --git a/vcpkg-linux.code-workspace b/vcpkg-linux.code-workspace index 4039eca856d..54777686d50 100644 --- a/vcpkg-linux.code-workspace +++ b/vcpkg-linux.code-workspace @@ -10,9 +10,10 @@ "-DCONTAINERIZED=OFF", "-DUSE_OPTIONAL=OFF", "-DUSE_CPPUNIT=ON", - "-DINCLUDE_PLUGINS=ON", - "-DSUPPRESS_V8EMBED=ON", - "-DSUPPRESS_REMBED=ON" + "-DWASMEMBED=ON" + // "-DINCLUDE_PLUGINS=ON", + // "-DSUPPRESS_V8EMBED=ON", + // "-DSUPPRESS_REMBED=ON" ], "editor.formatOnSave": false },