From cf42742296014459466a55b0c9193b8514ba149f Mon Sep 17 00:00:00 2001 From: lyric Date: Wed, 19 Jul 2017 01:07:57 +0800 Subject: [PATCH] fixed uploadDo --- README.md | 10 +++++-- fuh_test.go | 75 +++++++++++++++++++++++++++++++++++++++++------------ handler.go | 8 ++++++ 3 files changed, 75 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 15866d4..4ad090e 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ import ( "context" "encoding/json" "net/http" + "path/filepath" "github.com/LyricTian/fuh" ) @@ -28,11 +29,16 @@ import ( func main() { uploader := fuh.NewUploader(&fuh.Config{ BasePath: "attach", - SizeLimit: 10 << 20, + SizeLimit: 1 << 20, }, fuh.NewFileStore()) http.HandleFunc("/fileupload", func(w http.ResponseWriter, r *http.Request) { - finfos, err := uploader.Upload(context.Background(), r, "file") + + ctx := fuh.NewFileNameContext(context.Background(), func(ci fuh.ContextInfo) string { + return filepath.Join(ci.BasePath(), ci.FileName()) + }) + + finfos, err := uploader.Upload(ctx, r, "file") if err != nil { w.WriteHeader(500) return diff --git a/fuh_test.go b/fuh_test.go index 09caa1c..7cd1bec 100644 --- a/fuh_test.go +++ b/fuh_test.go @@ -4,12 +4,14 @@ import ( "bytes" "context" "io/ioutil" + "log" "mime/multipart" "net/http" "net/http/httptest" "os" "path/filepath" "testing" + "time" "github.com/LyricTian/fuh" . "github.com/smartystreets/goconvey/convey" @@ -22,7 +24,8 @@ func TestFileUpload(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Convey("single file upload test", t, func() { - uploader := fuh.NewUploader(&fuh.Config{BasePath: basePath}, fuh.NewFileStore()) + config := &fuh.Config{BasePath: basePath, SizeLimit: 1 << 20, MaxMemory: 10 << 20} + uploader := fuh.NewUploader(config, fuh.NewFileStore()) fileInfos, err := uploader.Upload(nil, r, "file") So(err, ShouldBeNil) @@ -43,10 +46,7 @@ func TestFileUpload(t *testing.T) { })) defer srv.Close() - err := postFile(srv.URL, buf, filename) - if err != nil { - t.Error(err.Error()) - } + postFile(srv.URL, buf, filename) } func TestCustomFileName(t *testing.T) { @@ -82,10 +82,7 @@ func TestCustomFileName(t *testing.T) { })) defer srv.Close() - err := postFile(srv.URL, buf, filename) - if err != nil { - t.Error(err.Error()) - } + postFile(srv.URL, buf, filename) } func TestFileSizeLimit(t *testing.T) { @@ -109,20 +106,66 @@ func TestFileSizeLimit(t *testing.T) { })) defer srv.Close() - err := postFile(srv.URL, buf, filename) - if err != nil { - t.Error(err.Error()) + postFile(srv.URL, buf, filename) +} + +func TestUploadTimeout(t *testing.T) { + basePath := "testdatas/" + filename := "uploadtimeout_test.txt" + buf := []byte("abc") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Convey("file upload timeout test", t, func() { + uploader := fuh.NewUploader(&fuh.Config{BasePath: basePath}, fuh.NewFileStore()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond*10) + defer cancel() + + fileInfos, err := uploader.Upload(ctx, r, "file") + So(fileInfos, ShouldBeNil) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, context.DeadlineExceeded.Error()) + }) + })) + defer srv.Close() + + postFile(srv.URL, buf, filename) +} + +func TestFileSize(t *testing.T) { + basePath := "testdatas/" + filename := "filesize_test.txt" + buf := make([]byte, 1024) + for i := 0; i < 1024; i++ { + buf[i] = '0' } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Convey("file size test", t, func() { + uploader := fuh.NewUploader(&fuh.Config{BasePath: basePath, MaxMemory: 128}, fuh.NewFileStore()) + + defer os.Remove(filepath.Join(basePath, filename)) + + fileInfos, err := uploader.Upload(context.Background(), r, "file") + So(err, ShouldBeNil) + So(len(fileInfos), ShouldEqual, 1) + So(fileInfos[0].Size(), ShouldEqual, len(buf)) + + }) + })) + defer srv.Close() + + postFile(srv.URL, buf, filename) } -func postFile(targetURL string, data []byte, filenames ...string) (err error) { +func postFile(targetURL string, data []byte, filenames ...string) { bodyBuf := &bytes.Buffer{} bodyWriter := multipart.NewWriter(bodyBuf) for _, filename := range filenames { - fileWriter, verr := bodyWriter.CreateFormFile("file", filename) - if verr != nil { - err = verr + fileWriter, err := bodyWriter.CreateFormFile("file", filename) + if err != nil { + log.Println(err.Error()) return } fileWriter.Write(data) diff --git a/handler.go b/handler.go index a2d185a..40178d2 100644 --- a/handler.go +++ b/handler.go @@ -92,6 +92,7 @@ func (u *uploadHandle) maxMemory() int64 { func (u *uploadHandle) uploadDo(ctx context.Context, r *http.Request, fheader *multipart.FileHeader, f func(FileInfo)) error { c := make(chan error, 1) + ctxDone := make(chan struct{}) go func() { file, err := fheader.Open() @@ -139,6 +140,12 @@ func (u *uploadHandle) uploadDo(ctx context.Context, r *http.Request, fheader *m return } + select { + case <-ctxDone: + return + default: + } + var fInfo FileInfo = &fileInfo{ fullName: fullName, size: fsize, @@ -150,6 +157,7 @@ func (u *uploadHandle) uploadDo(ctx context.Context, r *http.Request, fheader *m select { case <-ctx.Done(): + close(ctxDone) return ctx.Err() case err := <-c: return err