-
-
Notifications
You must be signed in to change notification settings - Fork 45
/
keyfunc.go
177 lines (160 loc) · 4.88 KB
/
keyfunc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
package keyfunc
import (
"context"
"crypto"
"encoding/json"
"errors"
"fmt"
"github.com/MicahParks/jwkset"
"github.com/golang-jwt/jwt/v5"
)
var (
// ErrKeyfunc is returned when a keyfunc error occurs.
ErrKeyfunc = errors.New("failed keyfunc")
)
// Keyfunc is meant to be used as the jwt.Keyfunc function for github.com/golang-jwt/jwt/v5. It uses
// github.com/MicahParks/jwkset as a JWK Set storage.
type Keyfunc interface {
Keyfunc(token *jwt.Token) (any, error)
KeyfuncCtx(ctx context.Context) jwt.Keyfunc
Storage() jwkset.Storage
}
// Options are used to create a new Keyfunc.
type Options struct {
Ctx context.Context
Storage jwkset.Storage
UseWhitelist []jwkset.USE
}
type keyfunc struct {
ctx context.Context
storage jwkset.Storage
useWhitelist []jwkset.USE
}
// New creates a new Keyfunc.
func New(options Options) (Keyfunc, error) {
ctx := options.Ctx
if ctx == nil {
ctx = context.Background()
}
if options.Storage == nil {
return nil, fmt.Errorf("%w: no JWK Set storage given in options", ErrKeyfunc)
}
k := keyfunc{
ctx: ctx,
storage: options.Storage,
useWhitelist: options.UseWhitelist,
}
return k, nil
}
// NewDefault creates a new Keyfunc with a default JWK Set storage and options.
//
// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources.
func NewDefault(urls []string) (Keyfunc, error) {
return NewDefaultCtx(context.Background(), urls)
}
// NewDefaultCtx creates a new Keyfunc with a default JWK Set storage and options. The context is used to end the
// "refresh goroutine".
//
// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources.
func NewDefaultCtx(ctx context.Context, urls []string) (Keyfunc, error) {
client, err := jwkset.NewDefaultHTTPClientCtx(ctx, urls)
if err != nil {
return nil, err
}
options := Options{
Storage: client,
}
return New(options)
}
// NewJWKJSON creates a new Keyfunc from raw JWK JSON.
func NewJWKJSON(raw json.RawMessage) (Keyfunc, error) {
marshalOptions := jwkset.JWKMarshalOptions{
Private: true,
}
jwk, err := jwkset.NewJWKFromRawJSON(raw, marshalOptions, jwkset.JWKValidateOptions{})
if err != nil {
return nil, fmt.Errorf("%w: could not create JWK from raw JSON", errors.Join(err, ErrKeyfunc))
}
store := jwkset.NewMemoryStorage()
err = store.KeyWrite(context.Background(), jwk)
if err != nil {
return nil, fmt.Errorf("%w: could not write JWK to storage", errors.Join(err, ErrKeyfunc))
}
options := Options{
Storage: store,
}
return New(options)
}
// NewJWKSetJSON creates a new Keyfunc from raw JWK Set JSON.
func NewJWKSetJSON(raw json.RawMessage) (Keyfunc, error) {
var jwks jwkset.JWKSMarshal
err := json.Unmarshal(raw, &jwks)
if err != nil {
return nil, fmt.Errorf("%w: could not unmarshal raw JWK Set JSON", errors.Join(err, ErrKeyfunc))
}
store, err := jwks.ToStorage()
if err != nil {
return nil, fmt.Errorf("%w: could not create JWK Set storage", errors.Join(err, ErrKeyfunc))
}
options := Options{
Storage: store,
}
return New(options)
}
func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (any, error) {
kidInter, ok := token.Header[jwkset.HeaderKID]
if !ok {
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKeyfunc)
}
kid, ok := kidInter.(string)
if !ok {
return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKeyfunc)
}
algInter, ok := token.Header["alg"]
if !ok {
return nil, fmt.Errorf("%w: could not find alg in JWT header", ErrKeyfunc)
}
alg, ok := algInter.(string)
if !ok {
// For test coverage purposes, this should be impossible to reach because the JWT package rejects a token
// without an alg parameter in the header before calling jwt.Keyfunc.
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc)
}
jwk, err := k.storage.KeyRead(ctx, kid)
if err != nil {
return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc))
}
if a := jwk.Marshal().ALG.String(); a != "" && a != alg {
return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrKeyfunc, a, alg)
}
if len(k.useWhitelist) > 0 {
found := false
for _, u := range k.useWhitelist {
if jwk.Marshal().USE == u {
found = true
break
}
}
if !found {
return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not in whitelist`, ErrKeyfunc, jwk.Marshal().USE)
}
}
type publicKeyer interface {
Public() crypto.PublicKey
}
key := jwk.Key()
pk, ok := key.(publicKeyer)
if ok {
key = pk.Public()
}
return key, nil
}
}
func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) {
keyF := k.KeyfuncCtx(k.ctx)
return keyF(token)
}
func (k keyfunc) Storage() jwkset.Storage {
return k.storage
}