From ef04a8a99e0b7404b574e371dc2a341acd5a017d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Tue, 31 Dec 2024 18:34:28 +0300 Subject: [PATCH] :bug: bug: Fix square bracket notation in Multipart FormData (#3235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :bug: bug: add square bracket notation support to BindMultipart * Fix golangci-lint issues * Fixing undef variable * Fix more lint issues * test * update1 * improve coverage * fix linter * reduce code duplication * reduce code duplications in bindMultipart --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: René --- bind_test.go | 103 ++++++++++++++++++++++++++++++++++++++++-- binder/cookie.go | 13 +----- binder/form.go | 29 +++++------- binder/form_test.go | 16 ++++++- binder/header.go | 22 ++++----- binder/mapping.go | 38 +++++++++++++++- binder/query.go | 17 +------ binder/resp_header.go | 23 +++++----- 8 files changed, 186 insertions(+), 75 deletions(-) diff --git a/bind_test.go b/bind_test.go index 52c9004c61..b01086e623 100644 --- a/bind_test.go +++ b/bind_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "mime/multipart" "net/http/httptest" "reflect" "testing" @@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) { reqBody := []byte(`{"name":"john"}`) type Demo struct { - Name string `json:"name" xml:"name" form:"name" query:"name"` + Name string `json:"name" xml:"name" form:"name" query:"name"` + Names []string `json:"names" xml:"names" form:"names" query:"names"` } // Helper function to test compressed bodies @@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) { Data []Demo `query:"data"` } + t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(t, writer.WriteField("data.0.name", "john")) + require.NoError(t, writer.WriteField("data.1.name", "doe")) + require.NoError(t, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + require.NoError(t, c.Bind().Body(cq)) + require.Len(t, cq.Data, 2) + require.Equal(t, "john", cq.Data[0].Name) + require.Equal(t, "doe", cq.Data[1].Name) + }) + + t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Reset() + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(t, writer.WriteField("data[0][name]", "john")) + require.NoError(t, writer.WriteField("data[1][name]", "doe")) + require.NoError(t, writer.Close()) + + c.Request().Header.SetContentType(writer.FormDataContentType()) + c.Request().SetBody(buf.Bytes()) + c.Request().Header.SetContentLength(len(c.Body())) + + cq := new(CollectionQuery) + require.NoError(t, c.Bind().Body(cq)) + require.Len(t, cq.Data, 2) + require.Equal(t, "john", cq.Data[0].Name) + require.Equal(t, "doe", cq.Data[1].Name) + }) + t.Run("CollectionQuerySquareBrackets", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() @@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) { Name string `form:"name"` } - body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(b, writer.WriteField("name", "john")) + require.NoError(b, writer.Close()) + body := buf.Bytes() + + c.Request().SetBody(body) + c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) + c.Request().Header.SetContentLength(len(body)) + d := new(Demo) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + err = c.Bind().Body(d) + } + + require.NoError(b, err) + require.Equal(b, "john", d.Name) +} + +// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4 +func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) { + var err error + + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + + type Person struct { + Name string `form:"name"` + Age int `form:"age"` + } + + type Demo struct { + Name string `form:"name"` + Persons []Person `form:"persons"` + } + + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + require.NoError(b, writer.WriteField("name", "john")) + require.NoError(b, writer.WriteField("persons.0.name", "john")) + require.NoError(b, writer.WriteField("persons[0][age]", "10")) + require.NoError(b, writer.WriteField("persons[1][name]", "doe")) + require.NoError(b, writer.WriteField("persons.1.age", "20")) + require.NoError(b, writer.Close()) + body := buf.Bytes() + c.Request().SetBody(body) - c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) + c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) c.Request().Header.SetContentLength(len(body)) d := new(Demo) @@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) { for n := 0; n < b.N; n++ { err = c.Bind().Body(d) } + require.NoError(b, err) require.Equal(b, "john", d.Name) + require.Equal(b, "john", d.Persons[0].Name) + require.Equal(b, 10, d.Persons[0].Age) + require.Equal(b, "doe", d.Persons[1].Name) + require.Equal(b, 20, d.Persons[1].Age) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4 diff --git a/binder/cookie.go b/binder/cookie.go index 230794f45a..5b9ccf1ed3 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) if err != nil { diff --git a/binder/form.go b/binder/form.go index 7ab0b1b258..a8f5b85270 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, true) }) if err != nil { @@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { // bindMultipart parses the request body and returns the result. func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { - data, err := req.MultipartForm() + multipartForm, err := req.MultipartForm() if err != nil { return err } - return parse(b.Name(), out, data.Value) + data := make(map[string][]string) + for key, values := range multipartForm.Value { + err = formatBindData(out, data, key, values, b.EnableSplitting, true) + if err != nil { + return err + } + } + + return parse(b.Name(), out, data) } // Reset resets the FormBinding binder. diff --git a/binder/form_test.go b/binder/form_test.go index c3c52c73fd..55023cb30f 100644 --- a/binder/form_test.go +++ b/binder/form_test.go @@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) { } require.Equal(t, "form", b.Name()) + type Post struct { + Title string `form:"title"` + } + type User struct { Name string `form:"name"` Names []string `form:"names"` + Posts []Post `form:"posts"` Age int `form:"age"` } var user User @@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) { mw := multipart.NewWriter(buf) require.NoError(t, mw.WriteField("name", "john")) - require.NoError(t, mw.WriteField("names", "john")) + require.NoError(t, mw.WriteField("names", "john,eric")) require.NoError(t, mw.WriteField("names", "doe")) require.NoError(t, mw.WriteField("age", "42")) + require.NoError(t, mw.WriteField("posts[0][title]", "post1")) + require.NoError(t, mw.WriteField("posts[1][title]", "post2")) + require.NoError(t, mw.WriteField("posts[2][title]", "post3")) + require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) @@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) { require.Equal(t, 42, user.Age) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") + require.Contains(t, user.Names, "eric") + require.Len(t, user.Posts, 3) + require.Equal(t, "post1", user.Posts[0].Title) + require.Equal(t, "post2", user.Posts[1].Title) + require.Equal(t, "post3", user.Posts[2].Title) } func Benchmark_FormBinder_BindMultipart(b *testing.B) { diff --git a/binder/header.go b/binder/header.go index b04ce9add3..763be56795 100644 --- a/binder/header.go +++ b/binder/header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string { // Bind parses the request header and returns the result. func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) + var err error req.Header.VisitAll(func(key, val []byte) { + if err != nil { + return + } + k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) + if err != nil { + return err + } + return parse(b.Name(), out, data) } diff --git a/binder/mapping.go b/binder/mapping.go index 29b5550b10..70cb9cbc2d 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error { elem := reflect.TypeOf(ptr).Elem() - switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types + switch elem.Kind() { case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { @@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error { } newMap[k] = v[len(v)-1] } + default: + return nil // it's not necessary to check all types } return nil @@ -247,3 +249,37 @@ func FilterFlags(content string) string { } return content } + +func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay + var err error + if supportBracketNotation && strings.Contains(key, "[") { + key, err = parseParamSquareBrackets(key) + if err != nil { + return err + } + } + + switch v := any(value).(type) { + case string: + assignBindData(out, data, key, v, enableSplitting) + case []string: + for _, val := range v { + assignBindData(out, data, key, val, enableSplitting) + } + default: + return fmt.Errorf("unsupported value type: %T", value) + } + + return err +} + +func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay + if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) { + values := strings.Split(value, ",") + for i := 0; i < len(values); i++ { + data[key] = append(data[key], values[i]) + } + } else { + data[key] = append(data[key], value) + } +} diff --git a/binder/query.go b/binder/query.go index 9ee500ba63..d2ac309215 100644 --- a/binder/query.go +++ b/binder/query.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if strings.Contains(k, "[") { - k, err = parseParamSquareBrackets(k) - } - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, true) }) if err != nil { diff --git a/binder/resp_header.go b/binder/resp_header.go index fc84d01402..cb29e99d6f 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -1,9 +1,6 @@ package binder import ( - "reflect" - "strings" - "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -21,20 +18,22 @@ func (*RespHeaderBinding) Name() string { // Bind parses the response header and returns the result. func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { data := make(map[string][]string) + var err error + resp.Header.VisitAll(func(key, val []byte) { + if err != nil { + return + } + k := utils.UnsafeString(key) v := utils.UnsafeString(val) - - if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) { - values := strings.Split(v, ",") - for i := 0; i < len(values); i++ { - data[k] = append(data[k], values[i]) - } - } else { - data[k] = append(data[k], v) - } + err = formatBindData(out, data, k, v, b.EnableSplitting, false) }) + if err != nil { + return err + } + return parse(b.Name(), out, data) }