From b53964ee6334fc3d260c8f4760745c86ffe2270d Mon Sep 17 00:00:00 2001 From: Konstantinos Gaitanis Date: Fri, 9 Aug 2024 17:33:18 +0300 Subject: [PATCH] fix optional encoding --- .../CandidValue+CustomStringConvertible.swift | 2 +- Sources/Candid/Encoding/CandidEncoder.swift | 101 +++++++++++++----- Tests/CandidTests/CandidEncoderTests.swift | 23 ++-- 3 files changed, 93 insertions(+), 33 deletions(-) diff --git a/Sources/Candid/CandidValue/CandidValue+CustomStringConvertible.swift b/Sources/Candid/CandidValue/CandidValue+CustomStringConvertible.swift index a69a4e4..b6d278b 100644 --- a/Sources/Candid/CandidValue/CandidValue+CustomStringConvertible.swift +++ b/Sources/Candid/CandidValue/CandidValue+CustomStringConvertible.swift @@ -76,7 +76,7 @@ extension CandidKeyedValue: CustomStringConvertible { extension CandidVariant: CustomStringConvertible { public var description: String { - return "\(value)" + return "\(key.stringValue ?? String(key.intValue)): \(value)" } } diff --git a/Sources/Candid/Encoding/CandidEncoder.swift b/Sources/Candid/Encoding/CandidEncoder.swift index 89c7a20..1b509e1 100644 --- a/Sources/Candid/Encoding/CandidEncoder.swift +++ b/Sources/Candid/Encoding/CandidEncoder.swift @@ -87,37 +87,54 @@ private class CandidValueEncoder: Encoder { guard let record = candidValue.recordValue else { throw EncodingError.invalidValue(T.self, .init(codingPath: codingPath, debugDescription: "Enums should be encoded to records")) } - let variant = try convertRecordToVariant(record) + let variant = try convertRecordToVariant(record, using: mirror) encodingValue = CandidSingleEncodingValue(variant) } - addOptionals(using: mirror) + encodingValue = CandidSingleEncodingValue(addOptionals(to: encodingValue.candidValue, using: mirror)) } } - private func addOptionals(using mirror: Mirror) { + private func addOptionals(to value: CandidValue, using mirror: Mirror) -> CandidValue { // The Swift encoding system skips the optional wrapper when a value is present, trying to directly encode the contained value // We add optional wrappers where needed according to the Type defined in the mirror - if let record = encodingValue.candidValue.recordValue { + switch value { + case .record(let record): let newRecordItems = addOptionals(in: record.candidSortedItems, using: mirror, intMarker: "_") - encodingValue = CandidSingleEncodingValue(.record(newRecordItems)) + return .record(newRecordItems) - } else if let variant = encodingValue.candidValue.variantValue, - let child = mirror.children.first { - if variant.value != .null && variant.value.candidType.primitiveType != .record { - // single value enum case - if let _ = child.value as? any CandidOptionalMarker { - encodingValue = CandidSingleEncodingValue(.variant(.init(variant.key, .option(variant.value)))) - } - - } else if let associatedValues = variant.value.recordValue { + case .variant(let variant): + if let child = mirror.children.first { let associatedMirror = Mirror(reflecting: child.value) - let newRecordItems = addOptionals(in: associatedValues.candidSortedItems, using: associatedMirror, intMarker: ".") - encodingValue = CandidSingleEncodingValue(.variant(.init(variant.key, .record(newRecordItems)))) + if let associatedValues = variant.value.recordValue { + let newRecordItems = addOptionals(in: associatedValues.candidSortedItems, using: associatedMirror, intMarker: ".") + return .variant(.init(variant.key, .record(newRecordItems))) + } else if associatedMirror.displayStyle == .tuple { + // case of single named argument in an enum + let tupleMirror = Mirror(reflecting: associatedMirror.children.first!) + if tupleMirror.children.count == 2, + let label = tupleMirror.descendant("label") as? String, + let value = tupleMirror.descendant("value"), + let optional = value as? any CandidOptionalMarker { + if let _ = optional.value { + return .variant(.init(variant.key, .record([label: .option(variant.value)]))) + } + return .variant(.init(variant.key, .record([label: .option(candidType(optional.wrappedType))]))) + } + + } else if let optional = child.value as? any CandidOptionalMarker { + // single value enum case + if let _ = optional.value { + return .variant(.init(variant.key, .option(variant.value))) + } + return .variant(.init(variant.key, .option(candidType(optional.wrappedType)))) + } } + default: break } + return value } - private func convertRecordToVariant(_ record: CandidRecord) throws -> CandidValue { + private func convertRecordToVariant(_ record: CandidRecord, using mirror: Mirror) throws -> CandidValue { guard record.candidSortedItems.count == 1, let value = record.candidSortedItems.first, let associatedValues = value.value.recordValue else { @@ -125,10 +142,17 @@ private class CandidValueEncoder: Encoder { } let variantValueKey = value.key if associatedValues.candidSortedItems.isEmpty { +// if let child = mirror.children.first, +// let optional = child.value as? any CandidOptionalMarker { +// return .variant(.init(variantValueKey, .option(candidType(optional.wrappedType)))) +// } return .variant(CandidKeyedValue(variantValueKey)) + } else if associatedValues.candidSortedItems.count == 1, - let associatedValue = associatedValues.candidSortedItems.first { + let associatedValue = associatedValues.candidSortedItems.first, + associatedValue.key.stringValue == "_0" { return .variant(CandidKeyedValue(variantValueKey, associatedValue.value)) + } else { let variantValues = try associatedValues.candidSortedItems.map { return CandidKeyedValue(try $0.key.toVariantKey(), $0.value) @@ -138,13 +162,21 @@ private class CandidValueEncoder: Encoder { } private func addOptionals(in keyedValues: [CandidKeyedValue], using mirror: Mirror, intMarker: String) -> [CandidKeyedValue] { - var newRecordItems: [CandidKeyedValue] = [] - for keyedItem in keyedValues { - if let child = mirror.children.first(where: { $0.label == keyedItem.key.stringValue ?? "\(intMarker)\(keyedItem.key.intValue)" || CandidKey.candidHash($0.label ?? "?") == keyedItem.key.intValue }), - let _ = child.value as? any CandidOptionalMarker { - newRecordItems.append(.init(keyedItem.key, .option(keyedItem.value))) + var newRecordItems = keyedValues + for child in mirror.children { + if let existing = keyedValues.first(where: { + $0.key.stringValue == child.label || + "\(intMarker)\($0.key.intValue)" == child.label || + $0.key.intValue == CandidKey.candidHash(child.label ?? "?") + }) { + if child.value is any CandidOptionalMarker { + newRecordItems.replace(existing.key, with: .option(existing.value)) + } } else { - newRecordItems.append(keyedItem) + if let optional = child.value as? any CandidOptionalMarker, + let label = child.label { + newRecordItems.append(.init(label, .option(candidType(optional.wrappedType)))) + } } } return newRecordItems @@ -259,7 +291,7 @@ private class CandidKeyedEncodingValue: CandidEncodingValue where Key: Codi func set(_ value: CandidValue, for key: Key) { set(CandidSingleEncodingValue(value), for: key) } private func candidKey(for key: Key) -> CandidKey { - if let int = key.intValue { + if let int = key.intValue, int != CandidKey.candidHash(key.stringValue) { return CandidKey(int) } return CandidKey(key.stringValue) @@ -433,3 +465,22 @@ private extension CandidKey { static let unnamedEnumRegex = try! Regex(#"_(?'number'\d+)"#) } + +private extension Array { + mutating func replace(_ key: CandidKey, with newValue: CandidValue) { + replace(key.intValue, with: newValue) + } + + mutating func replace(_ stringKey: String, with newValue: CandidValue) { + replace(CandidKey.candidHash(stringKey), with: newValue) + } + + mutating func replace(_ intKey: Int, with newValue: CandidValue) { + guard let index = firstIndex(where: { $0.key.intValue == intKey }) else { + return + } + let key = self[index].key + remove(at: index) + insert(.init(key, newValue), at: index) + } +} diff --git a/Tests/CandidTests/CandidEncoderTests.swift b/Tests/CandidTests/CandidEncoderTests.swift index de2ed0f..9e0adad 100644 --- a/Tests/CandidTests/CandidEncoderTests.swift +++ b/Tests/CandidTests/CandidEncoderTests.swift @@ -81,6 +81,7 @@ private let decodingTestVectors: [(any Codable, CandidValue, any Decodable.Type) (TestDataRecord3(data: Data([1])), .record(["data": .option(.option(.blob(Data([1]))))]), TestDataRecord3.self), (TestDataRecord(vector: [.init(data: Data([1]))]), .record(["vector": try! .vector([.record(["data": .option(.blob(Data([1])))])])]), TestDataRecord.self), (TestRecord(a: 1, b: 2), .record([97: .natural8(1), 98: .integer64(2)]), TestRecord.self), + (TestDataRecord2(data: nil), .record(["data": .option(.blob)]), TestDataRecord2.self), (TestEnum.a, .variant(.init("a", .null)), TestEnum.self), (TestEnum.b(2), .variant(.init("b", .option(.natural8(2)))), TestEnum.self), @@ -90,6 +91,7 @@ private let decodingTestVectors: [(any Codable, CandidValue, any Decodable.Type) (TestEnum.e(a: 1, b: 2), .variant(.init(101, .record([97:.option(.natural8(1)), 98: .natural16(2)]))), TestEnum.self), (TestEnum.f(1, 2), .variant(.init("f", .record([0: .option(.natural8(1)), 1: .natural16(2)]))), TestEnum.self), (TestEnum.g(SingleValueRecord(value: 2)), .variant(.init("g", .record(["value": .integer8(2)]))), TestEnum.self), + (TestEnum.h(a: 2), .variant(.init("h", .record(["a": .option(.natural8(2))]))), TestEnum.self), ] @@ -150,8 +152,8 @@ private let encodingTestVectors: [(any Encodable, CandidValue)] = [ ([UInt8]([8, 2]), try! .vector([.natural8(8), .natural8(2)])), ([UInt8?]([8, nil, 9]), try! .vector([.option(.natural8(8)), .option(.natural8), .option(.natural8(9))])), - (["a":0], .record(["a": .integer64(0)])), - (["a":0, "b": UInt8(8)], .record(["a": .natural8(0), "b": .natural8(8)])), + (["a": 0], .record(["a": .integer64(0)])), + (["a": 0, "b": UInt8(8)], .record(["a": .natural8(0), "b": .natural8(8)])), (TestRecord(a: 1, b: 2), .record(["a": .natural8(1), "b": .integer64(2)])), (TestRecord2(a: [1, nil]), .record(["a": try! .vector([.option(.natural8(1)), .option(.natural8)])])), (TestRecord3(record: TestRecord(a: 1, b: 2), records2: [TestRecord2(a: [1, nil])]), .record([ @@ -164,18 +166,23 @@ private let encodingTestVectors: [(any Encodable, CandidValue)] = [ (TestDataRecord(vector: [.init(data: Data([1]))]), .record(["vector": try! .vector([.record(["data": .option(.blob(Data([1])))])])])), // this fails because we don't correctly identify the recursive type. //(TestRecursiveRecord(a: []), .record(["a": .vector(.record())])), - + (TestDataRecord2(data: nil), .record(["data": .option(.blob)])), + (TestEnum.a, .variant(.init("a", .null))), (TestEnum.b(2), .variant(.init("b", .option(.natural8(2))))), + (TestEnum.b(nil), .variant(.init("b", .option(.natural8)))), (TestEnum.c(.a), .variant(.init("c", .variant(.init("a"))))), (TestEnum.d([1,2]), .variant(.init("d", try! .vector([.natural8(1), .natural8(2)])))), (TestEnum.e(a: 1, b: 2), .variant(.init("e", .record(["a":.option(.natural8(1)), "b": .natural16(2)])))), + (TestEnum.e(a: nil, b: 2), .variant(.init("e", .record(["a":.option(.natural8), "b": .natural16(2)])))), (TestEnum.f(1, 2), .variant(.init("f", .record([0: .option(.natural8(1)), 1: .natural16(2)])))), (TestEnum.g(SingleValueRecord(value: 2)), .variant(.init("g", .record(["value": .integer8(2)])))), - - (CandidFunction(signature: .init([CandidType](), []), method: nil), .function(CandidFunction(signature: .init([CandidType](), []), method: nil))), - (try! CandidPrincipal("aaaaa-aa"), try! .principal("aaaaa-aa")), - (CandidService(principal: nil, signature: .init([])), .service(CandidService(principal: nil, signature: .init([])))), + (TestEnum.h(a: 2), .variant(.init("h", .record(["a": .option(.natural8(2))])))), + (TestEnum.h(a: nil), .variant(.init("h", .record(["a": .option(.natural8)])))), +// +// (CandidFunction(signature: .init([CandidType](), []), method: nil), .function(CandidFunction(signature: .init([CandidType](), []), method: nil))), +// (try! CandidPrincipal("aaaaa-aa"), try! .principal("aaaaa-aa")), +// (CandidService(principal: nil, signature: .init([])), .service(CandidService(principal: nil, signature: .init([])))), ] private struct TestRecord: Codable { @@ -226,6 +233,7 @@ private indirect enum TestEnum: Codable { case e(a: UInt8?, b: UInt16) case f(UInt8?, UInt16) case g(SingleValueRecord) + case h(a: UInt8?) enum CodingKeys: Int, CodingKey { case a = 97 @@ -235,6 +243,7 @@ private indirect enum TestEnum: Codable { case e = 101 case f = 102 case g = 103 + case h = 104 } enum ECodingKeys: Int, CodingKey {