Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ The `TokenGenerator` functionality allows you to create JWT tokens directly with
package main

import (
"context"
"fmt"
"log"
"time"
Expand All @@ -356,9 +357,12 @@ func main() {
log.Fatal("JWT Error:" + err.Error())
}

// Create context for token operations
ctx := context.Background()

// Generate a complete token pair (access + refresh tokens)
userData := "user123"
tokenPair, err := authMiddleware.TokenGenerator(userData)
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
if err != nil {
log.Fatal("Failed to generate token pair:", err)
}
Expand Down Expand Up @@ -392,7 +396,7 @@ Use `TokenGeneratorWithRevocation` to refresh tokens and automatically revoke ol

```go
// Refresh with automatic revocation of old token
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldRefreshToken)
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldRefreshToken)
if err != nil {
log.Fatal("Failed to refresh token:", err)
}
Expand Down
6 changes: 5 additions & 1 deletion README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ import "github.com/appleboy/gin-jwt/v3"
package main

import (
"context"
"fmt"
"log"
"time"
Expand All @@ -191,9 +192,12 @@ func main() {
log.Fatal("JWT Error:" + err.Error())
}

// 创建 Token 操作的 context
ctx := context.Background()

// 生成完整的 Token 组(访问 + 刷新 Token)
userData := "user123"
tokenPair, err := authMiddleware.TokenGenerator(userData)
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
if err != nil {
log.Fatal("Failed to generate token pair:", err)
}
Expand Down
8 changes: 6 additions & 2 deletions README.zh-TW.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ import "github.com/appleboy/gin-jwt/v3"
package main

import (
"context"
"fmt"
"log"
"time"
Expand All @@ -191,9 +192,12 @@ func main() {
log.Fatal("JWT Error:" + err.Error())
}

// 建立 Token 操作的 context
ctx := context.Background()

// 產生完整的 Token 組(存取 + 刷新 Token)
userData := "user123"
tokenPair, err := authMiddleware.TokenGenerator(userData)
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
if err != nil {
log.Fatal("Failed to generate token pair:", err)
}
Expand Down Expand Up @@ -227,7 +231,7 @@ func (t *Token) ExpiresIn() int64 // 回傳到期前的秒數

```go
// 刷新並自動撤銷舊 Token
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldRefreshToken)
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldRefreshToken)
if err != nil {
log.Fatal("Failed to refresh token:", err)
}
Expand Down
10 changes: 7 additions & 3 deletions _example/token_generator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package main

