Skip to content

Commit a0d5e75

Browse files
authored
Token federation for Go driver (2/3) (#291)
Adds token federation for databricks sql go driver
1 parent 66608b7 commit a0d5e75

File tree

4 files changed

+764
-0
lines changed

4 files changed

+764
-0
lines changed

auth/tokenprovider/cached.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// CachedTokenProvider wraps another provider and caches tokens
13+
type CachedTokenProvider struct {
14+
provider TokenProvider
15+
cache *Token
16+
mutex sync.RWMutex
17+
refreshing bool // prevents thundering herd
18+
// RefreshThreshold determines when to refresh (default 5 minutes before expiry)
19+
RefreshThreshold time.Duration
20+
}
21+
22+
// NewCachedTokenProvider creates a caching wrapper around any token provider
23+
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
24+
return &CachedTokenProvider{
25+
provider: provider,
26+
RefreshThreshold: 5 * time.Minute,
27+
}
28+
}
29+
30+
// GetToken retrieves a token, using cache if available and valid
31+
func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) {
32+
// Check if context is already cancelled
33+
if err := ctx.Err(); err != nil {
34+
return nil, fmt.Errorf("cached token provider: context cancelled: %w", err)
35+
}
36+
37+
// Try to get from cache first
38+
p.mutex.RLock()
39+
cached := p.cache
40+
needsRefresh := p.shouldRefresh(cached)
41+
isRefreshing := p.refreshing
42+
p.mutex.RUnlock()
43+
44+
// If cache is valid and not being refreshed, return a copy
45+
if cached != nil && !needsRefresh {
46+
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
47+
// Return a copy to avoid concurrent modification issues
48+
return copyToken(cached), nil
49+
}
50+
51+
// If another goroutine is already refreshing, wait briefly and retry
52+
if isRefreshing {
53+
time.Sleep(50 * time.Millisecond)
54+
p.mutex.RLock()
55+
cached = p.cache
56+
needsRefresh = p.shouldRefresh(cached)
57+
p.mutex.RUnlock()
58+
59+
if cached != nil && !needsRefresh {
60+
return copyToken(cached), nil
61+
}
62+
}
63+
64+
// Need to refresh - acquire write lock
65+
p.mutex.Lock()
66+
67+
// Double-check after acquiring write lock
68+
if p.cache != nil && !p.shouldRefresh(p.cache) {
69+
p.mutex.Unlock()
70+
return copyToken(p.cache), nil
71+
}
72+
73+
// Mark as refreshing to prevent other goroutines from also refreshing
74+
p.refreshing = true
75+
p.mutex.Unlock()
76+
77+
// Fetch new token WITHOUT holding the lock
78+
log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
79+
token, err := p.provider.GetToken(ctx)
80+
81+
// Update cache with result
82+
p.mutex.Lock()
83+
p.refreshing = false
84+
if err != nil {
85+
p.mutex.Unlock()
86+
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
87+
}
88+
89+
p.cache = token
90+
p.mutex.Unlock()
91+
92+
return copyToken(token), nil
93+
}
94+
95+
// copyToken creates a copy of a token to avoid concurrent modification issues
96+
func copyToken(t *Token) *Token {
97+
if t == nil {
98+
return nil
99+
}
100+
101+
scopesCopy := make([]string, len(t.Scopes))
102+
copy(scopesCopy, t.Scopes)
103+
104+
return &Token{
105+
AccessToken: t.AccessToken,
106+
TokenType: t.TokenType,
107+
ExpiresAt: t.ExpiresAt,
108+
Scopes: scopesCopy,
109+
}
110+
}
111+
112+
// shouldRefresh determines if a token should be refreshed based on expiry time.
113+
// Returns true if:
114+
// - token is nil
115+
// - token has expired
116+
// - token will expire within RefreshThreshold (default 5 minutes)
117+
//
118+
// Returns false if:
119+
// - token has no expiry time (never expires)
120+
// - token is still valid and not close to expiry
121+
func (p *CachedTokenProvider) shouldRefresh(token *Token) bool {
122+
if token == nil {
123+
return true
124+
}
125+
126+
// If no expiry time, assume token doesn't expire
127+
if token.ExpiresAt.IsZero() {
128+
return false
129+
}
130+
131+
// Refresh if within threshold of expiry
132+
refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold)
133+
return time.Now().After(refreshAt)
134+
}
135+
136+
// Name returns the provider name
137+
func (p *CachedTokenProvider) Name() string {
138+
return fmt.Sprintf("cached[%s]", p.provider.Name())
139+
}
140+
141+
// ClearCache clears the cached token
142+
func (p *CachedTokenProvider) ClearCache() {
143+
p.mutex.Lock()
144+
p.cache = nil
145+
p.mutex.Unlock()
146+
}

