Skip to content

Commit

Permalink
enhancement: Return encryption key from GetBundle (#240)
Browse files Browse the repository at this point in the history
Signed-off-by: Oğuzhan Durgun <oguzhandurgun95@gmail.com>
  • Loading branch information
oguzhand95 authored Dec 16, 2024
1 parent adbc05f commit 6e4dc1d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
11 changes: 8 additions & 3 deletions bundle/v2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (c *Client) parseBundleResponse(bundleResponseBytes []byte) (*bundlev2.GetB
}

// GetBundle returns the path to the bundle with the given label.
func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, error) {
func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, []byte, error) {
log := c.Logger.WithValues("bundle", bundleLabel)
log.V(1).Info("Calling GetBundle RPC")

Expand All @@ -153,12 +153,17 @@ func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, err
}))
if err != nil {
log.Error(err, "GetBundle RPC failed")
return "", err
return "", nil, err
}

base.LogResponsePayload(log, resp.Msg)

return c.getBundleFile(logr.NewContext(ctx, log), resp.Msg.BundleInfo)
path, err := c.getBundleFile(logr.NewContext(ctx, log), resp.Msg.BundleInfo)
if err != nil {
return "", nil, err
}

return path, resp.Msg.BundleInfo.EncryptionKey, nil
}

func (c *Client) WatchBundle(ctx context.Context, bundleLabel string) (bundle.WatchHandle, error) {
Expand Down
24 changes: 13 additions & 11 deletions bundle/v2/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,15 @@ func TestGetBundle(t *testing.T) {

expectIssueAccessToken(mockAPIKeySvc)

wantEncryptionKey := []byte("secret")
mockBundleSvc.EXPECT().
GetBundle(mock.Anything, mock.MatchedBy(getBundleReq("label"))).
Return(connect.NewResponse(&bundlev2.GetBundleResponse{
BundleInfo: &bundlev2.BundleInfo{
Label: "label",
InputHash: hash("input"),
OutputHash: wantChecksum,
EncryptionKey: []byte("secret"),
EncryptionKey: wantEncryptionKey,
Segments: []*bundlev2.BundleInfo_Segment{
{
SegmentId: 1,
Expand All @@ -174,8 +175,9 @@ func TestGetBundle(t *testing.T) {
}), nil).Times(3)

for i := 0; i < 3; i++ {
file, err := client.GetBundle(context.Background(), "label")
file, encryptionKey, err := client.GetBundle(context.Background(), "label")
require.NoError(t, err)
require.Equal(t, wantEncryptionKey, encryptionKey)

haveChecksum := checksum(t, file)
require.Equal(t, wantChecksum, haveChecksum, "Checksum does not match")
Expand Down Expand Up @@ -221,7 +223,7 @@ func TestGetBundle(t *testing.T) {
},
}), nil)

file, err := client.GetBundle(context.Background(), "label")
file, _, err := client.GetBundle(context.Background(), "label")
require.NoError(t, err)

haveChecksum := checksum(t, file)
Expand Down Expand Up @@ -270,7 +272,7 @@ func TestGetBundle(t *testing.T) {
}), nil).Times(3)

for i := 0; i < 3; i++ {
file, err := client.GetBundle(context.Background(), "label")
file, _, err := client.GetBundle(context.Background(), "label")
require.NoError(t, err)

haveChecksum := checksum(t, file)
Expand Down Expand Up @@ -328,7 +330,7 @@ func TestGetBundle(t *testing.T) {
}), nil).Times(3)

for i := 0; i < 3; i++ {
file1, err := client.GetBundle(context.Background(), "label")
file1, _, err := client.GetBundle(context.Background(), "label")
require.NoError(t, err)

haveChecksum1 := checksum(t, file1)
Expand Down Expand Up @@ -380,7 +382,7 @@ func TestGetBundle(t *testing.T) {
}), nil).Times(3)

for i := 0; i < 3; i++ {
file2, err := client.GetBundle(context.Background(), "label")
file2, _, err := client.GetBundle(context.Background(), "label")
require.NoError(t, err)

haveChecksum2 := checksum(t, file2)
Expand Down Expand Up @@ -435,7 +437,7 @@ func TestGetBundle(t *testing.T) {
},
}), nil).Once()

_, err := client.GetBundle(context.Background(), "label")
_, _, err := client.GetBundle(context.Background(), "label")
require.Error(t, err)

require.Equal(t, 3, counter.getTotal(), "Total download count does not match")
Expand Down Expand Up @@ -483,7 +485,7 @@ func TestGetBundle(t *testing.T) {
},
}), nil).Once()

_, err := client.GetBundle(context.Background(), "label")
_, _, err := client.GetBundle(context.Background(), "label")
require.Error(t, err)

require.Equal(t, 3, counter.getTotal(), "Total download count does not match")
Expand Down Expand Up @@ -520,7 +522,7 @@ func TestGetBundle(t *testing.T) {
},
}), nil).Once()

_, err := client.GetBundle(context.Background(), "label")
_, _, err := client.GetBundle(context.Background(), "label")
require.Error(t, err)

require.Equal(t, 1, counter.getTotal(), "Total download count does not match")
Expand Down Expand Up @@ -555,7 +557,7 @@ func TestGetBundle(t *testing.T) {
},
}), nil).Once()

_, err := client.GetBundle(context.Background(), "label")
_, _, err := client.GetBundle(context.Background(), "label")
require.Error(t, err)
})

Expand All @@ -571,7 +573,7 @@ func TestGetBundle(t *testing.T) {
IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())).
Return(nil, connect.NewError(connect.CodeUnauthenticated, errors.New("🙅")))

_, err := client.GetBundle(context.Background(), "label")
_, _, err := client.GetBundle(context.Background(), "label")
require.Error(t, err)
require.ErrorIs(t, err, base.ErrAuthenticationFailed)
})
Expand Down

0 comments on commit 6e4dc1d

Please sign in to comment.