Skip to content

Commit

Permalink
refactor: ServeFile 现在直接处理所有的错误不再返回
Browse files Browse the repository at this point in the history
caixw committed Feb 29, 2024

Verified

This commit was signed with the committer’s verified signature.
caixw caixw
1 parent c0110b8 commit 9995155
Showing 3 changed files with 28 additions and 16 deletions.
5 changes: 1 addition & 4 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -20,9 +20,6 @@ func BenchmarkServeFile(b *testing.B) {
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
r, _ := http.NewRequest(http.MethodGet, "/assets/mux.go", nil)
err := ServeFile(fsys, "mux.go", "", w, r)
if err != nil {
b.Errorf("测试出错,返回了以下错误 %s", err)
}
ServeFile(fsys, "mux.go", "", w, r)
}
}
26 changes: 21 additions & 5 deletions mux.go
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ package mux

import (
"bytes"
"errors"
"expvar"
"html"
"io"
@@ -95,7 +96,9 @@ func Trace(w http.ResponseWriter, r *http.Request, body bool) error {
//
// p 表示需要读取的文件名;
// index 表示 p 为目录时,默认读取的文件,为空表示 index.html;
func ServeFile(fsys fs.FS, p, index string, w http.ResponseWriter, r *http.Request) error {
//
// 如果不需要自定义 index 的功能,则可以直接使用 [http.ServeFile] 或是 [http.ServeFileFS]。
func ServeFile(fsys fs.FS, p, index string, w http.ResponseWriter, r *http.Request) {
if index == "" {
index = "index.html"
}
@@ -107,13 +110,15 @@ func ServeFile(fsys fs.FS, p, index string, w http.ResponseWriter, r *http.Reque
STAT:
f, err := fsys.Open(p)
if err != nil {
return err
toHTTPError(w, err) // 保持与 http.ServeContent 相同的处理方式
return
}
defer f.Close()

stat, err := f.Stat()
if err != nil {
return err
toHTTPError(w, err)
return
}
if stat.IsDir() {
p = path.Join(p, index)
@@ -125,13 +130,24 @@ STAT:
data := make([]byte, stat.Size())
size, err := f.Read(data)
if err != nil {
return err
toHTTPError(w, err)
return
}
rs = bytes.NewReader(data[:size])
}

http.ServeContent(w, r, filepath.Base(p), stat.ModTime(), rs)
return nil
}

func toHTTPError(w http.ResponseWriter, err error) {
switch {
case errors.Is(err, fs.ErrNotExist):
http.Error(w, err.Error(), http.StatusNotFound)
case errors.Is(err, fs.ErrPermission):
http.Error(w, err.Error(), http.StatusForbidden)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

// Debug 输出调试信息
13 changes: 6 additions & 7 deletions mux_test.go
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
package mux

import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
@@ -74,23 +73,23 @@ func TestServeFile(t *testing.T) {

w := httptest.NewRecorder()
r := rest.Get(a, "/assets/").Request()
a.NotError(ServeFile(fsys, "", "go.mod", w, r))
ServeFile(fsys, "", "go.mod", w, r)
a.Contains(w.Body.String(), "module github.com/issue9/mux")

w = httptest.NewRecorder()
r = rest.Get(a, "/assets/").Request()
a.NotError(ServeFile(fsys, "types/types.go", "", w, r))
ServeFile(fsys, "types/types.go", "", w, r)
a.NotEmpty(w.Body.String())

w = httptest.NewRecorder()
r = rest.Get(a, "/assets/").Request()
a.ErrorIs(ServeFile(fsys, "types/", "", w, r), fs.ErrNotExist)
a.Empty(w.Body.String())
ServeFile(fsys, "types/", "", w, r)
a.Equal(w.Result().StatusCode, http.StatusNotFound)

w = httptest.NewRecorder()
r = rest.Get(a, "/assets/").Request()
a.ErrorIs(ServeFile(fsys, "not-exists", "", w, r), fs.ErrNotExist)
a.Empty(w.Body.String())
ServeFile(fsys, "not-exists", "", w, r)
a.Equal(w.Result().StatusCode, http.StatusNotFound)
}

func TestDebug(t *testing.T) {

0 comments on commit 9995155

Please sign in to comment.