Skip to content

Commit

Permalink
HPCC-29914 Add embedded wasm support
Browse files Browse the repository at this point in the history
Move definitions into their own cpp file(s)
Added a thread safe map
Removed std::runtime usage
Refactored up some string/unicode calls
Revisited ABI code

Signed-off-by: Gordon Smith <GordonJSmith@gmail.com>
  • Loading branch information
GordonSmith committed Oct 5, 2023
1 parent cb6bfde commit d2e4d48
Show file tree
Hide file tree
Showing 23 changed files with 715 additions and 500 deletions.
22 changes: 20 additions & 2 deletions plugins/wasmembed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
project(wasmembed)

set(CMAKE_CXX_STANDARD 20)

if(WASMEMBED)
ADD_PLUGIN(wasmembed)
if(MAKE_WASMEMBED)

add_subdirectory(secure-enclave)
find_path(WASMTIME_CPP_API_INCLUDE_DIRS "wasmtime-cpp-api/wasmtime.hh"
PATHS ${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}
)
if (WIN32)
find_library(WASMTIME_LIB NAMES wasmtime.dll
PATHS ${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}
)
else()
find_library(WASMTIME_LIB NAMES wasmtime
PATHS ${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}
)
endif()

include_directories(
${WASMTIME_CPP_API_INCLUDE_DIRS}/wasmtime-c-api
${WASMTIME_CPP_API_INCLUDE_DIRS}/wasmtime-cpp-api
./../../common/thorhelper
./../../dali/base
./../../rtl/eclrtl
Expand All @@ -21,13 +36,16 @@ if(WASMEMBED)

add_library(wasmembed SHARED
wasmembed.cpp
secure-enclave.cpp
abi.cpp
util.cpp
)

target_link_libraries(wasmembed
${WASMTIME_LIB}
roxiemem
eclrtl
jlib
secure-enclave
)

install(
Expand Down
268 changes: 268 additions & 0 deletions plugins/wasmembed/abi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
See: https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md
https://github.com/WebAssembly/component-model/blob/main/design/mvp/canonical-abi/definitions.py
*/

#include "abi.hpp"
#include "jexcept.hpp"

#include <cmath>
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>

auto UTF16_TAG = 1 << 31;

//
/* canonical despecialize (python) -------------------------------------------------------------
def despecialize(t):
match t:
case Tuple(ts) : return Record([ Field(str(i), t) for i,t in enumerate(ts) ])
case Union(ts) : return Variant([ Case(str(i), t) for i,t in enumerate(ts) ])
case Enum(labels) : return Variant([ Case(l, None) for l in labels ])
case Option(t) : return Variant([ Case("none", None), Case("some", t) ])
case Result(ok, error) : return Variant([ Case("ok", ok), Case("error", error) ])
case _ : return t
*/

// template <typename T>
// wasmtime::ValType despecialize(const T<wasmtime::Val> &t)
// {
// switch (t.kind())
// {
// case wasmtime::ValKind::I32:
// case wasmtime::ValKind::I64:
// case wasmtime::ValKind::F32:
// case wasmtime::ValKind::F64:
// case wasmtime::ValKind::V128:
// return t.kind();
// default:
// return wasmtime::ValType::i32();
// }
// }

/* canonical alignment (python) -------------------------------------------------------------
def alignment(t):
match despecialize(t):
case Bool() : return 1
case S8() | U8() : return 1
case S16() | U16() : return 2
case S32() | U32() : return 4
case S64() | U64() : return 8
case Float32() : return 4
case Float64() : return 8
case Char() : return 4
case String() | List(_) : return 4
case Record(fields) : return alignment_record(fields)
case Variant(cases) : return alignment_variant(cases)
case Flags(labels) : return alignment_flags(labels)
case Own(_) | Borrow(_) : return 4
*/

// int alignment(const wasmtime::ValType &t)
// {
// switch (t.kind())
// {
// case wasmtime::ValKind::I32:
// case wasmtime::ValKind::F32:
// return 4;
// case wasmtime::ValKind::I64:
// case wasmtime::ValKind::F64:
// return 8;
// case wasmtime::ValKind::V128:
// return 16;
// default:
// return 1;
// }
// }

/* canonical align_to (python) -------------------------------------------------------------
def align_to(ptr, alignment):
return math.ceil(ptr / alignment) * alignment
*/

uint32_t align_to(uint32_t ptr, uint32_t alignment)
{
return std::ceil(ptr / alignment) * alignment;
}

// loading ---

/* canonical load_int (python) -------------------------------------------------------------
def load_int(cx, ptr, nbytes, signed = False):
return int.from_bytes(cx.opts.memory[ptr : ptr+nbytes], 'little', signed=signed)
*/

#include <stdio.h>
#include <stdint.h>

// int load_int(const wasmtime::Span<uint8_t> &cx, int32_t *ptr, int nbytes, int signed)
// {
// int result = 0;
// for (int i = 0; i < nbytes; i++) {
// result += ptr[i] << (8*i);
// }
// if (signed) {
// result += (ptr[nbytes-1] & 0x80) ? -1 << (8*nbytes) : 0;
// }
// return result;
// }

template <typename T>
T load_int(const wasmtime::Span<uint8_t> &data, int32_t ptr)
{
T retVal = 0;
auto nbytes = sizeof(retVal);
for (int i = 0; i < nbytes; ++i)
{
uint8_t b = data[ptr + i];
retVal += b << (i * 8);
}
if (std::is_signed<T>::value)
{
retVal += (data[ptr + nbytes - 1] & 0x80) ? -1 << (8 * nbytes) : 0;
}
return retVal;
}

/* canonical load_string_from_range (python) -------------------------------------------------------------
def load_string_from_range(cx, ptr, tagged_code_units):
match cx.opts.string_encoding:
case 'utf8':
alignment = 1
byte_length = tagged_code_units
encoding = 'utf-8'
case 'utf16':
alignment = 2
byte_length = 2 * tagged_code_units
encoding = 'utf-16-le'
case 'latin1+utf16':
alignment = 2
if bool(tagged_code_units & UTF16_TAG):
byte_length = 2 * (tagged_code_units ^ UTF16_TAG)
encoding = 'utf-16-le'
else:
byte_length = tagged_code_units
encoding = 'latin-1'
trap_if(ptr != align_to(ptr, alignment))
trap_if(ptr + byte_length > len(cx.opts.memory))
try:
s = cx.opts.memory[ptr : ptr+byte_length].decode(encoding)
except UnicodeError:
trap()
return (s, cx.opts.string_encoding, tagged_code_units)
*/

// More: Not currently available from the wasmtime::context object, see https://github.com/bytecodealliance/wasmtime/issues/6719
std::string global_encoding = "utf8";

std::pair<uint32_t /*ptr*/, uint32_t /*byte length*/> load_string_from_range(const wasmtime::Span<uint8_t> &data, uint32_t ptr, uint32_t tagged_code_units)
{
std::string encoding = "utf-8";
uint32_t byte_length = tagged_code_units;
uint32_t alignment = 1;
if (global_encoding.compare("utf8") == 0)
{
alignment = 1;
byte_length = tagged_code_units;
encoding = "utf-8";
}
else if (global_encoding.compare("utf16") == 0)
{
alignment = 2;
byte_length = 2 * tagged_code_units;
encoding = "utf-16-le";
}
else if (global_encoding.compare("latin1+utf16") == 0)
{
alignment = 2;
if (tagged_code_units & UTF16_TAG)
{
byte_length = 2 * (tagged_code_units ^ UTF16_TAG);
encoding = "utf-16-le";
}
else
{
byte_length = tagged_code_units;
encoding = "latin-1";
}
}

if (ptr != align_to(ptr, alignment))
{
throw makeStringException(0, "Invalid alignment");
}

if (ptr + byte_length > data.size())
{
throw makeStringException(1, "Out of bounds");
}

return std::make_pair(ptr, byte_length);
}

/* canonical load_string (python) -------------------------------------------------------------
def load_string(cx, ptr):
begin = load_int(cx, ptr, 4)
tagged_code_units = load_int(cx, ptr + 4, 4)
return load_string_from_range(cx, begin, tagged_code_units)
*/
std::pair<uint32_t /*ptr*/, uint32_t /*byte length*/> load_string(const wasmtime::Span<uint8_t> &data, uint32_t ptr)
{
uint32_t begin = load_int<uint32_t>(data, ptr);
uint32_t tagged_code_units = load_int<uint32_t>(data, ptr + 4);
return load_string_from_range(data, begin, tagged_code_units);
}

/* canonical load_list_from_range (python) -------------------------------------------------------------
def load_list_from_range(cx, ptr, length, elem_type):
trap_if(ptr != align_to(ptr, alignment(elem_type)))
trap_if(ptr + length * size(elem_type) > len(cx.opts.memory))
a = []
for i in range(length):
a.append(load(cx, ptr + i * size(elem_type), elem_type))
return a
*/

template <typename T>
std::vector<T> load_list_from_range(const wasmtime::Span<uint8_t> &data, uint32_t ptr, uint32_t length)
{
if (ptr != align_to(ptr, alignment(T{})))
throw makeStringException(2, "Pointer is not aligned");
if (ptr + length * sizeof(T) > data.size())
throw makeStringException(1, "Out of bounds access");
std::vector<T> a;
for (uint32_t i = 0; i < length; i++)
{
a.push_back(load<T>(data, ptr + i * sizeof(T)));
}
return a;
}

/* canonical load_list (python) -------------------------------------------------------------
def load_list(cx, ptr, elem_type):
begin = load_int(cx, ptr, 4)
length = load_int(cx, ptr + 4, 4)
return load_list_from_range(cx, begin, length, elem_type)
*/

// Storing ---
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit d2e4d48

Please sign in to comment.