Skip to content

Commit

Permalink
fix optional encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Konstantinos Gaitanis committed Aug 9, 2024
1 parent 365da6d commit b53964e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ extension CandidKeyedValue: CustomStringConvertible {

extension CandidVariant: CustomStringConvertible {
public var description: String {
return "\(value)"
return "\(key.stringValue ?? String(key.intValue)): \(value)"
}
}

Expand Down
101 changes: 76 additions & 25 deletions Sources/Candid/Encoding/CandidEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,48 +87,72 @@ private class CandidValueEncoder: Encoder {
guard let record = candidValue.recordValue else {
throw EncodingError.invalidValue(T.self, .init(codingPath: codingPath, debugDescription: "Enums should be encoded to records"))
}
let variant = try convertRecordToVariant(record)
let variant = try convertRecordToVariant(record, using: mirror)
encodingValue = CandidSingleEncodingValue(variant)
}
addOptionals(using: mirror)
encodingValue = CandidSingleEncodingValue(addOptionals(to: encodingValue.candidValue, using: mirror))
}
}

private func addOptionals(using mirror: Mirror) {
private func addOptionals(to value: CandidValue, using mirror: Mirror) -> CandidValue {
// The Swift encoding system skips the optional wrapper when a value is present, trying to directly encode the contained value
// We add optional wrappers where needed according to the Type defined in the mirror
if let record = encodingValue.candidValue.recordValue {
switch value {
case .record(let record):
let newRecordItems = addOptionals(in: record.candidSortedItems, using: mirror, intMarker: "_")
encodingValue = CandidSingleEncodingValue(.record(newRecordItems))
return .record(newRecordItems)

} else if let variant = encodingValue.candidValue.variantValue,
let child = mirror.children.first {
if variant.value != .null && variant.value.candidType.primitiveType != .record {
// single value enum case
if let _ = child.value as? any CandidOptionalMarker {
encodingValue = CandidSingleEncodingValue(.variant(.init(variant.key, .option(variant.value))))
}

} else if let associatedValues = variant.value.recordValue {
case .variant(let variant):
if let child = mirror.children.first {
let associatedMirror = Mirror(reflecting: child.value)
let newRecordItems = addOptionals(in: associatedValues.candidSortedItems, using: associatedMirror, intMarker: ".")
encodingValue = CandidSingleEncodingValue(.variant(.init(variant.key, .record(newRecordItems))))
if let associatedValues = variant.value.recordValue {
let newRecordItems = addOptionals(in: associatedValues.candidSortedItems, using: associatedMirror, intMarker: ".")
return .variant(.init(variant.key, .record(newRecordItems)))
} else if associatedMirror.displayStyle == .tuple {
// case of single named argument in an enum
let tupleMirror = Mirror(reflecting: associatedMirror.children.first!)
if tupleMirror.children.count == 2,
let label = tupleMirror.descendant("label") as? String,
let value = tupleMirror.descendant("value"),
let optional = value as? any CandidOptionalMarker {
if let _ = optional.value {
return .variant(.init(variant.key, .record([label: .option(variant.value)])))
}
return .variant(.init(variant.key, .record([label: .option(candidType(optional.wrappedType))])))
}

} else if let optional = child.value as? any CandidOptionalMarker {
// single value enum case
if let _ = optional.value {
return .variant(.init(variant.key, .option(variant.value)))
}
return .variant(.init(variant.key, .option(candidType(optional.wrappedType))))
}
}
default: break
}
return value
}

private func convertRecordToVariant(_ record: CandidRecord) throws -> CandidValue {
private func convertRecordToVariant(_ record: CandidRecord, using mirror: Mirror) throws -> CandidValue {
guard record.candidSortedItems.count == 1,
let value = record.candidSortedItems.first,
let associatedValues = value.value.recordValue else {
throw CandidEncoderError.enumeratorCanNotBeEncoded
}
let variantValueKey = value.key
if associatedValues.candidSortedItems.isEmpty {
// if let child = mirror.children.first,
// let optional = child.value as? any CandidOptionalMarker {
// return .variant(.init(variantValueKey, .option(candidType(optional.wrappedType))))
// }
return .variant(CandidKeyedValue(variantValueKey))

} else if associatedValues.candidSortedItems.count == 1,
let associatedValue = associatedValues.candidSortedItems.first {
let associatedValue = associatedValues.candidSortedItems.first,
associatedValue.key.stringValue == "_0" {
return .variant(CandidKeyedValue(variantValueKey, associatedValue.value))

} else {
let variantValues = try associatedValues.candidSortedItems.map {
return CandidKeyedValue(try $0.key.toVariantKey(), $0.value)
Expand All @@ -138,13 +162,21 @@ private class CandidValueEncoder: Encoder {
}

private func addOptionals(in keyedValues: [CandidKeyedValue], using mirror: Mirror, intMarker: String) -> [CandidKeyedValue] {
var newRecordItems: [CandidKeyedValue] = []
for keyedItem in keyedValues {
if let child = mirror.children.first(where: { $0.label == keyedItem.key.stringValue ?? "\(intMarker)\(keyedItem.key.intValue)" || CandidKey.candidHash($0.label ?? "?") == keyedItem.key.intValue }),
let _ = child.value as? any CandidOptionalMarker {
newRecordItems.append(.init(keyedItem.key, .option(keyedItem.value)))
var newRecordItems = keyedValues
for child in mirror.children {
if let existing = keyedValues.first(where: {
$0.key.stringValue == child.label ||
"\(intMarker)\($0.key.intValue)" == child.label ||
$0.key.intValue == CandidKey.candidHash(child.label ?? "?")
}) {
if child.value is any CandidOptionalMarker {
newRecordItems.replace(existing.key, with: .option(existing.value))
}
} else {
newRecordItems.append(keyedItem)
if let optional = child.value as? any CandidOptionalMarker,
let label = child.label {
newRecordItems.append(.init(label, .option(candidType(optional.wrappedType))))
}
}
}
return newRecordItems
Expand Down Expand Up @@ -259,7 +291,7 @@ private class CandidKeyedEncodingValue<Key>: CandidEncodingValue where Key: Codi
func set(_ value: CandidValue, for key: Key) { set(CandidSingleEncodingValue(value), for: key) }

private func candidKey(for key: Key) -> CandidKey {
if let int = key.intValue {
if let int = key.intValue, int != CandidKey.candidHash(key.stringValue) {
return CandidKey(int)
}
return CandidKey(key.stringValue)
Expand Down Expand Up @@ -433,3 +465,22 @@ private extension CandidKey {

static let unnamedEnumRegex = try! Regex(#"_(?'number'\d+)"#)
}

private extension Array<CandidKeyedValue> {
mutating func replace(_ key: CandidKey, with newValue: CandidValue) {
replace(key.intValue, with: newValue)
}

mutating func replace(_ stringKey: String, with newValue: CandidValue) {
replace(CandidKey.candidHash(stringKey), with: newValue)
}

mutating func replace(_ intKey: Int, with newValue: CandidValue) {
guard let index = firstIndex(where: { $0.key.intValue == intKey }) else {
return
}
let key = self[index].key
remove(at: index)
insert(.init(key, newValue), at: index)
}
}
23 changes: 16 additions & 7 deletions Tests/CandidTests/CandidEncoderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ private let decodingTestVectors: [(any Codable, CandidValue, any Decodable.Type)
(TestDataRecord3(data: Data([1])), .record(["data": .option(.option(.blob(Data([1]))))]), TestDataRecord3.self),
(TestDataRecord(vector: [.init(data: Data([1]))]), .record(["vector": try! .vector([.record(["data": .option(.blob(Data([1])))])])]), TestDataRecord.self),
(TestRecord(a: 1, b: 2), .record([97: .natural8(1), 98: .integer64(2)]), TestRecord.self),
(TestDataRecord2(data: nil), .record(["data": .option(.blob)]), TestDataRecord2.self),

(TestEnum.a, .variant(.init("a", .null)), TestEnum.self),
(TestEnum.b(2), .variant(.init("b", .option(.natural8(2)))), TestEnum.self),
Expand All @@ -90,6 +91,7 @@ private let decodingTestVectors: [(any Codable, CandidValue, any Decodable.Type)
(TestEnum.e(a: 1, b: 2), .variant(.init(101, .record([97:.option(.natural8(1)), 98: .natural16(2)]))), TestEnum.self),
(TestEnum.f(1, 2), .variant(.init("f", .record([0: .option(.natural8(1)), 1: .natural16(2)]))), TestEnum.self),
(TestEnum.g(SingleValueRecord(value: 2)), .variant(.init("g", .record(["value": .integer8(2)]))), TestEnum.self),
(TestEnum.h(a: 2), .variant(.init("h", .record(["a": .option(.natural8(2))]))), TestEnum.self),

]

Expand Down Expand Up @@ -150,8 +152,8 @@ private let encodingTestVectors: [(any Encodable, CandidValue)] = [
([UInt8]([8, 2]), try! .vector([.natural8(8), .natural8(2)])),
([UInt8?]([8, nil, 9]), try! .vector([.option(.natural8(8)), .option(.natural8), .option(.natural8(9))])),

(["a":0], .record(["a": .integer64(0)])),
(["a":0, "b": UInt8(8)], .record(["a": .natural8(0), "b": .natural8(8)])),
(["a": 0], .record(["a": .integer64(0)])),
(["a": 0, "b": UInt8(8)], .record(["a": .natural8(0), "b": .natural8(8)])),
(TestRecord(a: 1, b: 2), .record(["a": .natural8(1), "b": .integer64(2)])),
(TestRecord2(a: [1, nil]), .record(["a": try! .vector([.option(.natural8(1)), .option(.natural8)])])),
(TestRecord3(record: TestRecord(a: 1, b: 2), records2: [TestRecord2(a: [1, nil])]), .record([
Expand All @@ -164,18 +166,23 @@ private let encodingTestVectors: [(any Encodable, CandidValue)] = [
(TestDataRecord(vector: [.init(data: Data([1]))]), .record(["vector": try! .vector([.record(["data": .option(.blob(Data([1])))])])])),
// this fails because we don't correctly identify the recursive type.
//(TestRecursiveRecord(a: []), .record(["a": .vector(.record())])),

(TestDataRecord2(data: nil), .record(["data": .option(.blob)])),

(TestEnum.a, .variant(.init("a", .null))),
(TestEnum.b(2), .variant(.init("b", .option(.natural8(2))))),
(TestEnum.b(nil), .variant(.init("b", .option(.natural8)))),
(TestEnum.c(.a), .variant(.init("c", .variant(.init("a"))))),
(TestEnum.d([1,2]), .variant(.init("d", try! .vector([.natural8(1), .natural8(2)])))),
(TestEnum.e(a: 1, b: 2), .variant(.init("e", .record(["a":.option(.natural8(1)), "b": .natural16(2)])))),
(TestEnum.e(a: nil, b: 2), .variant(.init("e", .record(["a":.option(.natural8), "b": .natural16(2)])))),
(TestEnum.f(1, 2), .variant(.init("f", .record([0: .option(.natural8(1)), 1: .natural16(2)])))),
(TestEnum.g(SingleValueRecord(value: 2)), .variant(.init("g", .record(["value": .integer8(2)])))),

(CandidFunction(signature: .init([CandidType](), []), method: nil), .function(CandidFunction(signature: .init([CandidType](), []), method: nil))),
(try! CandidPrincipal("aaaaa-aa"), try! .principal("aaaaa-aa")),
(CandidService(principal: nil, signature: .init([])), .service(CandidService(principal: nil, signature: .init([])))),
(TestEnum.h(a: 2), .variant(.init("h", .record(["a": .option(.natural8(2))])))),
(TestEnum.h(a: nil), .variant(.init("h", .record(["a": .option(.natural8)])))),
//
// (CandidFunction(signature: .init([CandidType](), []), method: nil), .function(CandidFunction(signature: .init([CandidType](), []), method: nil))),
// (try! CandidPrincipal("aaaaa-aa"), try! .principal("aaaaa-aa")),
// (CandidService(principal: nil, signature: .init([])), .service(CandidService(principal: nil, signature: .init([])))),
]

private struct TestRecord: Codable {
Expand Down Expand Up @@ -226,6 +233,7 @@ private indirect enum TestEnum: Codable {
case e(a: UInt8?, b: UInt16)
case f(UInt8?, UInt16)
case g(SingleValueRecord)
case h(a: UInt8?)

enum CodingKeys: Int, CodingKey {
case a = 97
Expand All @@ -235,6 +243,7 @@ private indirect enum TestEnum: Codable {
case e = 101
case f = 102
case g = 103
case h = 104
}

enum ECodingKeys: Int, CodingKey {
Expand Down

0 comments on commit b53964e

Please sign in to comment.