Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig)
}
// Start by opening the network connection (tcp, tls, ws) etc
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions)
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.Dialer)
if err != nil {
ERROR.Println(CLI, err.Error())
WARN.Println(CLI, "failed to connect to broker, trying next")
Expand Down
11 changes: 5 additions & 6 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (

// openConnection opens a network connection using the protocol indicated in the URL.
// Does not carry out any MQTT specific handshakes.
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions) (net.Conn, error) {
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, dialer *net.Dialer) (net.Conn, error) {
switch uri.Scheme {
case "ws":
conn, err := NewWebsocket(uri.String(), nil, timeout, headers, websocketOptions)
Expand All @@ -48,7 +48,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "mqtt", "tcp":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
conn, err := dialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
Expand All @@ -68,9 +68,9 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
// this check is preserved for compatibility with older versions
// which used uri.Host only (it works for local paths, e.g. unix://socket.sock in current dir)
if len(uri.Host) > 0 {
conn, err = net.DialTimeout("unix", uri.Host, timeout)
conn, err = dialer.Dial("unix", uri.Host)
} else {
conn, err = net.DialTimeout("unix", uri.Path, timeout)
conn, err = dialer.Dial("unix", uri.Path)
}

if err != nil {
Expand All @@ -80,14 +80,13 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
conn, err := tls.DialWithDialer(dialer, "tcp", uri.Host, tlsc)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()

conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
Expand Down
10 changes: 10 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package mqtt

import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -96,6 +97,7 @@ type ClientOptions struct {
HTTPHeaders http.Header
WebsocketOptions *WebsocketOptions
MaxResumePubInFlight int // // 0 = no limit; otherwise this is the maximum simultaneous messages sent while resuming
Dialer *net.Dialer
}

// NewClientOptions will create a new ClientClientOptions type with some
Expand Down Expand Up @@ -137,6 +139,7 @@ func NewClientOptions() *ClientOptions {
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
WebsocketOptions: &WebsocketOptions{},
Dialer: &net.Dialer{Timeout: 30 * time.Second},
}
return o
}
Expand Down Expand Up @@ -355,6 +358,7 @@ func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions {
// Default 30 seconds. Currently only operational on TCP/TLS connections.
func (o *ClientOptions) SetConnectTimeout(t time.Duration) *ClientOptions {
o.ConnectTimeout = t
o.Dialer.Timeout = t
return o
}

Expand Down Expand Up @@ -419,3 +423,9 @@ func (o *ClientOptions) SetMaxResumePubInFlight(MaxResumePubInFlight int) *Clien
o.MaxResumePubInFlight = MaxResumePubInFlight
return o
}

// SetDialer sets the tcp dialer options used in a tcp connection
func (o *ClientOptions) SetDialer(dialer *net.Dialer) *ClientOptions {
o.Dialer = dialer
return o
}