From 6e4dc1dee6803633b25b993aef446bee375868ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20Durgun?= Date: Mon, 16 Dec 2024 17:54:26 +0300 Subject: [PATCH] enhancement: Return encryption key from GetBundle (#240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Oğuzhan Durgun --- bundle/v2/client.go | 11 ++++++++--- bundle/v2/client_test.go | 24 +++++++++++++----------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/bundle/v2/client.go b/bundle/v2/client.go index d673615..a21d269 100644 --- a/bundle/v2/client.go +++ b/bundle/v2/client.go @@ -143,7 +143,7 @@ func (c *Client) parseBundleResponse(bundleResponseBytes []byte) (*bundlev2.GetB } // GetBundle returns the path to the bundle with the given label. -func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, error) { +func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, []byte, error) { log := c.Logger.WithValues("bundle", bundleLabel) log.V(1).Info("Calling GetBundle RPC") @@ -153,12 +153,17 @@ func (c *Client) GetBundle(ctx context.Context, bundleLabel string) (string, err })) if err != nil { log.Error(err, "GetBundle RPC failed") - return "", err + return "", nil, err } base.LogResponsePayload(log, resp.Msg) - return c.getBundleFile(logr.NewContext(ctx, log), resp.Msg.BundleInfo) + path, err := c.getBundleFile(logr.NewContext(ctx, log), resp.Msg.BundleInfo) + if err != nil { + return "", nil, err + } + + return path, resp.Msg.BundleInfo.EncryptionKey, nil } func (c *Client) WatchBundle(ctx context.Context, bundleLabel string) (bundle.WatchHandle, error) { diff --git a/bundle/v2/client_test.go b/bundle/v2/client_test.go index bacffae..ecfa0b5 100644 --- a/bundle/v2/client_test.go +++ b/bundle/v2/client_test.go @@ -155,6 +155,7 @@ func TestGetBundle(t *testing.T) { expectIssueAccessToken(mockAPIKeySvc) + wantEncryptionKey := []byte("secret") mockBundleSvc.EXPECT(). GetBundle(mock.Anything, mock.MatchedBy(getBundleReq("label"))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ @@ -162,7 +163,7 @@ func TestGetBundle(t *testing.T) { Label: "label", InputHash: hash("input"), OutputHash: wantChecksum, - EncryptionKey: []byte("secret"), + EncryptionKey: wantEncryptionKey, Segments: []*bundlev2.BundleInfo_Segment{ { SegmentId: 1, @@ -174,8 +175,9 @@ func TestGetBundle(t *testing.T) { }), nil).Times(3) for i := 0; i < 3; i++ { - file, err := client.GetBundle(context.Background(), "label") + file, encryptionKey, err := client.GetBundle(context.Background(), "label") require.NoError(t, err) + require.Equal(t, wantEncryptionKey, encryptionKey) haveChecksum := checksum(t, file) require.Equal(t, wantChecksum, haveChecksum, "Checksum does not match") @@ -221,7 +223,7 @@ func TestGetBundle(t *testing.T) { }, }), nil) - file, err := client.GetBundle(context.Background(), "label") + file, _, err := client.GetBundle(context.Background(), "label") require.NoError(t, err) haveChecksum := checksum(t, file) @@ -270,7 +272,7 @@ func TestGetBundle(t *testing.T) { }), nil).Times(3) for i := 0; i < 3; i++ { - file, err := client.GetBundle(context.Background(), "label") + file, _, err := client.GetBundle(context.Background(), "label") require.NoError(t, err) haveChecksum := checksum(t, file) @@ -328,7 +330,7 @@ func TestGetBundle(t *testing.T) { }), nil).Times(3) for i := 0; i < 3; i++ { - file1, err := client.GetBundle(context.Background(), "label") + file1, _, err := client.GetBundle(context.Background(), "label") require.NoError(t, err) haveChecksum1 := checksum(t, file1) @@ -380,7 +382,7 @@ func TestGetBundle(t *testing.T) { }), nil).Times(3) for i := 0; i < 3; i++ { - file2, err := client.GetBundle(context.Background(), "label") + file2, _, err := client.GetBundle(context.Background(), "label") require.NoError(t, err) haveChecksum2 := checksum(t, file2) @@ -435,7 +437,7 @@ func TestGetBundle(t *testing.T) { }, }), nil).Once() - _, err := client.GetBundle(context.Background(), "label") + _, _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) require.Equal(t, 3, counter.getTotal(), "Total download count does not match") @@ -483,7 +485,7 @@ func TestGetBundle(t *testing.T) { }, }), nil).Once() - _, err := client.GetBundle(context.Background(), "label") + _, _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) require.Equal(t, 3, counter.getTotal(), "Total download count does not match") @@ -520,7 +522,7 @@ func TestGetBundle(t *testing.T) { }, }), nil).Once() - _, err := client.GetBundle(context.Background(), "label") + _, _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) require.Equal(t, 1, counter.getTotal(), "Total download count does not match") @@ -555,7 +557,7 @@ func TestGetBundle(t *testing.T) { }, }), nil).Once() - _, err := client.GetBundle(context.Background(), "label") + _, _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) }) @@ -571,7 +573,7 @@ func TestGetBundle(t *testing.T) { IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())). Return(nil, connect.NewError(connect.CodeUnauthenticated, errors.New("🙅"))) - _, err := client.GetBundle(context.Background(), "label") + _, _, err := client.GetBundle(context.Background(), "label") require.Error(t, err) require.ErrorIs(t, err, base.ErrAuthenticationFailed) })