Skip to content

Commit

Permalink
Add support for func callbacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed Feb 3, 2025
1 parent 3f43a73 commit ae746b0
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 64 deletions.
181 changes: 134 additions & 47 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/aviate-labs/agent-go"
"github.com/aviate-labs/agent-go/candid/idl"
"github.com/aviate-labs/agent-go/certification/hashtree"
"github.com/aviate-labs/agent-go/identity"
"github.com/aviate-labs/agent-go/principal"
Expand Down Expand Up @@ -60,7 +61,7 @@ func Example_json() {

func Example_query_ed25519() {
id, _ := identity.NewRandomEd25519Identity()
ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai")
ledgerID := principal.MustDecode("ryjl3-tyaaa-aaaaa-aaaba-cai")
a, _ := agent.New(agent.Config{Identity: id})
var balance struct {
E8S uint64 `ic:"e8s"`
Expand All @@ -75,7 +76,7 @@ func Example_query_ed25519() {

func Example_query_prime256v1() {
id, _ := identity.NewRandomPrime256v1Identity()
ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai")
ledgerID := principal.MustDecode("ryjl3-tyaaa-aaaaa-aaaba-cai")
a, _ := agent.New(agent.Config{Identity: id})
var balance struct {
E8S uint64 `ic:"e8s"`
Expand All @@ -90,7 +91,7 @@ func Example_query_prime256v1() {

func Example_query_secp256k1() {
id, _ := identity.NewRandomSecp256k1Identity()
ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai")
ledgerID := principal.MustDecode("ryjl3-tyaaa-aaaaa-aaaba-cai")
a, _ := agent.New(agent.Config{Identity: id})
var balance struct {
E8S uint64 `ic:"e8s"`
Expand Down Expand Up @@ -120,7 +121,49 @@ func TestAgent_Call(t *testing.T) {
}
}

func TestAgent_Query_Callback(t *testing.T) {
func TestAgent_Query_Ed25519(t *testing.T) {
id, err := identity.NewRandomEd25519Identity()
if err != nil {
t.Fatal(err)
}
a, _ := agent.New(agent.Config{
Identity: id,
})
type Account struct {
Account string `ic:"account"`
}
var balance struct {
E8S uint64 `ic:"e8s"`
}
if err := a.Query(LEDGER_PRINCIPAL, "account_balance_dfx", []any{
Account{"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d"},
}, []any{&balance}); err != nil {
t.Fatal(err)
}
}

func TestAgent_Query_Secp256k1(t *testing.T) {
id, err := identity.NewRandomSecp256k1Identity()
if err != nil {
t.Fatal(err)
}
a, _ := agent.New(agent.Config{
Identity: id,
})
type Account struct {
Account string `ic:"account"`
}
var balance struct {
E8S uint64 `ic:"e8s"`
}
if err := a.Query(LEDGER_PRINCIPAL, "account_balance_dfx", []any{
Account{"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d"},
}, []any{&balance}); err != nil {
t.Fatal(err)
}
}

func TestAgent_Query_callback(t *testing.T) {
a, err := agent.New(agent.DefaultConfig)
if err != nil {
t.Fatal(err)
Expand All @@ -132,14 +175,10 @@ func TestAgent_Query_Callback(t *testing.T) {
Length uint64 `ic:"length" json:"length"`
}

type QueryArchiveFn struct {
/* TODO! */
}

type ArchivedBlocksRange struct {
Start uint64 `ic:"start" json:"start"`
Length uint64 `ic:"length" json:"length"`
Callback QueryArchiveFn `ic:"callback" json:"callback"`
Start uint64 `ic:"start" json:"start"`
Length uint64 `ic:"length" json:"length"`
Callback idl.Function `ic:"callback" json:"callback"`
}

type QueryBlocksResponse struct {
Expand All @@ -150,14 +189,15 @@ func TestAgent_Query_Callback(t *testing.T) {
ArchivedBlocks []ArchivedBlocksRange `ic:"archived_blocks" json:"archived_blocks"`
}

args := GetBlocksArgs{
Start: 123,
Length: 1,
}
req, err := a.CreateCandidAPIRequest(
agent.RequestTypeQuery,
ledgerCanisterID,
"query_blocks",
GetBlocksArgs{
Start: 123,
Length: 1,
},
args,
)
if err != nil {
t.Fatal(err)
Expand All @@ -170,47 +210,94 @@ func TestAgent_Query_Callback(t *testing.T) {
if archive.Start != 123 || archive.Length != 1 {
t.Error(archive)
}
}
if !archive.Callback.Method.Principal.Equal(principal.MustDecode("qjdve-lqaaa-aaaaa-aaaeq-cai")) {
t.Error(archive.Callback.Method.Principal)
}
if archive.Callback.Method.Method != "get_blocks" {
t.Error(archive.Callback.Method.Method)
}

func TestAgent_Query_Ed25519(t *testing.T) {
id, err := identity.NewRandomEd25519Identity()
if err != nil {
t.Fatal(err)
type Timestamp struct {
TimestampNanos uint64 `ic:"timestamp_nanos" json:"timestamp_nanos"`
}
a, _ := agent.New(agent.Config{
Identity: id,
})
type Account struct {
Account string `ic:"account"`

type Tokens struct {
E8s uint64 `ic:"e8s" json:"e8s"`
}
var balance struct {
E8S uint64 `ic:"e8s"`

type Operation struct {
Mint *struct {
To []byte `ic:"to" json:"to"`
Amount Tokens `ic:"amount" json:"amount"`
} `ic:"Mint,variant"`
Burn *struct {
From []byte `ic:"from" json:"from"`
Spender *[]byte `ic:"spender,omitempty" json:"spender,omitempty"`
Amount Tokens `ic:"amount" json:"amount"`
} `ic:"Burn,variant"`
Transfer *struct {
From []byte `ic:"from" json:"from"`
To []byte `ic:"to" json:"to"`
Amount Tokens `ic:"amount" json:"amount"`
Fee Tokens `ic:"fee" json:"fee"`
Spender *[]uint8 `ic:"spender,omitempty" json:"spender,omitempty"`
} `ic:"Transfer,variant"`
Approve *struct {
From []byte `ic:"from" json:"from"`
Spender []byte `ic:"spender" json:"spender"`
AllowanceE8s idl.Int `ic:"allowance_e8s" json:"allowance_e8s"`
Allowance Tokens `ic:"allowance" json:"allowance"`
Fee Tokens `ic:"fee" json:"fee"`
ExpiresAt *Timestamp `ic:"expires_at,omitempty" json:"expires_at,omitempty"`
ExpectedAllowance *Tokens `ic:"expected_allowance,omitempty" json:"expected_allowance,omitempty"`
} `ic:"Approve,variant"`
}
if err := a.Query(LEDGER_PRINCIPAL, "account_balance_dfx", []any{
Account{"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d"},
}, []any{&balance}); err != nil {
t.Fatal(err)

type Transaction struct {
Memo uint64 `ic:"memo" json:"memo"`
Icrc1Memo *[]byte `ic:"icrc1_memo,omitempty" json:"icrc1_memo,omitempty"`
Operation *Operation `ic:"operation,omitempty" json:"operation,omitempty"`
CreatedAtTime Timestamp `ic:"created_at_time" json:"created_at_time"`
}
}

func TestAgent_Query_Secp256k1(t *testing.T) {
id, err := identity.NewRandomSecp256k1Identity()
if err != nil {
t.Fatal(err)
type Block struct {
ParentHash *[]byte `ic:"parent_hash,omitempty" json:"parent_hash,omitempty"`
Transaction Transaction `ic:"transaction" json:"transaction"`
Timestamp Timestamp `ic:"timestamp" json:"timestamp"`
}
a, _ := agent.New(agent.Config{
Identity: id,
})
type Account struct {
Account string `ic:"account"`

type BlockRange struct {
Blocks []Block `ic:"blocks" json:"blocks"`
}
var balance struct {
E8S uint64 `ic:"e8s"`

type GetBlocksError struct {
BadFirstBlockIndex *struct {
RequestedIndex uint64 `ic:"requested_index" json:"requested_index"`
FirstValidIndex uint64 `ic:"first_valid_index" json:"first_valid_index"`
} `ic:"BadFirstBlockIndex,variant"`
Other *struct {
ErrorCode uint64 `ic:"error_code" json:"error_code"`
ErrorMessage string `ic:"error_message" json:"error_message"`
} `ic:"Other,variant"`
}
if err := a.Query(LEDGER_PRINCIPAL, "account_balance_dfx", []any{
Account{"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d"},
}, []any{&balance}); err != nil {
t.Fatal(err)

type GetBlocksResult struct {
Ok *BlockRange `ic:"Ok,variant"`
Err *GetBlocksError `ic:"Err,variant"`
}

var blocks GetBlocksResult
if err := a.Query(
archive.Callback.Method.Principal,
archive.Callback.Method.Method,
[]any{args},
[]any{&blocks},
); err != nil {
t.Error(err)
}

if len(blocks.Ok.Blocks) != 1 {
t.Error(blocks)
}
}

Expand Down
4 changes: 4 additions & 0 deletions candid/did/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ const (
// it does not alter the state of its canister, and that it can be invoked
// using the cheaper “query call” mechanism.
AnnQuery FuncAnnotation = "query"
// AnnCompositeQuery is a special query function that has IC-specific
// features (and limitations). Eventually, query and composite_query
// functions will become the same thing.
AnnCompositeQuery FuncAnnotation = "composite_query"
)

// Tuple represents one or more arguments.
Expand Down
23 changes: 19 additions & 4 deletions candid/idl/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func encodeTypes(ts []Type, tdt *TypeDefinitionTable) ([]byte, error) {
return concat(l, vs), nil
}

type Function struct {
Types FunctionType
Method PrincipalMethod
}

type FunctionParameter struct {
Type Type
Index OpCode
Expand Down Expand Up @@ -120,7 +125,7 @@ func (f FunctionType) Decode(r *bytes.Reader) (any, error) {
}
m := make([]byte, ml.Int64())
{
n, err := r.Read(pid)
n, err := r.Read(m)
if err != nil {
return nil, err
}
Expand All @@ -143,7 +148,7 @@ func (f FunctionType) EncodeType(tdt *TypeDefinitionTable) ([]byte, error) {
}

func (f FunctionType) EncodeValue(v any) ([]byte, error) {
pm, ok := v.(PrincipalMethod)
pm, ok := v.(*PrincipalMethod)
if !ok {
return nil, NewEncodeValueError(v, FuncOpCode)
}
Expand Down Expand Up @@ -174,8 +179,18 @@ func (f FunctionType) String() string {
return fmt.Sprintf("(%s) -> (%s)%s", strings.Join(args, ", "), strings.Join(rets, ", "), ann)
}

func (FunctionType) UnmarshalGo(raw any, _v any) error {
return nil // Unsupported.
func (f FunctionType) UnmarshalGo(raw any, _v any) error {
pm, ok := raw.(*PrincipalMethod)
if !ok {
return NewUnmarshalGoError(raw, _v)
}
v, ok := _v.(*Function)
if !ok {
return NewUnmarshalGoError(raw, _v)
}
v.Types = f
v.Method = *pm
return nil
}

type PrincipalMethod struct {
Expand Down
5 changes: 2 additions & 3 deletions candid/idl/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
)

func ExampleFunctionType() {
p, _ := principal.Decode("w7x7r-cok77-xa")
test_(
[]idl.Type{
idl.NewFunctionType(
Expand All @@ -16,8 +15,8 @@ func ExampleFunctionType() {
),
},
[]any{
idl.PrincipalMethod{
Principal: p,
&idl.PrincipalMethod{
Principal: principal.MustDecode("w7x7r-cok77-xa"),
Method: "foo",
},
},
Expand Down
5 changes: 3 additions & 2 deletions candid/idl/principal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package idl_test

import (
"bytes"
"testing"

"github.com/aviate-labs/agent-go/candid/idl"
"github.com/aviate-labs/agent-go/principal"
"testing"
)

func ExamplePrincipal() {
p, _ := principal.Decode("aaaaa-aa")
p := principal.MustDecode("aaaaa-aa")
test([]idl.Type{idl.NewOptionalType(new(idl.PrincipalType))}, []any{p})
// Output:
// 4449444c016e680100010100
Expand Down
2 changes: 1 addition & 1 deletion candid/idl/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func ExampleService() {
p, _ := principal.Decode("w7x7r-cok77-xa")
p := principal.MustDecode("w7x7r-cok77-xa")
test(
[]idl.Type{idl.NewServiceType(
map[string]*idl.FunctionType{
Expand Down
6 changes: 3 additions & 3 deletions certification/http/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestAgent_HttpRequest(t *testing.T) {
canisterId, _ := principal.Decode("rdmx6-jaaaa-aaaaa-aaadq-cai")
canisterId := principal.MustDecode("rdmx6-jaaaa-aaaaa-aaadq-cai")
a, err := http.NewAgent(canisterId, agent.DefaultConfig)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestCalculateRequestHash(t *testing.T) {
}

func TestResponse_Verify_V1(t *testing.T) {
canisterId, _ := principal.Decode("rdmx6-jaaaa-aaaaa-aaadq-cai")
canisterId := principal.MustDecode("rdmx6-jaaaa-aaaaa-aaadq-cai")
a, err := http.NewAgent(canisterId, agent.DefaultConfig)
if err != nil {
t.Fatal(err)
Expand All @@ -97,7 +97,7 @@ func TestResponse_Verify_V1(t *testing.T) {
}

func TestResponse_Verify_V2(t *testing.T) {
canisterId, _ := principal.Decode("rdmx6-jaaaa-aaaaa-aaadq-cai")
canisterId := principal.MustDecode("rdmx6-jaaaa-aaaaa-aaadq-cai")
a, err := http.NewAgent(canisterId, agent.DefaultConfig)
if err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (g *Generator) dataToString(prefix string, data did.Data) string {
case did.DataId:
return funcName(prefix, string(t))
case did.Func:
return "struct { /* NOT SUPPORTED */ }"
return "idl.Function"
case did.Optional:
return fmt.Sprintf("*%s", g.dataToString(prefix, t.Data))
case did.Primitive:
Expand Down
Loading

0 comments on commit ae746b0

Please sign in to comment.