Skip to content

Commit

Permalink
Rewrite certificate package.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed Mar 23, 2024
1 parent 1205f48 commit 05936cf
Show file tree
Hide file tree
Showing 14 changed files with 494 additions and 205 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Supported identities are `Ed25519` and `Secp256k1`. By default, the agent uses t
```go
id, _ := identity.NewEd25519Identity(publicKey, privateKey)
config := agent.Config{
Identity: id,
Identity: id,
}
```

Expand All @@ -65,8 +65,8 @@ If you are running a local replica, you can use the `FetchRootKey` option to fet
```go
u, _ := url.Parse("http://localhost:8000")
config := agent.Config{
ClientConfig: &agent.ClientConfig{Host: u},
FetchRootKey: true,
ClientConfig: &agent.ClientConfig{Host: u},
FetchRootKey: true,
}
```

Expand Down Expand Up @@ -103,3 +103,8 @@ installed then those tests will be ignored.
```shell
go test -v ./...
```

## Reference Implementations

- [Rust Agent](https://github.com/dfinity/agent-rs/)
- [JavaScript Agent](https://github.com/dfinity/agent-js/)
62 changes: 42 additions & 20 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/aviate-labs/agent-go/certificate/hashtree"
"net/url"
"time"

Expand Down Expand Up @@ -129,37 +130,45 @@ func (a Agent) GetCanisterControllers(canisterID principal.Principal) ([]princip

// GetCanisterInfo returns the raw certificate for the given canister based on the given sub-path.
func (a Agent) GetCanisterInfo(canisterID principal.Principal, subPath string) ([]byte, error) {
path := [][]byte{[]byte("canister"), canisterID.Raw, []byte(subPath)}
c, err := a.readStateCertificate(canisterID, [][][]byte{path})
path := []hashtree.Label{hashtree.Label("canister"), canisterID.Raw, hashtree.Label(subPath)}
c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path})
if err != nil {
return nil, err
}
var state map[string]any
if err := cbor.Unmarshal(c, &state); err != nil {
return nil, err
}
node, err := certificate.DeserializeNode(state["tree"].([]any))
node, err := hashtree.DeserializeNode(state["tree"].([]any))
if err != nil {
return nil, err
}
return certificate.Lookup(path, node), nil
result := hashtree.NewHashTree(node).Lookup(path...)
if err := result.Found(); err != nil {
return nil, err
}
return result.Value, nil
}

func (a Agent) GetCanisterMetadata(canisterID principal.Principal, subPath string) ([]byte, error) {
path := [][]byte{[]byte("canister"), canisterID.Raw, []byte("metadata"), []byte(subPath)}
c, err := a.readStateCertificate(canisterID, [][][]byte{path})
path := []hashtree.Label{hashtree.Label("canister"), canisterID.Raw, hashtree.Label("metadata"), hashtree.Label(subPath)}
c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path})
if err != nil {
return nil, err
}
var state map[string]any
if err := cbor.Unmarshal(c, &state); err != nil {
return nil, err
}
node, err := certificate.DeserializeNode(state["tree"].([]any))
node, err := hashtree.DeserializeNode(state["tree"].([]any))
if err != nil {
return nil, err
}
return certificate.Lookup(path, node), nil
result := hashtree.NewHashTree(node).Lookup(path...)
if err := result.Found(); err != nil {
return nil, err
}
return result.Value, nil
}

// GetCanisterModuleHash returns the module hash for the given canister.
Expand Down Expand Up @@ -208,9 +217,9 @@ func (a Agent) Query(canisterID principal.Principal, methodName string, args []a
}

// RequestStatus returns the status of the request with the given ID.
func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID) ([]byte, certificate.Node, error) {
path := [][]byte{[]byte("request_status"), requestID[:]}
c, err := a.readStateCertificate(canisterID, [][][]byte{path})
func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID) ([]byte, hashtree.Node, error) {
path := []hashtree.Label{hashtree.Label("request_status"), requestID[:]}
c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path})
if err != nil {
return nil, nil, err
}
Expand All @@ -225,11 +234,15 @@ func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID
if err := cert.Verify(); err != nil {
return nil, nil, err
}
node, err := certificate.DeserializeNode(state["tree"].([]any))
node, err := hashtree.DeserializeNode(state["tree"].([]any))
if err != nil {
return nil, nil, err
}
return certificate.Lookup(append(path, []byte("status")), node), node, nil
result := hashtree.NewHashTree(node).Lookup(append(path, hashtree.Label("status"))...)
if err := result.Found(); err != nil {
return nil, nil, err
}
return result.Value, node, nil
}