auth/tokenprovider/exchange.go

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
13+
"github.com/golang-jwt/jwt/v5"
14+
"github.com/rs/zerolog/log"
15+
)
16+
17+
// FederationProvider wraps another token provider and automatically handles token exchange
18+
type FederationProvider struct {
19+
baseProvider TokenProvider
20+
databricksHost string
21+
clientID string // For SP-wide federation
22+
httpClient *http.Client
23+
// Settings for token exchange
24+
returnOriginalTokenIfAuthenticated bool
25+
}
26+
27+
// NewFederationProvider creates a federation provider that wraps another provider
28+
// It automatically detects when token exchange is needed and falls back gracefully
29+
func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider {
30+
return &FederationProvider{
31+
baseProvider: baseProvider,
32+
databricksHost: databricksHost,
33+
httpClient: &http.Client{Timeout: 30 * time.Second},
34+
returnOriginalTokenIfAuthenticated: true,
35+
}
36+
}
37+
38+
// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M)
39+
func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider {
40+
return &FederationProvider{
41+
baseProvider: baseProvider,
42+
databricksHost: databricksHost,
43+
clientID: clientID,
44+
httpClient: &http.Client{Timeout: 30 * time.Second},
45+
returnOriginalTokenIfAuthenticated: true,
46+
}
47+
}
48+
49+
// GetToken gets token from base provider and exchanges if needed
50+
func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) {
51+
// Check if context is already cancelled
52+
if err := ctx.Err(); err != nil {
53+
return nil, fmt.Errorf("federation provider: context cancelled: %w", err)
54+
}
55+
56+
// Get token from base provider
57+
baseToken, err := p.baseProvider.GetToken(ctx)
58+
if err != nil {
59+
return nil, fmt.Errorf("federation provider: failed to get base token: %w", err)
60+
}
61+
62+
// Check if token is a JWT and needs exchange
63+
if p.needsTokenExchange(baseToken.AccessToken) {
64+
log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name())
65+
66+
// Try token exchange
67+
exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken)
68+
if err != nil {
69+
log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token")
70+
return baseToken, nil // Fall back to original token
71+
}
72+
73+
log.Debug().Msg("federation provider: token exchange successful")
74+
return exchangedToken, nil
75+
}
76+
77+
// Use original token
78+
return baseToken, nil
79+
}
80+
81+
// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer
82+
func (p *FederationProvider) needsTokenExchange(tokenString string) bool {
83+
// Try to parse as JWT without verification
84+
// We use ParseUnverified because:
85+
// 1. We only need to inspect claims (issuer), not validate the signature
86+
// 2. We don't have the public key for external identity providers
87+
// 3. Token validation will be done by Databricks during exchange
88+
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
89+
if err != nil {
90+
log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange")
91+
return false
92+
}
93+
94+
claims, ok := token.Claims.(jwt.MapClaims)
95+
if !ok {
96+
return false
97+
}
98+
99+
issuer, ok := claims["iss"].(string)
100+
if !ok {
101+
return false
102+
}
103+
104+
// Check if issuer is different from Databricks host
105+
return !p.isSameHost(issuer, p.databricksHost)
106+
}
107+
108+
// tryTokenExchange attempts to exchange the token with Databricks
109+
func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) {
110+
// Build exchange URL - add scheme if not present
111+
exchangeURL := p.databricksHost
112+
if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") {
113+
// Default to HTTPS for security
114+
exchangeURL = "https://" + exchangeURL
115+
} else if strings.HasPrefix(exchangeURL, "http://") {
116+
// Warn if using insecure HTTP for token exchange
117+
log.Warn().Msgf("federation provider: using insecure HTTP for token exchange: %s", exchangeURL)
118+
}
119+
if !strings.HasSuffix(exchangeURL, "/") {
120+
exchangeURL += "/"
121+
}
122+
exchangeURL += "oidc/v1/token"
123+
124+
// Prepare form data for token exchange
125+
data := url.Values{}
126+
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
127+
data.Set("scope", "sql")
128+
data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
129+
data.Set("subject_token", subjectToken)
130+
131+
if p.returnOriginalTokenIfAuthenticated {
132+
data.Set("return_original_token_if_authenticated", "true")
133+
}
134+
135+
// Add client_id for SP-wide federation
136+
if p.clientID != "" {
137+
data.Set("client_id", p.clientID)
138+
}
139+
140+
// Create request
141+
req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode()))
142+
if err != nil {
143+
return nil, fmt.Errorf("failed to create request: %w", err)
144+
}
145+
146+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
147+
req.Header.Set("Accept", "*/*")
148+
149+
// Make request
150+
resp, err := p.httpClient.Do(req)
151+
if err != nil {
152+
return nil, fmt.Errorf("request failed: %w", err)
153+
}
154+
defer resp.Body.Close()
155+
156+
body, err := io.ReadAll(resp.Body)
157+
if err != nil {
158+
return nil, fmt.Errorf("failed to read response: %w", err)
159+
}
160+
161+
if resp.StatusCode != http.StatusOK {
162+
return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body))
163+
}
164+
165+
// Parse response
166+
var tokenResp struct {
167+
AccessToken string `json:"access_token"`
168+
TokenType string `json:"token_type"`
169+
ExpiresIn int `json:"expires_in"`
170+
Scope string `json:"scope"`
171+
}
172+
173+
if err := json.Unmarshal(body, &tokenResp); err != nil {
174+
return nil, fmt.Errorf("failed to parse response: %w", err)
175+
}
176+
177+
// Validate token response
178+
if tokenResp.AccessToken == "" {
179+
return nil, fmt.Errorf("token exchange returned empty access token")
180+
}
181+
if tokenResp.TokenType == "" {
182+
log.Debug().Msg("token exchange: token_type not specified, defaulting to Bearer")
183+
tokenResp.TokenType = "Bearer"
184+
}
185+
if tokenResp.ExpiresIn < 0 {
186+
return nil, fmt.Errorf("token exchange returned invalid expires_in: %d", tokenResp.ExpiresIn)
187+
}
188+
189+
token := &Token{
190+
AccessToken: tokenResp.AccessToken,
191+
TokenType: tokenResp.TokenType,
192+
Scopes: strings.Fields(tokenResp.Scope),
193+
}
194+
195+
if tokenResp.ExpiresIn > 0 {
196+
token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
197+
}
198+
199+
return token, nil
200+
}
201+
202+
// isSameHost compares two URLs to see if they have the same host
203+
func (p *FederationProvider) isSameHost(url1, url2 string) bool {
204+
// Add scheme to url2 if it doesn't have one (databricksHost may not have scheme)
205+
parsedURL2 := url2
206+
if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") {
207+
parsedURL2 = "https://" + url2
208+
}
209+
210+
u1, err1 := url.Parse(url1)
211+
u2, err2 := url.Parse(parsedURL2)
212+
213+
if err1 != nil || err2 != nil {
214+
log.Debug().Msgf("federation provider: failed to parse URLs for comparison: url1=%s err1=%v, url2=%s err2=%v",
215+
url1, err1, parsedURL2, err2)
216+
return false
217+
}
218+
219+
// Use Hostname() instead of Host to ignore port differences
220+
// This handles cases like "host.com:443" == "host.com" for HTTPS
221+
isSame := u1.Hostname() == u2.Hostname()
222+
log.Debug().Msgf("federation provider: host comparison: %s vs %s = %v", u1.Hostname(), u2.Hostname(), isSame)
223+
return isSame
224+
}
225+
226+
// Name returns the provider name
227+
func (p *FederationProvider) Name() string {
228+
baseName := p.baseProvider.Name()
229+
if p.clientID != "" {
230+
clientIDDisplay := p.clientID
231+
if len(p.clientID) > 8 {
232+
clientIDDisplay = p.clientID[:8]
233+
}
234+
return fmt.Sprintf("federation[%s,sp:%s]", baseName, clientIDDisplay) // Truncate client ID for readability
235+
}
236+
return fmt.Sprintf("federation[%s]", baseName)
237+
}

0 commit comments

Comments
 (0)