Skip to content

Commit

Permalink
Merge pull request #19 from Chia-Network/multipart-upload
Browse files Browse the repository at this point in the history
Add s3 multipart upload function
  • Loading branch information
Starttoaster authored Apr 18, 2024
2 parents 1d6d681 + 55d5311 commit 913faf4
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 0 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/chia-network/go-modules
go 1.19

require (
github.com/aws/aws-sdk-go v1.51.22
github.com/lestrrat-go/jwx v1.2.29
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.19.0
Expand All @@ -13,6 +14,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
Expand Down
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/aws/aws-sdk-go v1.51.22 h1:VL2p2JgC32myt7DMEcbe1devdtgGSgMNvZpkcdvlxq4=
github.com/aws/aws-sdk-go v1.51.22/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
Expand All @@ -11,6 +13,10 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A=
github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
Expand Down Expand Up @@ -59,6 +65,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -84,6 +91,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
Expand All @@ -93,6 +101,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
30 changes: 30 additions & 0 deletions pkg/amazon/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Amazon

This package provides some useful abstractions for interacting with AWS.

## MultPartUpload

### Usage

```go
// Get S3 client from region, and AWS API keypair
sess, err := amazon.NewS3Client("region-name", "aws-key-id", "aws-key-secret")
if err != nil {
return err
}

// Upload a local file in parts
err = amazon.MultiPartUpload(amazon.MultiPartUploadInput{
Ctx: context.Background(), // Required: The context for this request
CtxTimeout: 10 * time.Minute, // Optional: The request will time out after this duration (defaults to 60 minutes)
Svc: sess, // Required: An AWS S3 session service for the upload
Filepath: "./file.txt", // Required: A full path to a local file to PUT to S3
DestinationBucket: "my-bucket", // Required: The destination S3 bucket's name
DestinationKey: "file.txt", // Required: The destination path in the bucket to put the file
MaxConcurrent: 3, // Optional: The number of concurrent part uploads (defaults to 10)
PartSize: 8388608, // Optional: Number of bytes (defaults to 8MB)
})
if err != nil {
return err
}
```
202 changes: 202 additions & 0 deletions pkg/amazon/s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package amazon

import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"os"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
)

// MultiPartUploadInput holds the inputs for a multipart upload
type MultiPartUploadInput struct {
Svc *s3.S3 // Required: An AWS S3 session service for the upload
Ctx context.Context // Required: The context for this request
CtxTimeout time.Duration // Optional: The request will time out after this duration (defaults to 60 minutes)
MaxConcurrent int // Optional: The number of concurrent part uploads (defaults to 10)
PartSize int64 // Optional: Number of bytes (defaults to 8MB)
Filepath string // Required: A full path to a local file to PUT to S3
DestinationBucket string // Required: The destination S3 bucket's name
DestinationKey string // Required: The destination path in the bucket to put the file
Logger *slog.Logger // Optional: Handles logging if supplied
}

// MultiPartUploadResult holds the result for an individual part upload
type MultiPartUploadResult struct {
Error error
Part *s3.CompletedPart
}

// MultiPartUpload uploads a local file in multiple parts to AWS S3
func MultiPartUpload(input MultiPartUploadInput) error {
// Exit if no S3 service given
if input.Svc == nil {
return fmt.Errorf("s3 service nil -- is a required option")
}
// Set part size to default 8MB if no part size specified or less than 5MB
if input.PartSize < 5242880 {
input.PartSize = 8 * 1024 * 1024
}
// Make sure max concurrent is at least 1, default to 10 if unspecified or less than 1
if input.MaxConcurrent < 1 {
input.MaxConcurrent = 10
}
// Set timeout to 60 minutes if not specified or zero value
if input.CtxTimeout == 0 {
input.CtxTimeout = 60 * time.Minute
}

// Set up context with timeout
ctx, cancelFn := context.WithTimeout(input.Ctx, input.CtxTimeout)
defer cancelFn()

// Open local file
file, err := os.Open(input.Filepath)
if err != nil {
return fmt.Errorf("error opening file: %w", err)
}
defer func() {
err := file.Close()
if err != nil {
if input.Logger != nil {
input.Logger.Error("encountered error closing file", "path", input.Filepath)
}
}
}()

// Get file and total file size
fileInfo, err := file.Stat()
if err != nil {
return fmt.Errorf("error getting file info: %w", err)
}
fileSize := fileInfo.Size()

// Initialize a multipart upload and get an upload ID back
multipartUpload, err := input.Svc.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
Bucket: &input.DestinationBucket,
Key: &input.DestinationKey,
})
if err != nil {
return fmt.Errorf("error creating multipart upload: %w", err)
}

