Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/rpc/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/isolationgroup"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/types"
)

type authOutboundMiddleware struct {
Expand Down Expand Up @@ -115,6 +116,14 @@ func (m *InboundMetricsMiddleware) Handle(ctx context.Context, req *transport.Re
return h.Handle(ctx, req, resw)
}

// CallerInfoMiddleware extracts caller information from headers and adds it to the context.
type CallerInfoMiddleware struct{}

func (m *CallerInfoMiddleware) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
ctx = types.GetContextWithCallerInfoFromHeaders(ctx, req.Headers)
return h.Handle(ctx, req, resw)
}

type overrideCallerMiddleware struct {
caller string
}
Expand Down
52 changes: 52 additions & 0 deletions common/rpc/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/isolationgroup"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/types"
)

func TestAuthOubboundMiddleware(t *testing.T) {
Expand Down Expand Up @@ -329,6 +330,57 @@ func TestClientPartitionConfigMiddleware(t *testing.T) {
})
}

func TestCallerInfoMiddleware(t *testing.T) {
t.Run("extracts caller type from header", func(t *testing.T) {
m := &CallerInfoMiddleware{}
h := &fakeHandler{}
headers := transport.NewHeaders().With(types.CallerTypeHeaderName, "cli")
err := m.Handle(context.Background(), &transport.Request{Headers: headers}, nil, h)
assert.NoError(t, err)

callerInfo := types.GetCallerInfoFromContext(h.ctx)
assert.Equal(t, types.CallerTypeCLI, callerInfo.GetCallerType())
})

t.Run("sets unknown caller type when header is missing", func(t *testing.T) {
m := &CallerInfoMiddleware{}
h := &fakeHandler{}
headers := transport.NewHeaders()
err := m.Handle(context.Background(), &transport.Request{Headers: headers}, nil, h)
assert.NoError(t, err)

callerInfo := types.GetCallerInfoFromContext(h.ctx)
assert.Equal(t, types.CallerTypeUnknown, callerInfo.GetCallerType())
})

t.Run("extracts different caller types", func(t *testing.T) {
tests := []struct {
name string
headerValue string
expectedCaller types.CallerType
}{
{"CLI", "cli", types.CallerTypeCLI},
{"UI", "ui", types.CallerTypeUI},
{"SDK", "sdk", types.CallerTypeSDK},
{"Internal", "internal", types.CallerTypeInternal},
{"Empty", "", types.CallerTypeUnknown},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &CallerInfoMiddleware{}
h := &fakeHandler{}
headers := transport.NewHeaders().With(types.CallerTypeHeaderName, tt.headerValue)
err := m.Handle(context.Background(), &transport.Request{Headers: headers}, nil, h)
assert.NoError(t, err)

callerInfo := types.GetCallerInfoFromContext(h.ctx)
assert.Equal(t, tt.expectedCaller, callerInfo.GetCallerType())
})
}
})
}

type fakeHandler struct {
ctx context.Context
}
Expand Down
2 changes: 1 addition & 1 deletion common/rpc/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func NewParams(serviceName string, config *config.Config, dc *dynamicconfig.Coll
OutboundTLS: outboundTLS,
InboundMiddleware: yarpc.InboundMiddleware{
// order matters: ForwardPartitionConfigMiddleware must be applied after ClientPartitionConfigMiddleware
Unary: yarpc.UnaryInboundMiddleware(&InboundMetricsMiddleware{}, &ClientPartitionConfigMiddleware{}, &ForwardPartitionConfigMiddleware{}),
Unary: yarpc.UnaryInboundMiddleware(&InboundMetricsMiddleware{}, &CallerInfoMiddleware{}, &ClientPartitionConfigMiddleware{}, &ForwardPartitionConfigMiddleware{}),
},
OutboundMiddleware: yarpc.OutboundMiddleware{
Unary: yarpc.UnaryOutboundMiddleware(&HeaderForwardingMiddleware{
Expand Down
59 changes: 23 additions & 36 deletions common/types/caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ package types

import (
"context"

"go.uber.org/yarpc"
)

const (
Expand Down Expand Up @@ -62,15 +60,12 @@ type CallerInfo struct {
}

// NewCallerInfo creates a new CallerInfo
func NewCallerInfo(callerType CallerType) *CallerInfo {
return &CallerInfo{callerType: callerType}
func NewCallerInfo(callerType CallerType) CallerInfo {
return CallerInfo{callerType: callerType}
}

// GetCallerType returns the CallerType, or CallerTypeUnknown if CallerInfo is nil
func (c *CallerInfo) GetCallerType() CallerType {
if c == nil {
return CallerTypeUnknown
}
// GetCallerType returns the CallerType
func (c CallerInfo) GetCallerType() CallerType {
return c.callerType
}

Expand All @@ -92,43 +87,35 @@ func ParseCallerType(s string) CallerType {
}

// ContextWithCallerInfo adds CallerInfo to context
func ContextWithCallerInfo(ctx context.Context, callerInfo *CallerInfo) context.Context {
if callerInfo == nil {
return ctx
}
func ContextWithCallerInfo(ctx context.Context, callerInfo CallerInfo) context.Context {
return context.WithValue(ctx, callerInfoKey, callerInfo)
}

// GetCallerInfoFromContext retrieves CallerInfo from context, returns nil if not set
func GetCallerInfoFromContext(ctx context.Context) *CallerInfo {
// GetCallerInfoFromContext retrieves CallerInfo from context
// Returns CallerInfo with CallerTypeUnknown if not set in context
func GetCallerInfoFromContext(ctx context.Context) CallerInfo {
if ctx == nil {
return nil
return NewCallerInfo(CallerTypeUnknown)
}
if callerInfo, ok := ctx.Value(callerInfoKey).(CallerInfo); ok {
return callerInfo
}
callerInfo, _ := ctx.Value(callerInfoKey).(*CallerInfo)
return callerInfo
return NewCallerInfo(CallerTypeUnknown)
}

// GetCallerInfoFromHeaders extracts CallerInfo from YARPC headers in the context
func GetCallerInfoFromHeaders(ctx context.Context) *CallerInfo {
if ctx == nil {
return nil
}
// NewCallerInfoFromTransportHeaders extracts CallerInfo from transport headers
// This is used by middleware to extract caller information from incoming requests
func NewCallerInfoFromTransportHeaders(headers interface{ Get(string) (string, bool) }) CallerInfo {
callerTypeStr, _ := headers.Get(CallerTypeHeaderName)

call := yarpc.CallFromContext(ctx)
if call == nil {
return nil
}
// Future: add more header extractions here
// version, _ := headers.Get("cadence-client-version")
// identity, _ := headers.Get("cadence-client-identity")

callerTypeStr := call.Header(CallerTypeHeaderName)
return NewCallerInfo(ParseCallerType(callerTypeStr))
}

// GetContextWithCallerInfoFromHeaders extracts CallerInfo from YARPC headers and adds it to the context
// Returns the original context if no caller info is found in headers
func GetContextWithCallerInfoFromHeaders(ctx context.Context) context.Context {
callerInfo := GetCallerInfoFromHeaders(ctx)
if callerInfo == nil {
return ctx
}
return ContextWithCallerInfo(ctx, callerInfo)
// GetContextWithCallerInfoFromHeaders extracts CallerInfo from transport headers and adds it to the context
func GetContextWithCallerInfoFromHeaders(ctx context.Context, headers interface{ Get(string) (string, bool) }) context.Context {
return ContextWithCallerInfo(ctx, NewCallerInfoFromTransportHeaders(headers))
}
Loading
Loading