// Sender returns the principal that is sending the requests.
Expand All @@ -256,15 +269,24 @@ func (a Agent) poll(canisterID principal.Principal, requestID RequestID, delay,
return nil, err
}
if len(data) != 0 {
path := [][]byte{[]byte("request_status"), requestID[:]}
path := []hashtree.Label{hashtree.Label("request_status"), requestID[:]}
switch string(data) {
case "rejected":
code := certificate.Lookup(append(path, []byte("reject_code")), node)
rejectMessage := certificate.Lookup(append(path, []byte("reject_message")), node)
return nil, fmt.Errorf("(%d) %s", uint64FromBytes(code), string(rejectMessage))
tree := hashtree.NewHashTree(node)
codeResult := tree.Lookup(append(path, hashtree.Label("reject_code"))...)
messageResult := tree.Lookup(append(path, hashtree.Label("reject_message"))...)
if codeResult.Found() != nil || messageResult.Found() != nil {
return nil, fmt.Errorf("no reject code or message found")
}
return nil, fmt.Errorf("(%d) %s", uint64FromBytes(codeResult.Value), string(messageResult.Value))
case "replied":
path := [][]byte{[]byte("request_status"), requestID[:]}
return certificate.Lookup(append(path, []byte("reply")), node), nil
fmt.Println(node)
repliedResult := hashtree.NewHashTree(node).Lookup(append(path, hashtree.Label("reply"))...)
fmt.Println(repliedResult)
if repliedResult.Found() != nil {
return nil, fmt.Errorf("no reply found")
}
return repliedResult.Value, nil
}
}
case <-timer.C:
Expand All @@ -291,7 +313,7 @@ func (a Agent) readState(canisterID principal.Principal, data []byte) (map[strin
return m, cbor.Unmarshal(resp, &m)
}

func (a Agent) readStateCertificate(canisterID principal.Principal, paths [][][]byte) ([]byte, error) {
func (a Agent) readStateCertificate(canisterID principal.Principal, paths [][]hashtree.Label) ([]byte, error) {
_, data, err := a.sign(Request{
Type: RequestTypeReadState,
Sender: a.Sender(),
Expand Down
26 changes: 13 additions & 13 deletions certificate/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package certificate
import (
"fmt"
"github.com/aviate-labs/agent-go/certificate/bls"
"github.com/aviate-labs/agent-go/certificate/hashtree"
"github.com/aviate-labs/agent-go/principal"
"github.com/fxamacker/cbor/v2"
"slices"
Expand All @@ -11,15 +12,15 @@ import (
// Cert is a certificate gets returned by the IC.
type Cert struct {
// Tree is the certificate tree.
Tree HashTree `cbor:"tree"`
Tree hashtree.HashTree `cbor:"tree"`
// Signature is the signature of the certificate tree.
Signature []byte `cbor:"signature"`
// Delegation is the delegation of the certificate.
Delegation *Delegation `cbor:"delegation"`
}

// Certificate is a certificate gets returned by the IC and can be used to verify
// the state root based on the root key and canister ID.
// Certificate is a certificate that gets returned by the IC and can be used to verify the state root based on the root
// key and canister ID.
type Certificate struct {
Cert Cert
RootKey []byte
Expand Down Expand Up @@ -50,7 +51,7 @@ func (c Certificate) Verify() error {
return err
}
rootHash := c.Cert.Tree.Digest()
message := append(DomainSeparator("ic-state-root"), rootHash[:]...)
message := append(hashtree.DomainSeparator("ic-state-root"), rootHash[:]...)
if !signature.Verify(publicKey, string(message)) {
return fmt.Errorf("signature verification failed")
}
Expand All @@ -64,15 +65,14 @@ func (c Certificate) getPublicKey() (*bls.PublicKey, error) {
}

cert := c.Cert.Delegation
canisterRanges := Lookup(
LookupPath("subnet", string(cert.SubnetId.Raw), "canister_ranges"),
cert.Certificate.Cert.Tree.Root,
canisterRangesResult := cert.Certificate.Cert.Tree.Lookup(
hashtree.Label("subnet"), cert.SubnetId.Raw, hashtree.Label("canister_ranges"),
)
if canisterRanges == nil {
if canisterRangesResult.Found() != nil {
return nil, fmt.Errorf("no canister ranges found for subnet %s", cert.SubnetId)
}
var rawRanges [][][]byte
if err := cbor.Unmarshal(canisterRanges, &rawRanges); err != nil {
if err := cbor.Unmarshal(canisterRangesResult.Value, &rawRanges); err != nil {
return nil, err
}

Expand All @@ -90,14 +90,14 @@ func (c Certificate) getPublicKey() (*bls.PublicKey, error) {
return nil, fmt.Errorf("canister %s is not in range", c.CanisterID)
}

publicKey := Lookup(
LookupPath("subnet", string(cert.SubnetId.Raw), "public_key"),
cert.Certificate.Cert.Tree.Root,
publicKeyResult := cert.Certificate.Cert.Tree.Lookup(
hashtree.Label("subnet"), cert.SubnetId.Raw, hashtree.Label("public_key"),
)
if publicKey == nil {
if publicKeyResult.Found() != nil {
return nil, fmt.Errorf("no public key found for subnet %s", cert.SubnetId)
}

publicKey := publicKeyResult.Value
if len(publicKey) != len(derPrefix)+96 {
return nil, fmt.Errorf("invalid public key length: %d", len(publicKey))
}
Expand Down
12 changes: 11 additions & 1 deletion certificate/tree.go → certificate/hashtree/hashtree.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package certificate
package hashtree

// HashTree is a hash tree.
type HashTree struct {
Expand Down Expand Up @@ -29,3 +29,13 @@ func (t *HashTree) UnmarshalCBOR(bytes []byte) error {
t.Root = root
return nil
}

// Lookup looks up a path in the hash tree.
func (t HashTree) Lookup(path ...Label) LookupResult {
return lookupPath(t.Root, path...)
}

// LookupSubTree looks up a path in the hash tree and returns the sub-tree.
func (t HashTree) LookupSubTree(path ...Label) LookupSubTreeResult {
return lookupSubTree(t.Root, path...)
}
124 changes: 124 additions & 0 deletions certificate/hashtree/hashtree_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package hashtree

import (
"bytes"
"encoding/hex"
"fmt"
"testing"
)

func TestHashTree_simple(t *testing.T) {
tree := NewHashTree(Fork{
LeftTree: Labeled{
Label: Label("label 1"),
Tree: Empty{},
},
RightTree: Fork{
LeftTree: Pruned{
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
},
RightTree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06},
},
})
digest := tree.Digest()
if hex.EncodeToString(digest[:]) != "69cf325d0f20505b261821a7e77ff72fb9a8753a7964f0b587553bfb44e72532" {
t.Fatalf("unexpected digest: %x", digest)
}
}

func TestHashTree_Lookup(t *testing.T) {
t.Run("Empty Nodes", func(t *testing.T) {
tree := NewHashTree(Fork{
LeftTree: Labeled{
Label: Label("label 1"),
Tree: Empty{},
},
RightTree: Fork{
LeftTree: Pruned{},
RightTree: Fork{
LeftTree: Labeled{
Label: Label("label 3"),
Tree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06},
},
RightTree: Labeled{
Label: Label("label 5"),
Tree: Empty{},
},
},
},
})

for _, i := range []int{0, 1} {
if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent {
t.Fatalf("unexpected lookup result")
}
}
if result := tree.Lookup(Label("label 2")); result.Type != LookupResultUnknown {
t.Fatalf("unexpected lookup result")
}
if result := tree.Lookup(Label("label 3")); result.Type != LookupResultFound {
t.Fatalf("unexpected lookup result")
} else {
if !bytes.Equal(result.Value, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) {
t.Fatalf("unexpected node value")
}
}
for _, i := range []int{4, 5, 6} {
if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent {
t.Fatalf("unexpected lookup result")
}
}
})
t.Run("Nil Nodes", func(t *testing.T) {
// let tree: HashTree<Vec<u8>> = fork(
// label("label 1", empty()),
// fork(
// fork(
// label("label 3", leaf(vec![1, 2, 3, 4, 5, 6])),
// label("label 5", empty()),
// ),
// pruned([1; 32]),
// ),
// );
tree := NewHashTree(Fork{
LeftTree: Labeled{
Label: Label("label 1"),
},
RightTree: Fork{
LeftTree: Fork{
LeftTree: Labeled{
Label: Label("label 3"),
Tree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06},
},
RightTree: Labeled{
Label: Label("label 5"),
},
},
RightTree: Pruned{},
},
})
for _, i := range []int{0, 1, 2} {
if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent {
t.Fatalf("unexpected lookup result")
}
}
if result := tree.Lookup(Label("label 3")); result.Type != LookupResultFound {
t.Fatalf("unexpected lookup result")
} else {
if !bytes.Equal(result.Value, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) {
t.Fatalf("unexpected node value")
}
}
for _, i := range []int{4, 5} {
if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent {
t.Fatalf("unexpected lookup result")
}
}
if result := tree.Lookup(Label("label 6")); result.Type != LookupResultUnknown {
t.Fatalf("unexpected lookup result")
}
})
}
Loading

0 comments on commit 05936cf

Please sign in to comment.