From 1b3477c922f69b49224a413043d2b94587f86f79 Mon Sep 17 00:00:00 2001 From: Marcus Pousette Date: Wed, 17 Jan 2024 10:41:11 +0100 Subject: [PATCH] feat: custom class serialization --- README.md | 40 ++++++++- package.json | 2 +- src/__tests__/index.test.ts | 157 +++++++++++++++++++++++++++++++----- src/binary.ts | 7 ++ src/index.ts | 59 ++++++++++---- src/types.ts | 1 + 6 files changed, 228 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index b3412677..885acb1f 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,8 @@ class TestStruct { ``` **Custom serialization and deserialization** +Override how one field is handled + ```typescript class TestStruct { @@ -245,12 +247,48 @@ class TestStruct { } } -validate(TestStruct); const serialized = serialize(new TestStruct(3)); const deserialied = deserialize(serialized, TestStruct); expect(deserialied.number).toEqual(3); ``` +Override how one class is serialized + +```typescript +import { serializer } from '@dao-xyz/borsh' +class TestStruct { + + @field({type: 'u8'}) + public number: number; + + constructor(number: number) { + this.number = number; + } + + cache: Uint8Array | undefined; + + @serializer() + override(writer: BinaryWriter, serialize: (obj: this) => Uint8Array) { + if (this.cache) { + writer.set(this.cache) + } + else { + this.cache = serialize(this) + writer.set(this.cache) + } + } +} + +const obj = new TestStruct(3); +const serialized = serialize(obj); +const deserialied = deserialize(serialized, TestStruct); +expect(deserialied.number).toEqual(3); +expect(obj.cache).toBeDefined() +``` + + + + ## Inheritance Schema generation is supported if deserialization is deterministic. In other words, all classes extending some super class needs to use discriminators/variants of the same type. diff --git a/package.json b/package.json index 39affc35..2d55cf35 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@dao-xyz/borsh", - "version": "5.1.8", + "version": "5.2.0", "readme": "README.md", "homepage": "https://github.com/dao-xyz/borsh-ts#README", "description": "Binary Object Representation Serializer for Hashing simplified with decorators", diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 46900e4d..609e3ead 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -13,6 +13,7 @@ import { getSchema, BorshError, string, + serializer, } from "../index.js"; import crypto from "crypto"; @@ -1584,34 +1585,146 @@ describe("string", () => { }); }); -describe("bool", () => { - test("field bolean", () => { +describe("options", () => { + class Test { + @field({ type: "u8" }) + number: number; + constructor(number: number) { + this.number = number; + } + } + test("pass writer", () => { + const writer = new BinaryWriter(); + writer.u8(1); + expect(new Uint8Array(serialize(new Test(123), writer))).toEqual( + new Uint8Array([1, 123]) + ); + }); +}); + +describe("override", () => { + describe("serializer", () => { class TestStruct { - @field({ type: "bool" }) - public a: boolean; + @serializer() + override(writer: BinaryWriter) { + writer.u8(1); + } + } - constructor(a: boolean) { - this.a = a; + class TestStructMixed { + @field({ type: TestStruct }) + nested: TestStruct; + + @field({ type: "u8" }) + number: number; + + cached: Uint8Array; + + constructor(number: number) { + this.nested = new TestStruct(); + this.number = number; + } + + @serializer() + override(writer: BinaryWriter, serialize: (obj: this) => Uint8Array) { + if (this.cached) { + writer.set(this.cached); + } else { + this.cached = serialize(this); + writer.set(this.cached); + } } } - validate(TestStruct); - const expectedResult: StructKind = new StructKind({ - fields: [ - { - key: "a", - type: "bool", - }, - ], + + class TestStructMixedNested { + @field({ type: TestStructMixed }) + nested: TestStructMixed; + + @field({ type: "u8" }) + number: number; + + cached: Uint8Array; + + constructor(number: number) { + this.nested = new TestStructMixed(number); + this.number = number; + } + + @serializer() + override(writer: BinaryWriter, serialize: (obj: this) => Uint8Array) { + if (this.cached) { + writer.set(this.cached); + } else { + this.cached = serialize(this); + writer.set(this.cached); + } + } + } + + class TestStructNested { + @field({ type: TestStruct }) + nested: TestStruct; + + @field({ type: "u8" }) + number: number; + constructor() { + this.nested = new TestStruct(); + this.number = 2; + } + } + + class TestBaseClass { + @serializer() + override(writer: BinaryWriter) { + writer.u8(3); + } + } + @variant(2) + class TestStructInherited extends TestBaseClass { + @field({ type: TestStruct }) + struct: TestStruct; + + @field({ type: "u8" }) + number: number; + constructor() { + super(); + this.struct = new TestStruct(); + this.number = 0; + } + } + + test("struct", () => { + expect(new Uint8Array(serialize(new TestStruct()))).toEqual( + new Uint8Array([1]) + ); }); - expect(getSchema(TestStruct)).toEqual(expectedResult); - const buf = serialize(new TestStruct(true)); - expect(new Uint8Array(buf)).toEqual(new Uint8Array([1])); - const deserializedSome = deserialize(new Uint8Array(buf), TestStruct); - expect(deserializedSome.a).toEqual(true); - }); -}); + test("recursive call", () => { + const obj = new TestStructMixed(2); + expect(new Uint8Array(serialize(obj))).toEqual(new Uint8Array([1, 2])); + expect(new Uint8Array(obj.cached)).toEqual(new Uint8Array([1, 2])); + expect(new Uint8Array(serialize(obj))).toEqual(new Uint8Array([1, 2])); + }); + test("recursive call nested", () => { + const obj = new TestStructMixedNested(2); + expect(new Uint8Array(serialize(obj))).toEqual(new Uint8Array([1, 2, 2])); + expect(new Uint8Array(obj.cached)).toEqual(new Uint8Array([1, 2, 2])); + expect(new Uint8Array(obj.nested.cached)).toEqual(new Uint8Array([1, 2])); -describe("override", () => { + expect(new Uint8Array(serialize(obj))).toEqual(new Uint8Array([1, 2, 2])); + }); + + test("nested", () => { + expect(new Uint8Array(serialize(new TestStructNested()))).toEqual( + new Uint8Array([1, 2]) + ); + }); + + test("inherited", () => { + expect(new Uint8Array(serialize(new TestStructInherited()))).toEqual( + new Uint8Array([3, 2, 1, 0]) + ); + }); + }); test("serialize/deserialize", () => { /** * Serialize field with custom serializer and deserializer diff --git a/src/binary.ts b/src/binary.ts index a73b58f0..e5e7aae0 100644 --- a/src/binary.ts +++ b/src/binary.ts @@ -186,6 +186,13 @@ export class BinaryWriter { writer.totalSize += lengthSize + len; } + public set(array: Uint8Array) { + let offset = this.totalSize; + this._writes = this._writes.next = () => { + this._buf.set(array, offset); + } + this.totalSize += array.length + } public uint8Array(array: Uint8Array) { return BinaryWriter.uint8Array(array, this) diff --git a/src/index.ts b/src/index.ts index 94e074da..a678b4e5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -37,10 +37,15 @@ const PROTOTYPE_SCHEMA_OFFSET = PROTOTYPE_DESERIALIZATION_HANDLER_OFFSET + PROTO * @returns bytes */ export function serialize( - obj: any + obj: any, + writer: BinaryWriter = new BinaryWriter() ): Uint8Array { - const writer = new BinaryWriter(); - (obj.constructor._borsh_serialize || (obj.constructor._borsh_serialize = serializeStruct(obj.constructor)))(obj, writer) + (obj.constructor._borsh_serialize || (obj.constructor._borsh_serialize = serializeStruct(obj.constructor, true)))(obj, writer) + return writer.finalize(); +} + +function recursiveSerialize(obj: any, writer: BinaryWriter = new BinaryWriter()) { + (obj.constructor._borsh_serialize_recursive || (obj.constructor._borsh_serialize_recursive = serializeStruct(obj.constructor, false)))(obj, writer) return writer.finalize(); } @@ -196,7 +201,8 @@ function serializeField( function serializeStruct( - ctor: Function + ctor: Function, + allowCustomSerializer = true ) { let handle: (obj: any, writer: BinaryWriter) => any = undefined; var i = 0; @@ -236,20 +242,29 @@ function serializeStruct( } : (_obj, writer) => writer.string(index); } } - for (const field of schema.fields) { + if (allowCustomSerializer && schema.serializer) { let prev = handle; - const fieldHandle = serializeField(field.key, field.type); - if (prev) { - handle = (obj, writer) => { - prev(obj, writer); - fieldHandle(obj[field.key], writer) + handle = prev ? (obj, writer) => { + prev(obj, writer); + schema.serializer(obj, writer, (obj: any) => recursiveSerialize(obj)) + } : (obj, writer) => schema.serializer(obj, writer, (obj: any) => recursiveSerialize(obj)) + } + else { + for (const field of schema.fields) { + let prev = handle; + const fieldHandle = serializeField(field.key, field.type); + if (prev) { + handle = (obj, writer) => { + prev(obj, writer); + fieldHandle(obj[field.key], writer) + } + } + else { + handle = (obj, writer) => fieldHandle(obj[field.key], writer) } } - else { - handle = (obj, writer) => fieldHandle(obj[field.key], writer) - } - } + } else if (once && !getDependencies(ctor, i)?.length) { @@ -740,6 +755,22 @@ export function field(properties: SimpleField | CustomField) { }; } + +/** + * @param properties, the properties of the field mapping to schema + * @returns + */ +export function serializer() { + return function (target: any, propertyKey: string) { + const offset = getOffset(target.constructor); + const schemas = getOrCreateStructMeta(target.constructor, offset); + schemas.serializer = (obj, writer, serialize) => obj[propertyKey](writer, serialize) + }; +} + + + + /** * @param clazzes * @param validate, run validation? diff --git a/src/types.ts b/src/types.ts index 7083a587..caefaf3f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -129,6 +129,7 @@ export interface Field { export class StructKind { variant?: number | number[] | string + serializer?: (any: any, writer: BinaryWriter, serialize: (obj: any) => Uint8Array) => void fields: Field[]; constructor(properties?: { variant?: number | number[] | string, fields: Field[] }) { if (properties) {