import (
"context"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -30,9 +31,12 @@ func main() {
// Example user data
userData := "user123"

// Create context for token operations
ctx := context.Background()

// Generate a complete token pair (access + refresh tokens)
fmt.Println("=== Generating Token Pair ===")
tokenPair, err := authMiddleware.TokenGenerator(userData)
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
if err != nil {
log.Fatal("Failed to generate token pair:", err)
}
Expand All @@ -46,7 +50,7 @@ func main() {

// Simulate refresh token usage
fmt.Println("\n=== Refreshing Token Pair ===")
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, tokenPair.RefreshToken)
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, tokenPair.RefreshToken)
if err != nil {
log.Fatal("Failed to refresh token pair:", err)
}
Expand All @@ -57,7 +61,7 @@ func main() {

// Verify old refresh token is invalid
fmt.Println("\n=== Verifying Old Token Revocation ===")
_, err = authMiddleware.TokenGeneratorWithRevocation(userData, tokenPair.RefreshToken)
_, err = authMiddleware.TokenGeneratorWithRevocation(ctx, userData, tokenPair.RefreshToken)
if err != nil {
fmt.Printf("Old refresh token correctly rejected: %s\n", err)
}
Expand Down
31 changes: 16 additions & 15 deletions auth_jwt.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
Expand Down Expand Up @@ -568,7 +569,7 @@ func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {
}

// Generate complete token pair
tokenPair, err := mw.TokenGenerator(data)
tokenPair, err := mw.TokenGenerator(c.Request.Context(), data)
if err != nil {
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(c,ErrFailedTokenCreation))
return
Expand Down Expand Up @@ -612,7 +613,7 @@ func (mw *GinJWTMiddleware) LogoutHandler(c *gin.Context) {
// Handle refresh token revocation (RFC 6749 compliant)
refreshToken := mw.extractRefreshToken(c)
if refreshToken != "" {
if err := mw.revokeRefreshToken(refreshToken); err != nil {
if err := mw.revokeRefreshToken(c.Request.Context(), refreshToken); err != nil {
log.Printf("Failed to revoke refresh token on logout: %v", err)
}
}
Expand Down Expand Up @@ -658,14 +659,14 @@ func (mw *GinJWTMiddleware) generateRefreshToken() (string, error) {
}

// storeRefreshToken stores a refresh token with user data
func (mw *GinJWTMiddleware) storeRefreshToken(token string, userData any) error {
func (mw *GinJWTMiddleware) storeRefreshToken(ctx context.Context, token string, userData any) error {
expiry := mw.TimeFunc().Add(mw.RefreshTokenTimeout)
return mw.RefreshTokenStore.Set(token, userData, expiry)
return mw.RefreshTokenStore.Set(ctx, token, userData, expiry)
}

// validateRefreshToken validates a refresh token and returns associated user data
func (mw *GinJWTMiddleware) validateRefreshToken(token string) (any, error) {
userData, err := mw.RefreshTokenStore.Get(token)
func (mw *GinJWTMiddleware) validateRefreshToken(ctx context.Context, token string) (any, error) {
userData, err := mw.RefreshTokenStore.Get(ctx, token)
if err != nil {
if err == core.ErrRefreshTokenNotFound {
return nil, ErrInvalidRefreshToken
Expand All @@ -676,8 +677,8 @@ func (mw *GinJWTMiddleware) validateRefreshToken(token string) (any, error) {
}

// revokeRefreshToken removes a refresh token from storage
func (mw *GinJWTMiddleware) revokeRefreshToken(token string) error {
return mw.RefreshTokenStore.Delete(token)
func (mw *GinJWTMiddleware) revokeRefreshToken(ctx context.Context, token string) error {
return mw.RefreshTokenStore.Delete(ctx, token)
}

// RefreshHandler can be used to refresh a token using RFC 6749 compliant refresh tokens.
Expand All @@ -692,14 +693,14 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {
}

// Validate refresh token
userData, err := mw.validateRefreshToken(refreshToken)
userData, err := mw.validateRefreshToken(c.Request.Context(), refreshToken)
if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(c,err))
return
}

// Generate new token pair and revoke old refresh token
tokenPair, err := mw.TokenGeneratorWithRevocation(userData, refreshToken)
tokenPair, err := mw.TokenGeneratorWithRevocation(c.Request.Context(), userData, refreshToken)
if err != nil {
mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(c,err))
return
Expand Down Expand Up @@ -795,7 +796,7 @@ func (mw *GinJWTMiddleware) generateAccessToken(data any) (string, time.Time, er
}

// TokenGenerator generates a complete token pair (access + refresh) with RFC 6749 compliance
func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
func (mw *GinJWTMiddleware) TokenGenerator(ctx context.Context, data any) (*core.Token, error) {
// Generate access token
accessToken, expire, err := mw.generateAccessToken(data)
if err != nil {
Expand All @@ -809,7 +810,7 @@ func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
}

// Store refresh token
if err := mw.storeRefreshToken(refreshToken, data); err != nil {
if err := mw.storeRefreshToken(ctx, refreshToken, data); err != nil {
return nil, err
}

Expand All @@ -824,15 +825,15 @@ func (mw *GinJWTMiddleware) TokenGenerator(data any) (*core.Token, error) {
}

// TokenGeneratorWithRevocation generates a new token pair and revokes the old refresh token
func (mw *GinJWTMiddleware) TokenGeneratorWithRevocation(data any, oldRefreshToken string) (*core.Token, error) {
func (mw *GinJWTMiddleware) TokenGeneratorWithRevocation(ctx context.Context, data any, oldRefreshToken string) (*core.Token, error) {
// Generate new token pair
tokenPair, err := mw.TokenGenerator(data)
tokenPair, err := mw.TokenGenerator(ctx, data)
if err != nil {
return nil, err
}

// Revoke old refresh token, ignore if token already doesn't exist
if err := mw.revokeRefreshToken(oldRefreshToken); err != nil && !errors.Is(err, core.ErrRefreshTokenNotFound) {
if err := mw.revokeRefreshToken(ctx, oldRefreshToken); err != nil && !errors.Is(err, core.ErrRefreshTokenNotFound) {
return nil, err
}

Expand Down
11 changes: 6 additions & 5 deletions auth_jwt_redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,33 +401,34 @@ func testRedisStoreOperations(t *testing.T, middleware *GinJWTMiddleware) {
require.True(t, ok, "should be using Redis store")

// Test store operations directly
ctx := context.Background()
testToken := "direct-test-token"
testData := map[string]any{"test": "data"}
expiry := time.Now().Add(time.Hour)

// Test Set
err := redisStore.Set(testToken, testData, expiry)
err := redisStore.Set(ctx, testToken, testData, expiry)
assert.NoError(t, err, "direct set should succeed")

// Test Get
retrievedData, err := redisStore.Get(testToken)
retrievedData, err := redisStore.Get(ctx, testToken)
assert.NoError(t, err, "direct get should succeed")
assert.Equal(t, testData, retrievedData, "retrieved data should match")

// Test Count
count, err := redisStore.Count()
count, err := redisStore.Count(ctx)
assert.NoError(t, err, "count should succeed")
assert.GreaterOrEqual(t, count, 1, "count should include our test token")

// Test Delete
err = redisStore.Delete(testToken)
err = redisStore.Delete(ctx, testToken)
assert.NoError(t, err, "direct delete should succeed")

// Verify deletion - wait for cache TTL to expire
time.Sleep(100 * time.Millisecond)

// The Get method should return an error for deleted tokens
_, err = redisStore.Get(testToken)
_, err = redisStore.Get(ctx, testToken)
assert.Error(t, err, "token should not exist after deletion")
}

Expand Down
22 changes: 13 additions & 9 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt

import (
"context"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -1470,7 +1471,8 @@ func TestTokenGenerator(t *testing.T) {
assert.NoError(t, err)

userData := "admin"
tokenPair, err := authMiddleware.TokenGenerator(userData)
ctx := context.Background()
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)

assert.NoError(t, err)
assert.NotNil(t, tokenPair)
Expand Down Expand Up @@ -1510,40 +1512,41 @@ func TestTokenGeneratorWithRevocation(t *testing.T) {
assert.NoError(t, err)

userData := "admin"
ctx := context.Background()

// Generate first token pair
oldTokenPair, err := authMiddleware.TokenGenerator(userData)
oldTokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
assert.NoError(t, err)

// Verify old refresh token exists in store
storedData, err := authMiddleware.validateRefreshToken(oldTokenPair.RefreshToken)
storedData, err := authMiddleware.validateRefreshToken(ctx, oldTokenPair.RefreshToken)
assert.NoError(t, err)
assert.Equal(t, userData, storedData)

// Generate new token pair with revocation
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldTokenPair.RefreshToken)
newTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldTokenPair.RefreshToken)
assert.NoError(t, err)
assert.NotNil(t, newTokenPair)

// Verify refresh tokens are different (access tokens might be the same if generated in same second)
assert.NotEqual(t, oldTokenPair.RefreshToken, newTokenPair.RefreshToken)

// Verify old refresh token is revoked
_, err = authMiddleware.validateRefreshToken(oldTokenPair.RefreshToken)
_, err = authMiddleware.validateRefreshToken(ctx, oldTokenPair.RefreshToken)
assert.Error(t, err)

// Verify new refresh token works
storedData, err = authMiddleware.validateRefreshToken(newTokenPair.RefreshToken)
storedData, err = authMiddleware.validateRefreshToken(ctx, newTokenPair.RefreshToken)
assert.NoError(t, err)
assert.Equal(t, userData, storedData)

// Test revoking already revoked token (should not fail)
anotherTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, oldTokenPair.RefreshToken)
anotherTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, oldTokenPair.RefreshToken)
assert.NoError(t, err)
assert.NotNil(t, anotherTokenPair)

// Test revoking non-existent token (should not fail)
finalTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(userData, "non_existent_token")
finalTokenPair, err := authMiddleware.TokenGeneratorWithRevocation(ctx, userData, "non_existent_token")
assert.NoError(t, err)
assert.NotNil(t, finalTokenPair)
}
Expand All @@ -1562,7 +1565,8 @@ func TestTokenStruct(t *testing.T) {
assert.NoError(t, err)

userData := "admin"
tokenPair, err := authMiddleware.TokenGenerator(userData)
ctx := context.Background()
tokenPair, err := authMiddleware.TokenGenerator(ctx, userData)
assert.NoError(t, err)

// Test ExpiresIn method
Expand Down
11 changes: 6 additions & 5 deletions core/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package core

import (
"context"
"errors"
"time"
)
Expand All @@ -18,23 +19,23 @@ var (
type TokenStore interface {
// Set stores a refresh token with associated user data and expiration
// Returns an error if the operation fails
Set(token string, userData any, expiry time.Time) error
Set(ctx context.Context, token string, userData any, expiry time.Time) error

// Get retrieves user data associated with a refresh token
// Returns ErrRefreshTokenNotFound if token doesn't exist or is expired
Get(token string) (any, error)
Get(ctx context.Context, token string) (any, error)

// Delete removes a refresh token from storage
// Returns an error if the operation fails, but should not error if token doesn't exist
Delete(token string) error
Delete(ctx context.Context, token string) error

// Cleanup removes expired tokens (optional, for cleanup routines)
// Returns the number of tokens cleaned up and any error encountered
Cleanup() (int, error)
Cleanup(ctx context.Context) (int, error)

// Count returns the total number of active refresh tokens
// Useful for monitoring and debugging
Count() (int, error)
Count(ctx context.Context) (int, error)
}

// RefreshTokenData holds the data stored with each refresh token
Expand Down
Loading