Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ var brokenAuthHeaderDomains = []string{
// connectorData stores information for sessions authenticated by this connector
type connectorData struct {
RefreshToken []byte
IDToken []byte // raw upstream id_token JWT for RP-Initiated logout
}

// Detect auth header provider issues for known providers. This lets users
Expand Down Expand Up @@ -736,6 +737,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
cd := connectorData{
RefreshToken: []byte(token.RefreshToken),
}
if rawIDToken, ok := token.Extra("id_token").(string); ok {
cd.IDToken = []byte(rawIDToken)
}

connData, err := json.Marshal(&cd)
if err != nil {
Expand Down Expand Up @@ -766,7 +770,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
// LogoutURL returns the upstream OIDC provider's end_session_endpoint URL.
// Per the OIDC RP-Initiated Logout spec, the post_logout_redirect_uri parameter
// tells the upstream where to redirect after logout.
func (c *oidcConnector) LogoutURL(_ context.Context, _ []byte, postLogoutRedirectURI string) (string, error) {
func (c *oidcConnector) LogoutURL(_ context.Context, rawConnectorData []byte, postLogoutRedirectURI string) (string, error) {
if c.endSessionURL == "" {
return "", nil
}
Expand All @@ -781,6 +785,16 @@ func (c *oidcConnector) LogoutURL(_ context.Context, _ []byte, postLogoutRedirec
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
q.Set("client_id", c.oauth2Config.ClientID)
}
// Per the RP-Initiated Logout spec, id_token_hint is independently valid
// of post_logout_redirect_uri — include it whenever we have one.
if len(rawConnectorData) > 0 {
var cd connectorData
if err := json.Unmarshal(rawConnectorData, &cd); err == nil {
if len(cd.IDToken) > 0 {
q.Set("id_token_hint", string(cd.IDToken))
}
}
}
u.RawQuery = q.Encode()

return u.String(), nil
Expand Down
41 changes: 40 additions & 1 deletion connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -979,10 +979,22 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
}

func TestLogoutURL(t *testing.T) {
idTokenConnData, err := json.Marshal(connectorData{
RefreshToken: []byte("refresh"),
IDToken: []byte("id-token-jwt"),
})
require.NoError(t, err)

noIDTokenConnData, err := json.Marshal(connectorData{
RefreshToken: []byte("refresh"),
})
require.NoError(t, err)

tests := []struct {
name string
endSessionURL string
postLogoutRedirectURI string
connectorData []byte
wantURL string
wantEmpty bool
}{
Expand All @@ -1008,6 +1020,33 @@ func TestLogoutURL(t *testing.T) {
postLogoutRedirectURI: "https://dex.example.com/callback",
wantURL: "https://provider.example.com/logout?client_id=clientID&existing=param&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Fcallback",
},
{
name: "with id_token_hint from connector data",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://dex.example.com/logout/callback",
connectorData: idTokenConnData,
wantURL: "https://provider.example.com/logout?client_id=clientID&id_token_hint=id-token-jwt&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback",
},
{
name: "id_token_hint included without post_logout_redirect_uri",
endSessionURL: "https://provider.example.com/logout",
connectorData: idTokenConnData,
wantURL: "https://provider.example.com/logout?id_token_hint=id-token-jwt",
},
{
name: "connector data without IDToken omits id_token_hint",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://dex.example.com/logout/callback",
connectorData: noIDTokenConnData,
wantURL: "https://provider.example.com/logout?client_id=clientID&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback",
},
{
name: "malformed connector data is ignored",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://dex.example.com/logout/callback",
connectorData: []byte("not-json"),
wantURL: "https://provider.example.com/logout?client_id=clientID&post_logout_redirect_uri=https%3A%2F%2Fdex.example.com%2Flogout%2Fcallback",
},
}

for _, tc := range tests {
Expand All @@ -1019,7 +1058,7 @@ func TestLogoutURL(t *testing.T) {
},
}

got, err := conn.LogoutURL(context.Background(), nil, tc.postLogoutRedirectURI)
got, err := conn.LogoutURL(context.Background(), tc.connectorData, tc.postLogoutRedirectURI)
require.NoError(t, err)

if tc.wantEmpty {
Expand Down
10 changes: 9 additions & 1 deletion server/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,21 @@ func (s *Server) tryUpstreamLogout(ctx context.Context, userID, connectorID stri
}

// Check that the session exists — we need it to store logout state.
_, err = s.storage.GetAuthSession(ctx, userID, connectorID)
session, err := s.storage.GetAuthSession(ctx, userID, connectorID)
if err != nil {
s.logger.DebugContext(ctx, "logout: no auth session for upstream logout, skipping",
"user_id", userID, "connector_id", connectorID)
return "", false
}

// The auth session connector data should keep an id_token that will be used as hint for RP-Initiated logout
if len(session.ConnectorData) > 0 {
connectorData = session.ConnectorData
s.logger.DebugContext(ctx, "logout: using auth_session.ConnectorData", "connector_id", connectorID)
} else if len(connectorData) == 0 {
s.logger.DebugContext(ctx, "logout: no connector data available", "connector_id", connectorID)
}

// Store logout parameters in the session.
if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) {
old.LogoutState = &storage.LogoutState{
Expand Down
85 changes: 85 additions & 0 deletions server/logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,29 @@ import (

"github.com/stretchr/testify/require"

"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/storage"
)

// recordingLogoutConnector implements connector.LogoutCallbackConnector and
// records the connectorData it was invoked with so tests can assert what was
// passed down.
type recordingLogoutConnector struct {
gotConnectorData []byte
returnURL string
}

func (c *recordingLogoutConnector) LogoutURL(_ context.Context, connectorData []byte, _ string) (string, error) {
c.gotConnectorData = connectorData
return c.returnURL, nil
}

func (c *recordingLogoutConnector) HandleLogoutCallback(_ context.Context, _ *http.Request) error {
return nil
}

var _ connector.LogoutCallbackConnector = (*recordingLogoutConnector)(nil)

func TestHandleLogoutNoSessions(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
Expand Down Expand Up @@ -380,3 +400,68 @@ func TestRevokeRefreshTokensReturnsConnectorData(t *testing.T) {
require.Empty(t, os.Refresh)
require.Equal(t, expectedConnData, os.ConnectorData)
}

// TestTryUpstreamLogoutPrefersSessionConnectorData verifies that when the auth
// session has ConnectorData stored (from login), it takes precedence over the
// connectorData the caller passes in (which originates from the offline session).
func TestTryUpstreamLogoutPrefersSessionConnectorData(t *testing.T) {
tests := []struct {
name string
sessionConnData []byte
callerConnData []byte
wantConnData []byte
}{
{
name: "session data wins over caller data",
sessionConnData: []byte(`{"IDToken":"session-token"}`),
callerConnData: []byte(`{"IDToken":"caller-token"}`),
wantConnData: []byte(`{"IDToken":"session-token"}`),
},
{
name: "caller data used when session data is empty",
sessionConnData: nil,
callerConnData: []byte(`{"IDToken":"caller-token"}`),
wantConnData: []byte(`{"IDToken":"caller-token"}`),
},
{
name: "empty when neither source has data",
sessionConnData: nil,
callerConnData: nil,
wantConnData: nil,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, server := newTestServerWithSessions(t, nil)
defer httpServer.Close()

ctx := t.Context()
userID := "test-user"
connectorID := "mock"

// Inject a recording connector with matching ResourceVersion so
// getConnector returns our mock instead of re-opening from storage.
rec := &recordingLogoutConnector{returnURL: "https://upstream.example.com/logout"}
server.mu.Lock()
server.connectors[connectorID] = Connector{
Type: "mockCallback",
ResourceVersion: "1",
Connector: rec,
}
server.mu.Unlock()

require.NoError(t, server.storage.CreateAuthSession(ctx, storage.AuthSession{
UserID: userID, ConnectorID: connectorID, Nonce: "nonce",
CreatedAt: time.Now(), LastActivity: time.Now(),
ConnectorData: tc.sessionConnData,
}))

redirectURL, ok := server.tryUpstreamLogout(ctx, userID, connectorID, tc.callerConnData,
"https://dex.example.com/cb", "state-123", "client-123")
require.True(t, ok)
require.Equal(t, "https://upstream.example.com/logout", redirectURL)
require.Equal(t, tc.wantConnData, rec.gotConnectorData)
})
}
}
2 changes: 2 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request,
old.ClientStates = make(map[string]*storage.ClientAuthState)
}
old.ClientStates[authReq.ClientID] = clientState
old.ConnectorData = authReq.ConnectorData
return old, nil
}); err != nil {
return fmt.Errorf("update auth session: %w", err)
Expand Down Expand Up @@ -289,6 +290,7 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request,
UserAgent: r.UserAgent(),
AbsoluteExpiry: now.Add(s.sessionConfig.AbsoluteLifetime),
IdleExpiry: now.Add(s.sessionConfig.ValidIfNotUsedFor),
ConnectorData: authReq.ConnectorData,
}

if err := s.storage.CreateAuthSession(ctx, newSession); err != nil {
Expand Down
9 changes: 8 additions & 1 deletion storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package conformance

import (
"bytes"
"crypto/ecdsa"
"reflect"
"sort"
Expand Down Expand Up @@ -1388,6 +1389,7 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) {
UserAgent: "TestBrowser/1.0",
AbsoluteExpiry: now.Add(24 * time.Hour),
IdleExpiry: now.Add(1 * time.Hour),
ConnectorData: []byte(`{"RefreshToken":"dGVzdA==","IDToken":"ZXlKaGJHY21PaUpTVXpJMU5pSjk="}`),
}

// Create.
Expand Down Expand Up @@ -1418,15 +1420,17 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) {
t.Errorf("auth session retrieved from storage did not match: %s", diff)
}

// Update: add a new client state.
// Update: add a new client state and rotate connector data.
newNow := now.Add(time.Minute)
updatedConnectorData := []byte(`{"RefreshToken":"bmV3","IDToken":"bmV3LWlk"}`)
if err := s.UpdateAuthSession(ctx, session.UserID, session.ConnectorID, func(old storage.AuthSession) (storage.AuthSession, error) {
old.ClientStates["client2"] = &storage.ClientAuthState{
Active: true,
ExpiresAt: newNow.Add(24 * time.Hour),
LastActivity: newNow,
}
old.LastActivity = newNow
old.ConnectorData = updatedConnectorData
return old, nil
}); err != nil {
t.Fatalf("update auth session: %v", err)
Expand All @@ -1443,6 +1447,9 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) {
if got.ClientStates["client2"] == nil {
t.Fatal("expected client2 state to exist")
}
if !bytes.Equal(got.ConnectorData, updatedConnectorData) {
t.Fatalf("expected updated connector data %q, got %q", updatedConnectorData, got.ConnectorData)
}

// List and verify.
sessions, err := s.ListAuthSessions(ctx)
Expand Down
2 changes: 2 additions & 0 deletions storage/ent/client/authsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSe
SetUserAgent(session.UserAgent).
SetAbsoluteExpiry(session.AbsoluteExpiry.UTC()).
SetIdleExpiry(session.IdleExpiry.UTC()).
SetConnectorData(session.ConnectorData).
Save(ctx)
if err != nil {
return convertDBError("create auth session: %w", err)
Expand Down Expand Up @@ -106,6 +107,7 @@ func (d *Database) UpdateAuthSession(ctx context.Context, userID, connectorID st
SetUserAgent(newSession.UserAgent).
SetAbsoluteExpiry(newSession.AbsoluteExpiry.UTC()).
SetIdleExpiry(newSession.IdleExpiry.UTC()).
SetConnectorData(newSession.ConnectorData).
Save(ctx)
if err != nil {
return rollback(tx, "update auth session updating: %w", err)
Expand Down
4 changes: 4 additions & 0 deletions storage/ent/client/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ func toStorageAuthSession(s *db.AuthSession) storage.AuthSession {
IdleExpiry: s.IdleExpiry,
}

if s.ConnectorData != nil {
result.ConnectorData = *s.ConnectorData
}

if s.ClientStates != nil {
if err := json.Unmarshal(s.ClientStates, &result.ClientStates); err != nil {
panic(err)
Expand Down
19 changes: 16 additions & 3 deletions storage/ent/db/authsession.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions storage/ent/db/authsession/authsession.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading