Skip to content

Commit 619c9dc

Browse files
committed
chore: apply the default interface/mark of the dialer in the final stage
1 parent 9c5067e commit 619c9dc

File tree

2 files changed

+43
-48
lines changed

2 files changed

+43
-48
lines changed

component/dialer/dialer.go

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,16 @@ const (
2020
DefaultUDPTimeout = DefaultTCPTimeout
2121
)
2222

23-
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
23+
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error)
2424

2525
var (
2626
dialMux sync.Mutex
27-
IP4PEnable bool
2827
actualSingleStackDialContext = serialSingleStackDialContext
2928
actualDualStackDialContext = serialDualStackDialContext
3029
tcpConcurrent = false
3130
fallbackTimeout = 300 * time.Millisecond
3231
)
3332

34-
func applyOptions(options ...Option) *option {
35-
opt := &option{
36-
interfaceName: DefaultInterface.Load(),
37-
routingMark: int(DefaultRoutingMark.Load()),
38-
}
39-
40-
for _, o := range DefaultOptions {
41-
o(opt)
42-
}
43-
44-
for _, o := range options {
45-
o(opt)
46-
}
47-
48-
return opt
49-
}
50-
5133
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
5234
opt := applyOptions(options...)
5335

@@ -77,38 +59,43 @@ func DialContext(ctx context.Context, network, address string, options ...Option
7759
}
7860

7961
func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
80-
cfg := applyOptions(options...)
62+
opt := applyOptions(options...)
8163

8264
lc := &net.ListenConfig{}
83-
if cfg.addrReuse {
65+
if opt.addrReuse {
8466
addrReuseToListenConfig(lc)
8567
}
8668
if DefaultSocketHook != nil { // ignore interfaceName, routingMark when DefaultSocketHook not null (in CMFA)
8769
socketHookToListenConfig(lc)
8870
} else {
89-
interfaceName := cfg.interfaceName
90-
if interfaceName == "" {
71+
if opt.interfaceName == "" {
72+
opt.interfaceName = DefaultInterface.Load()
73+
}
74+
if opt.interfaceName == "" {
9175
if finder := DefaultInterfaceFinder.Load(); finder != nil {
92-
interfaceName = finder.FindInterfaceName(rAddrPort.Addr())
76+
opt.interfaceName = finder.FindInterfaceName(rAddrPort.Addr())
9377
}
9478
}
9579
if rAddrPort.Addr().Unmap().IsLoopback() {
9680
// avoid "The requested address is not valid in its context."
97-
interfaceName = ""
81+
opt.interfaceName = ""
9882
}
99-
if interfaceName != "" {
83+
if opt.interfaceName != "" {
10084
bind := bindIfaceToListenConfig
101-
if cfg.fallbackBind {
85+
if opt.fallbackBind {
10286
bind = fallbackBindIfaceToListenConfig
10387
}
104-
addr, err := bind(interfaceName, lc, network, address, rAddrPort)
88+
addr, err := bind(opt.interfaceName, lc, network, address, rAddrPort)
10589
if err != nil {
10690
return nil, err
10791
}
10892
address = addr
10993
}
110-
if cfg.routingMark != 0 {
111-
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
94+
if opt.routingMark == 0 {
95+
opt.routingMark = int(DefaultRoutingMark.Load())
96+
}
97+
if opt.routingMark != 0 {
98+
bindMarkToListenConfig(opt.routingMark, lc, network, address)
11299
}
113100
}
114101

@@ -134,7 +121,7 @@ func GetTcpConcurrent() bool {
134121
return tcpConcurrent
135122
}
136123

137-
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
124+
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) {
138125
var address string
139126
destination, port = resolver.LookupIP4P(destination, port)
140127
address = net.JoinHostPort(destination.String(), port)
@@ -159,21 +146,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
159146
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA)
160147
socketHookToToDialer(dialer)
161148
} else {
162-
interfaceName := opt.interfaceName // don't change the "opt", it's a pointer
163-
if interfaceName == "" {
149+
if opt.interfaceName == "" {
150+
opt.interfaceName = DefaultInterface.Load()
151+
}
152+
if opt.interfaceName == "" {
164153
if finder := DefaultInterfaceFinder.Load(); finder != nil {
165-
interfaceName = finder.FindInterfaceName(destination)
154+
opt.interfaceName = finder.FindInterfaceName(destination)
166155
}
167156
}
168-
if interfaceName != "" {
157+
if opt.interfaceName != "" {
169158
bind := bindIfaceToDialer
170159
if opt.fallbackBind {
171160
bind = fallbackBindIfaceToDialer
172161
}
173-
if err := bind(interfaceName, dialer, network, destination); err != nil {
162+
if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
174163
return nil, err
175164
}
176165
}
166+
if opt.routingMark == 0 {
167+
opt.routingMark = int(DefaultRoutingMark.Load())
168+
}
177169
if opt.routingMark != 0 {
178170
bindMarkToDialer(opt.routingMark, dialer, network, destination)
179171
}
@@ -185,26 +177,26 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
185177
return dialer.DialContext(ctx, network, address)
186178
}
187179

188-
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
180+
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
189181
return serialDialContext(ctx, network, ips, port, opt)
190182
}
191183

192-
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
184+
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
193185
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
194186
}
195187

196-
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
188+
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
197189
return parallelDialContext(ctx, network, ips, port, opt)
198190
}
199191

200-
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
192+
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
201193
if opt.prefer != 4 && opt.prefer != 6 {
202194
return parallelDialContext(ctx, network, ips, port, opt)
203195
}
204196
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
205197
}
206198

207-
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
199+
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
208200
ipv4s, ipv6s := resolver.SortationAddr(ips)
209201
if len(ipv4s) == 0 && len(ipv6s) == 0 {
210202
return nil, ErrorNoIpAddress
@@ -285,7 +277,7 @@ loop:
285277
return nil, errors.Join(errs...)
286278
}
287279

288-
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
280+
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
289281
if len(ips) == 0 {
290282
return nil, ErrorNoIpAddress
291283
}
@@ -324,7 +316,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr,
324316
return nil, os.ErrDeadlineExceeded
325317
}
326318

327-
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
319+
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
328320
if len(ips) == 0 {
329321
return nil, ErrorNoIpAddress
330322
}
@@ -390,5 +382,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
390382

391383
func NewDialer(options ...Option) Dialer {
392384
opt := applyOptions(options...)
393-
return Dialer{Opt: *opt}
385+
return Dialer{Opt: opt}
394386
}

component/dialer/options.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
)
1111

1212
var (
13-
DefaultOptions []Option
1413
DefaultInterface = atomic.NewTypedValue[string]("")
1514
DefaultRoutingMark = atomic.NewInt32(0)
1615

@@ -117,9 +116,13 @@ func WithOption(o option) Option {
117116
}
118117

119118
func IsZeroOptions(opts []Option) bool {
120-
var opt option
121-
for _, o := range opts {
119+
return applyOptions(opts...) == option{}
120+
}
121+
122+
func applyOptions(options ...Option) option {
123+
opt := option{}
124+
for _, o := range options {
122125
o(&opt)
123126
}
124-
return opt == option{}
127+
return opt
125128
}

0 commit comments

Comments
 (0)