Skip to content

Commit

Permalink
🐛 bug: Fix square bracket notation in Multipart FormData (#3235)
Browse files Browse the repository at this point in the history
* 🐛 bug: add square bracket notation support to BindMultipart

* Fix golangci-lint issues

* Fixing undef variable

* Fix more lint issues

* test

* update1

* improve coverage

* fix linter

* reduce code duplication

* reduce code duplications in bindMultipart

---------

Co-authored-by: Juan Calderon-Perez <[email protected]>
Co-authored-by: René <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2024
1 parent d0e767f commit ef04a8a
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 75 deletions.
103 changes: 100 additions & 3 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"net/http/httptest"
"reflect"
"testing"
Expand Down Expand Up @@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) {
reqBody := []byte(`{"name":"john"}`)

type Demo struct {
Name string `json:"name" xml:"name" form:"name" query:"name"`
Name string `json:"name" xml:"name" form:"name" query:"name"`
Names []string `json:"names" xml:"names" form:"names" query:"names"`
}

// Helper function to test compressed bodies
Expand Down Expand Up @@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
Data []Demo `query:"data"`
}

t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data.0.name", "john"))
require.NoError(t, writer.WriteField("data.1.name", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data[0][name]", "john"))
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
Expand Down Expand Up @@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
Name string `form:"name"`
}

body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
var err error

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})

type Person struct {
Name string `form:"name"`
Age int `form:"age"`
}

type Demo struct {
Name string `form:"name"`
Persons []Person `form:"persons"`
}

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.WriteField("persons.0.name", "john"))
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
require.NoError(b, writer.WriteField("persons.1.age", "20"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

Expand All @@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
require.Equal(b, "john", d.Persons[0].Name)
require.Equal(b, 10, d.Persons[0].Age)
require.Equal(b, "doe", d.Persons[1].Name)
require.Equal(b, 20, d.Persons[1].Age)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
Expand Down
13 changes: 1 addition & 12 deletions binder/cookie.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
Expand Down
29 changes: 11 additions & 18 deletions binder/form.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand All @@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {

// bindMultipart parses the request body and returns the result.
func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
data, err := req.MultipartForm()
multipartForm, err := req.MultipartForm()
if err != nil {
return err
}

return parse(b.Name(), out, data.Value)
data := make(map[string][]string)
for key, values := range multipartForm.Value {
err = formatBindData(out, data, key, values, b.EnableSplitting, true)
if err != nil {
return err
}
}

return parse(b.Name(), out, data)
}

// Reset resets the FormBinding binder.
Expand Down
16 changes: 15 additions & 1 deletion binder/form_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
}
require.Equal(t, "form", b.Name())

type Post struct {
Title string `form:"title"`
}

type User struct {
Name string `form:"name"`
Names []string `form:"names"`
Posts []Post `form:"posts"`
Age int `form:"age"`
}
var user User
Expand All @@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
mw := multipart.NewWriter(buf)

require.NoError(t, mw.WriteField("name", "john"))
require.NoError(t, mw.WriteField("names", "john"))
require.NoError(t, mw.WriteField("names", "john,eric"))
require.NoError(t, mw.WriteField("names", "doe"))
require.NoError(t, mw.WriteField("age", "42"))
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))

require.NoError(t, mw.Close())

req.Header.SetContentType(mw.FormDataContentType())
Expand All @@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.Equal(t, 42, user.Age)
require.Contains(t, user.Names, "john")
require.Contains(t, user.Names, "doe")
require.Contains(t, user.Names, "eric")
require.Len(t, user.Posts, 3)
require.Equal(t, "post1", user.Posts[0].Title)
require.Equal(t, "post2", user.Posts[1].Title)
require.Equal(t, "post3", user.Posts[2].Title)
}

func Benchmark_FormBinder_BindMultipart(b *testing.B) {
Expand Down
22 changes: 10 additions & 12 deletions binder/header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string {
// Bind parses the request header and returns the result.
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
data := make(map[string][]string)
var err error
req.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
return err
}

return parse(b.Name(), out, data)
}

Expand Down
38 changes: 37 additions & 1 deletion binder/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
func parseToMap(ptr any, data map[string][]string) error {
elem := reflect.TypeOf(ptr).Elem()

switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types
switch elem.Kind() {
case reflect.Slice:
newMap, ok := ptr.(map[string][]string)
if !ok {
Expand All @@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error {
}
newMap[k] = v[len(v)-1]
}
default:
return nil // it's not necessary to check all types
}

return nil
Expand Down Expand Up @@ -247,3 +249,37 @@ func FilterFlags(content string) string {
}
return content
}

func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
if err != nil {
return err
}
}

switch v := any(value).(type) {
case string:
assignBindData(out, data, key, v, enableSplitting)
case []string:
for _, val := range v {
assignBindData(out, data, key, val, enableSplitting)
}
default:
return fmt.Errorf("unsupported value type: %T", value)
}

return err
}

func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
values := strings.Split(value, ",")
for i := 0; i < len(values); i++ {
data[key] = append(data[key], values[i])
}
} else {
data[key] = append(data[key], value)
}
}
17 changes: 1 addition & 16 deletions binder/query.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand Down
Loading

1 comment on commit ef04a8a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.50.

Benchmark suite Current: ef04a8a Previous: 47be681 Ratio
Benchmark_Utils_GetOffer/1_parameter 223.9 ns/op 0 B/op 0 allocs/op 131.7 ns/op 0 B/op 0 allocs/op 1.70
Benchmark_Utils_GetOffer/1_parameter - ns/op 223.9 ns/op 131.7 ns/op 1.70
`Benchmark_RoutePatternMatch//api/:param/fixedEnd_ not_match _/api/abc/def/fixedEnd - allocs/op` 14 allocs/op
Benchmark_Middleware_BasicAuth - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_Middleware_BasicAuth_Upper - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth_Upper - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_CORS_NewHandler - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandler - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflight - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflight - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_Middleware_CSRF_GenerateToken - B/op 517 B/op 332 B/op 1.56
Benchmark_Middleware_CSRF_GenerateToken - allocs/op 10 allocs/op 6 allocs/op 1.67

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.