// Record the upload ID from the multipart upload
var uploadID string
if multipartUpload != nil {
if multipartUpload.UploadId != nil {
if *multipartUpload.UploadId == "" {
return errors.New("no upload ID returned in start upload request -- something wrong with the client or credentials?")
}

uploadID = *multipartUpload.UploadId
}
}

// Get the total number of parts we will upload
numParts := getTotalNumberParts(fileSize, input.PartSize)
if input.Logger != nil {
input.Logger.Debug("will upload file in parts", "file", input.Filepath, "parts", numParts)
}

var (
wg sync.WaitGroup
ch = make(chan error, numParts)
sem = make(chan struct{}, input.MaxConcurrent)
)

// Start the individual part uploads
orderedParts := make([]*s3.CompletedPart, numParts)
for i := int64(0); i < numParts; i++ {
partNumber := i + 1
offset := i * input.PartSize
bytesToRead := min(input.PartSize, fileSize-offset)

partBuffer := make([]byte, bytesToRead)
_, err := file.ReadAt(partBuffer, offset)
if err != nil {
return err
}

wg.Add(1)
go func(partNumber int64, partBuffer []byte) {
sem <- struct{}{}
defer func() {
<-sem
}()
defer wg.Done()

if input.Logger != nil {
input.Logger.Debug("uploading file part", "file", input.Filepath, "part", partNumber, "size", len(partBuffer))
}

resp, err := input.Svc.UploadPart(&s3.UploadPartInput{
Bucket: aws.String(input.DestinationBucket),
Key: aws.String(input.DestinationKey),
UploadId: &uploadID,
PartNumber: aws.Int64(partNumber),
Body: bytes.NewReader(partBuffer),
})
if err != nil {
ch <- fmt.Errorf("error uploading part %d : %w", partNumber, err)
return
}

// Store the completed part in the uploadParts slice
orderedParts[partNumber-1] = &s3.CompletedPart{
ETag: resp.ETag,
PartNumber: aws.Int64(partNumber),
}

if input.Logger != nil {
input.Logger.Debug("finished uploading file part", "file", input.Filepath, "part", partNumber, "size", len(partBuffer))
}
}(partNumber, partBuffer)
}

wg.Wait()

// Check for errors from goroutines
select {
case err := <-ch:
return err
default:
// No errors
}
close(ch)

// Make a final call to AWS to say the file upload is complete
// The file won't show up in S3 unless this is called
_, err = input.Svc.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: &input.DestinationBucket,
Key: &input.DestinationKey,
UploadId: &uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{
Parts: orderedParts,
},
})
if err != nil {
return fmt.Errorf("error completing upload: %w", err)
}

return nil
}

func min(a, b int64) int64 {
if a < b {
return a
}
return b
}

func getTotalNumberParts(filesize int64, partsize int64) int64 {
if filesize%partsize == 0 {
return filesize / partsize
}
return filesize/partsize + 1
}
24 changes: 24 additions & 0 deletions pkg/amazon/s3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package amazon

import "testing"

func TestGetTotalNumberParts(t *testing.T) {
tests := []struct {
FileSize int64
PartSize int64
Expected int64
}{
{5, 2, 3}, // Test quotient where division produced a non-whole number
{2, 1, 2}, // Test quotient where division produced a whole number
// Test with a couple much larger numbers
{104857600, 5242880, 20}, // 100MB file, 5MB chunks, 20 total chunks
{132070244351, 8388200, 15745}, // ~123GB file, ~8MB chunks, 15745 total chunks
}

for _, test := range tests {
result := getTotalNumberParts(test.FileSize, test.PartSize)
if result != test.Expected {
t.Errorf("operation failed for %d / %d. Expected %d, got %d", test.FileSize, test.PartSize, test.Expected, result)
}
}
}

0 comments on commit 913faf4

Please sign in to comment.