Skip to content

Commit

Permalink
Support Square Bracket Notation in Multipart Form data (#3268)
Browse files Browse the repository at this point in the history
* Feature Request: Support Square Bracket Notation in Multipart Form Data #3224

* Feature Request: Support Square Bracket Notation in Multipart Form Data #3224
  • Loading branch information
ReneWerner87 authored Dec 31, 2024
1 parent 47be681 commit 7eb9d25
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 92 deletions.
113 changes: 29 additions & 84 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,28 +406,30 @@ func (c *Ctx) BodyParser(out interface{}) error {
k := c.app.getString(key)
v := c.app.getString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, bodyTag) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatParserData(out, data, bodyTag, k, v, c.app.config.EnableSplittingOnParsers, true)
})

if err != nil {
return err
}

return c.parseToStruct(bodyTag, out, data)
}
if strings.HasPrefix(ctype, MIMEMultipartForm) {
data, err := c.fasthttp.MultipartForm()
multipartForm, err := c.fasthttp.MultipartForm()
if err != nil {
return err
}
return c.parseToStruct(bodyTag, out, data.Value)

data := make(map[string][]string)
for key, values := range multipartForm.Value {
err = formatParserData(out, data, bodyTag, key, values, c.app.config.EnableSplittingOnParsers, true)
if err != nil {
return err
}
}

return c.parseToStruct(bodyTag, out, data)
}
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
if err := xml.Unmarshal(c.Body(), out); err != nil {
Expand Down Expand Up @@ -531,18 +533,7 @@ func (c *Ctx) CookieParser(out interface{}) error {
k := c.app.getString(key)
v := c.app.getString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, cookieTag) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatParserData(out, data, cookieTag, k, v, c.app.config.EnableSplittingOnParsers, true)
})
if err != nil {
return err
Expand Down Expand Up @@ -1283,18 +1274,7 @@ func (c *Ctx) QueryParser(out interface{}) error {
k := c.app.getString(key)
v := c.app.getString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, queryTag) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatParserData(out, data, queryTag, k, v, c.app.config.EnableSplittingOnParsers, true)
})

if err != nil {
Expand All @@ -1304,61 +1284,26 @@ func (c *Ctx) QueryParser(out interface{}) error {
return c.parseToStruct(queryTag, out, data)
}

func parseParamSquareBrackets(k string) (string, error) {
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)

kbytes := []byte(k)
openBracketsCount := 0

for i, b := range kbytes {
if b == '[' {
openBracketsCount++
if i+1 < len(kbytes) && kbytes[i+1] != ']' {
if err := bb.WriteByte('.'); err != nil {
return "", fmt.Errorf("failed to write: %w", err)
}
}
continue
}

if b == ']' {
openBracketsCount--
if openBracketsCount < 0 {
return "", errors.New("unmatched brackets")
}
continue
}

if err := bb.WriteByte(b); err != nil {
return "", fmt.Errorf("failed to write: %w", err)
}
}

if openBracketsCount > 0 {
return "", errors.New("unmatched brackets")
}

return bb.String(), nil
}

// ReqHeaderParser binds the request header strings to a struct.
func (c *Ctx) ReqHeaderParser(out interface{}) error {
data := make(map[string][]string)
var err error

c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}

k := c.app.getString(key)
v := c.app.getString(val)

if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, reqHeaderTag) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatParserData(out, data, reqHeaderTag, k, v, c.app.config.EnableSplittingOnParsers, false)
})

if err != nil {
return err
}

return c.parseToStruct(reqHeaderTag, out, data)
}

Expand Down
42 changes: 42 additions & 0 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,48 @@ func Test_Ctx_BodyParser(t *testing.T) {
utils.AssertEqual(t, 2, len(cq.Data))
utils.AssertEqual(t, "john", cq.Data[0].Name)
utils.AssertEqual(t, "doe", cq.Data[1].Name)

t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
utils.AssertEqual(t, nil, writer.WriteField("data.0.name", "john"))
utils.AssertEqual(t, nil, writer.WriteField("data.1.name", "doe"))
utils.AssertEqual(t, nil, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
utils.AssertEqual(t, nil, c.BodyParser(cq))
utils.AssertEqual(t, len(cq.Data), 2)
utils.AssertEqual(t, "john", cq.Data[0].Name)
utils.AssertEqual(t, "doe", cq.Data[1].Name)
})

t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
utils.AssertEqual(t, nil, writer.WriteField("data[0][name]", "john"))
utils.AssertEqual(t, nil, writer.WriteField("data[1][name]", "doe"))
utils.AssertEqual(t, nil, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
utils.AssertEqual(t, nil, c.BodyParser(cq))
utils.AssertEqual(t, len(cq.Data), 2)
utils.AssertEqual(t, "john", cq.Data[0].Name)
utils.AssertEqual(t, "doe", cq.Data[1].Name)
})
}

func Test_Ctx_ParamParser(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,4 @@ require (
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/tools v0.22.0 // indirect
)
6 changes: 0 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1Gsh
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/tinylib/msgp v1.1.3 h1:3giwAkmtaEDLSV0MdO1lDLuPgklgPzmk8H9+So2BVfA=
github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po=
github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
Expand All @@ -25,11 +23,7 @@ github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1S
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
73 changes: 73 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package fiber
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"hash/crc32"
"io"
Expand Down Expand Up @@ -1151,3 +1152,75 @@ func IndexRune(str string, needle int32) bool {
}
return false
}

func parseParamSquareBrackets(k string) (string, error) {
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)

kbytes := []byte(k)
openBracketsCount := 0

for i, b := range kbytes {
if b == '[' {
openBracketsCount++
if i+1 < len(kbytes) && kbytes[i+1] != ']' {
if err := bb.WriteByte('.'); err != nil {
return "", fmt.Errorf("failed to write: %w", err)
}
}
continue
}

if b == ']' {
openBracketsCount--
if openBracketsCount < 0 {
return "", errors.New("unmatched brackets")
}
continue
}

if err := bb.WriteByte(b); err != nil {
return "", fmt.Errorf("failed to write: %w", err)
}
}

if openBracketsCount > 0 {
return "", errors.New("unmatched brackets")
}

return bb.String(), nil
}

func formatParserData(out interface{}, data map[string][]string, aliasTag, key string, value interface{}, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
if err != nil {
return err
}
}

switch v := value.(type) {
case string:
assignBindData(out, data, aliasTag, key, v, enableSplitting)
case []string:
for _, val := range v {
assignBindData(out, data, aliasTag, key, val, enableSplitting)
}
default:
return fmt.Errorf("unsupported value type: %T", value)
}

return err
}

func assignBindData(out interface{}, data map[string][]string, aliasTag, key, value string, enableSplitting bool) { //nolint:revive // it's okay
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key, aliasTag) {
values := strings.Split(value, ",")
for i := 0; i < len(values); i++ {
data[key] = append(data[key], values[i])
}
} else {
data[key] = append(data[key], value)
}
}

0 comments on commit 7eb9d25

Please sign in to comment.