Skip to content

Commit 24fa4d0

Browse files
committed
CSRF: support older token-based CSRF protection handler that want to render token into template
(cherry picked from commit 9183f1e)
1 parent 482bb46 commit 24fa4d0

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

middleware/csrf.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ import (
1313
"github.com/labstack/echo/v4"
1414
)
1515

16+
// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site
17+
// header and the request is deemed safe.
18+
// It is a dummy token value that can be used to render CSRF token for form by handlers.
19+
//
20+
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
21+
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
22+
// handler may need this value to render CSRF token for form.
23+
const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_"
24+
1625
// CSRFConfig defines the config for CSRF middleware.
1726
type CSRFConfig struct {
1827
// Skipper defines a function to skip middleware.
@@ -83,6 +92,8 @@ type CSRFConfig struct {
8392

8493
// ErrorHandler defines a function which is executed for returning custom errors.
8594
ErrorHandler CSRFErrorHandler
95+
96+
generator func(length uint8) string
8697
}
8798

8899
// CSRFErrorHandler is a function which is executed for creating custom errors.
@@ -145,6 +156,10 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
145156
}
146157
config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
147158
}
159+
tokenGenerator := randomString
160+
if config.generator != nil {
161+
tokenGenerator = config.generator
162+
}
148163

149164
extractors, cErr := CreateExtractors(config.TokenLookup)
150165
if cErr != nil {
@@ -170,7 +185,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
170185

171186
token := ""
172187
if k, err := c.Cookie(config.CookieName); err != nil {
173-
token = randomString(config.TokenLength)
188+
token = tokenGenerator(config.TokenLength)
174189
} else {
175190
token = k.Value // Reuse token
176191
}
@@ -287,6 +302,11 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error)
287302
}
288303

289304
if isSafe {
305+
// This helps handlers that support older token-based CSRF protection.
306+
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
307+
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
308+
// handler may need this value to render CSRF token for form.
309+
c.Set(config.ContextKey, CSRFUsingSecFetchSite)
290310
return true, nil
291311
}
292312
// we are here when request is state-changing and `cross-site` or `same-site`

middleware/csrf_test.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
8686
},
8787
},
8888
{
89-
name: "ok, token from POST header, second token passes",
89+
name: "nok, token from POST header, tokens limited to 1, second token would pass",
9090
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
9191
givenCSRFCookie: "token",
9292
givenMethod: http.MethodPost,
@@ -122,7 +122,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
122122
},
123123
},
124124
{
125-
name: "ok, token from PUT query form, second token passes",
125+
name: "nok, token from PUT query form, second token would pass",
126126
whenTokenLookup: "query:csrf",
127127
givenCSRFCookie: "token",
128128
givenMethod: http.MethodPut,
@@ -235,12 +235,14 @@ func TestCSRFWithConfig(t *testing.T) {
235235
expectEmptyBody bool
236236
expectMWError string
237237
expectCookieContains string
238+
expectTokenInContext string
238239
expectErr string
239240
}{
240241
{
241242
name: "ok, GET",
242243
whenMethod: http.MethodGet,
243244
expectCookieContains: "_csrf",
245+
expectTokenInContext: "TESTTOKEN",
244246
},
245247
{
246248
name: "ok, POST valid token",
@@ -250,6 +252,7 @@ func TestCSRFWithConfig(t *testing.T) {
250252
},
251253
whenMethod: http.MethodPost,
252254
expectCookieContains: "_csrf",
255+
expectTokenInContext: token,
253256
},
254257
{
255258
name: "nok, POST without token",
@@ -278,13 +281,23 @@ func TestCSRFWithConfig(t *testing.T) {
278281
},
279282
whenMethod: http.MethodGet,
280283
expectCookieContains: "_csrf",
284+
expectTokenInContext: "TESTTOKEN",
281285
},
282286
{
283287
name: "ok, unsafe method + SecFetchSite=same-origin passes",
284288
whenHeaders: map[string]string{
285289
echo.HeaderSecFetchSite: "same-origin",
286290
},
287-
whenMethod: http.MethodPost,
291+
whenMethod: http.MethodPost,
292+
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
293+
},
294+
{
295+
name: "ok, safe method + SecFetchSite=same-origin passes",
296+
whenHeaders: map[string]string{
297+
echo.HeaderSecFetchSite: "same-origin",
298+
},
299+
whenMethod: http.MethodGet,
300+
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
288301
},
289302
{
290303
name: "nok, unsafe method + SecFetchSite=same-cross blocked",
@@ -312,6 +325,11 @@ func TestCSRFWithConfig(t *testing.T) {
312325
if tc.givenConfig != nil {
313326
config = *tc.givenConfig
314327
}
328+
if config.generator == nil {
329+
config.generator = func(_ uint8) string {
330+
return "TESTTOKEN"
331+
}
332+
}
315333
mw, err := config.ToMiddleware()
316334
if tc.expectMWError != "" {
317335
assert.EqualError(t, err, tc.expectMWError)
@@ -320,6 +338,8 @@ func TestCSRFWithConfig(t *testing.T) {
320338
assert.NoError(t, err)
321339

322340
h := mw(func(c echo.Context) error {
341+
cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey))
342+
assert.Equal(t, tc.expectTokenInContext, cToken)
323343
return c.String(http.StatusOK, "test")
324344
})
325345

@@ -559,7 +579,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
559579
whenMethod: http.MethodPost,
560580
whenSecFetchSite: "same-site",
561581
expectAllow: false,
562-
expectErr: ``,
563582
},
564583
{
565584
name: "ok, unsafe POST + same-origin passes",
@@ -617,7 +636,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
617636
whenMethod: http.MethodPut,
618637
whenSecFetchSite: "same-site",
619638
expectAllow: false,
620-
expectErr: ``,
621639
},
622640
{
623641
name: "nok, unsafe DELETE + cross-site is blocked",
@@ -633,7 +651,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
633651
whenMethod: http.MethodDelete,
634652
whenSecFetchSite: "same-site",
635653
expectAllow: false,
636-
expectErr: ``,
637654
},
638655
{
639656
name: "nok, unsafe PATCH + cross-site is blocked",

0 commit comments

Comments
 (0)