diff --git a/client/client.go b/client/client.go index c2260966f..42ef81cdb 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ package client import ( "bufio" "context" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -11,10 +12,13 @@ import ( "heckel.io/ntfy/v2/util" "io" "net/http" + "os" "regexp" "strings" "sync" "time" + + "software.sslmate.com/src/go-pkcs12" ) const ( @@ -34,6 +38,7 @@ var ( type Client struct { Messages chan *Message config *Config + http *http.Client subscriptions map[string]*subscription mu sync.Mutex } @@ -75,12 +80,45 @@ type subscription struct { } // New creates a new Client using a given Config -func New(config *Config) *Client { +func New(config *Config) (*Client, error) { + httpClient, err := newHTTPClient(config) + if err != nil { + return nil, err + } return &Client{ Messages: make(chan *Message, 50), // Allow reading a few messages config: config, + http: httpClient, subscriptions: make(map[string]*subscription), + }, nil +} + +// newHTTPClient creates an HTTP client, optionally configured with a PKCS#12 client certificate +// for mTLS when config.CertFile is set. +func newHTTPClient(config *Config) (*http.Client, error) { + if config.CertFile == "" { + return &http.Client{}, nil + } + p12Data, err := os.ReadFile(config.CertFile) + if err != nil { + return nil, fmt.Errorf("reading cert file %s: %w", config.CertFile, err) + } + privateKey, cert, caCerts, err := pkcs12.DecodeChain(p12Data, config.CertPassword) + if err != nil { + return nil, fmt.Errorf("decoding cert file %s: %w", config.CertFile, err) + } + tlsCert := tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: privateKey, + Leaf: cert, + } + for _, ca := range caCerts { + tlsCert.Certificate = append(tlsCert.Certificate, ca.Raw) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, } + return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}, nil } // Publish sends a message to a specific topic, optionally using options. @@ -112,7 +150,7 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO } } log.Debug("%s Publishing message with headers %s", util.ShortTopicURL(topicURL), req.Header) - resp, err := http.DefaultClient.Do(req) + resp, err := c.http.Do(req) if err != nil { return nil, err } @@ -152,7 +190,7 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) options = append(options, WithPoll()) go func() { - err := performSubscribeRequest(ctx, msgChan, topicURL, "", options...) + err := performSubscribeRequest(ctx, c.http, msgChan, topicURL, "", options...) close(msgChan) errChan <- err }() @@ -196,7 +234,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, er topicURL: topicURL, cancel: cancel, } - go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) + go handleSubscribeConnLoop(ctx, c.http, c.Messages, topicURL, subscriptionID, options...) return subscriptionID, nil } @@ -225,11 +263,11 @@ func (c *Client) expandTopicURL(topic string) (string, error) { return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil } -func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { +func handleSubscribeConnLoop(ctx context.Context, httpClient *http.Client, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { for { // TODO The retry logic is crude and may lose messages. It should record the last message like the // Android client, use since=, and do incremental backoff too - if err := performSubscribeRequest(ctx, msgChan, topicURL, subcriptionID, options...); err != nil { + if err := performSubscribeRequest(ctx, httpClient, msgChan, topicURL, subcriptionID, options...); err != nil { log.Warn("%s Connection failed: %s", util.ShortTopicURL(topicURL), err.Error()) } select { @@ -241,7 +279,7 @@ func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicUR } } -func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error { +func performSubscribeRequest(ctx context.Context, httpClient *http.Client, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error { streamURL := fmt.Sprintf("%s/json", topicURL) log.Debug("%s Listening to %s", util.ShortTopicURL(topicURL), streamURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, streamURL, nil) @@ -253,7 +291,7 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR return err } } - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return err } diff --git a/client/client_test.go b/client/client_test.go index a6784ff8a..15287205e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -19,7 +19,7 @@ func TestMain(m *testing.M) { func TestClient_Publish_Subscribe(t *testing.T) { s, port := test.StartServer(t) defer test.StopServer(t, s, port) - c := client.New(newTestConfig(port)) + c, _ := client.New(newTestConfig(port)) subscriptionID, _ := c.Subscribe("mytopic") time.Sleep(time.Second) @@ -74,7 +74,7 @@ func TestClient_Publish_Subscribe(t *testing.T) { func TestClient_Publish_Poll(t *testing.T) { s, port := test.StartServer(t) defer test.StopServer(t, s, port) - c := client.New(newTestConfig(port)) + c, _ := client.New(newTestConfig(port)) msg, err := c.Publish("mytopic", "some message", client.WithNoFirebase(), client.WithTagsList("tag1,tag2")) require.Nil(t, err) diff --git a/client/config.go b/client/config.go index 444460d63..861e68438 100644 --- a/client/config.go +++ b/client/config.go @@ -22,6 +22,11 @@ type Config struct { DefaultToken string `yaml:"default-token"` DefaultCommand string `yaml:"default-command"` Subscribe []Subscribe `yaml:"subscribe"` + + // CertFile is the path to a PKCS#12 (.p12) file used for mTLS client authentication. + // CertPassword is the password for the PKCS#12 file (may be empty). + CertFile string `yaml:"cert-file"` + CertPassword string `yaml:"cert-password"` } // Subscribe is the struct for a Subscription within Config @@ -43,6 +48,8 @@ func NewConfig() *Config { DefaultToken: "", DefaultCommand: "", Subscribe: nil, + CertFile: "", + CertPassword: "", } } diff --git a/cmd/config_loader.go b/cmd/config_loader.go index e6180bed7..61dbfdebc 100644 --- a/cmd/config_loader.go +++ b/cmd/config_loader.go @@ -5,8 +5,10 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "gopkg.in/yaml.v2" + "heckel.io/ntfy/v2/client" "heckel.io/ntfy/v2/util" "os" + "strings" ) // initConfigFileInputSourceFunc is like altsrc.InitInputSourceWithContext and altsrc.NewYamlSourceFromFlagFunc, but checks @@ -58,3 +60,21 @@ func newYamlSourceFromFile(file string, flags []cli.Flag) (altsrc.InputSourceCon } return altsrc.NewMapInputSource(file, rawConfig), nil } + +// parseCertFlag reads the --cert flag (format: file.p12 or file.p12:password) and +// populates conf.CertFile and conf.CertPassword. It is a no-op when the flag is not set. +func parseCertFlag(c *cli.Context, conf *client.Config) error { + certArg := c.String("cert") + if certArg == "" { + return nil + } + // Split on the LAST colon to allow Windows paths like C:\path\to\file.p12:pass + if idx := strings.LastIndex(certArg, ":"); idx != -1 { + conf.CertFile = certArg[:idx] + conf.CertPassword = certArg[idx+1:] + } else { + conf.CertFile = certArg + conf.CertPassword = "" + } + return nil +} diff --git a/cmd/publish.go b/cmd/publish.go index c80c140b3..75e9aaf72 100644 --- a/cmd/publish.go +++ b/cmd/publish.go @@ -39,6 +39,7 @@ var flagsPublish = append( &cli.StringFlag{Name: "email", Aliases: []string{"mail", "e"}, EnvVars: []string{"NTFY_EMAIL"}, Usage: "also send to e-mail address"}, &cli.StringFlag{Name: "user", Aliases: []string{"u"}, EnvVars: []string{"NTFY_USER"}, Usage: "username[:password] used to auth against the server"}, &cli.StringFlag{Name: "token", Aliases: []string{"k"}, EnvVars: []string{"NTFY_TOKEN"}, Usage: "access token used to auth against the server"}, + &cli.StringFlag{Name: "cert", EnvVars: []string{"NTFY_CERT"}, Usage: "PKCS#12 client certificate for mTLS, in the format file.p12[:password]"}, &cli.IntFlag{Name: "wait-pid", Aliases: []string{"wait_pid", "pid"}, EnvVars: []string{"NTFY_WAIT_PID"}, Usage: "wait until PID exits before publishing"}, &cli.BoolFlag{Name: "wait-cmd", Aliases: []string{"wait_cmd", "cmd", "done"}, EnvVars: []string{"NTFY_WAIT_CMD"}, Usage: "run command and wait until it finishes before publishing"}, &cli.BoolFlag{Name: "no-cache", Aliases: []string{"no_cache", "C"}, EnvVars: []string{"NTFY_NO_CACHE"}, Usage: "do not cache message server-side"}, @@ -92,6 +93,9 @@ func execPublish(c *cli.Context) error { if err != nil { return err } + if err := parseCertFlag(c, conf); err != nil { + return err + } title := c.String("title") priority := c.String("priority") tags := c.String("tags") @@ -229,7 +233,10 @@ func execPublish(c *cli.Context) error { } } } - cl := client.New(conf) + cl, err := client.New(conf) + if err != nil { + return err + } m, err := cl.PublishReader(topic, body, options...) if err != nil { return err diff --git a/cmd/serve_test.go b/cmd/serve_test.go index b89efa8ad..be7ca2501 100644 --- a/cmd/serve_test.go +++ b/cmd/serve_test.go @@ -507,7 +507,7 @@ func TestCLI_Serve_WebSocket(t *testing.T) { require.Equal(t, websocket.TextMessage, messageType) require.Equal(t, "open", toMessage(t, string(data)).Event) - c := client.New(client.NewConfig()) + c, _ := client.New(client.NewConfig()) _, err = c.Publish(fmt.Sprintf("http://127.0.0.1:%d/mytopic", port), "my message") require.Nil(t, err) diff --git a/cmd/subscribe.go b/cmd/subscribe.go index 844509272..d548a1736 100644 --- a/cmd/subscribe.go +++ b/cmd/subscribe.go @@ -24,6 +24,7 @@ var flagsSubscribe = append( &cli.StringFlag{Name: "since", Aliases: []string{"s"}, Usage: "return events since `SINCE` (Unix timestamp, or all)"}, &cli.StringFlag{Name: "user", Aliases: []string{"u"}, EnvVars: []string{"NTFY_USER"}, Usage: "username[:password] used to auth against the server"}, &cli.StringFlag{Name: "token", Aliases: []string{"k"}, EnvVars: []string{"NTFY_TOKEN"}, Usage: "access token used to auth against the server"}, + &cli.StringFlag{Name: "cert", EnvVars: []string{"NTFY_CERT"}, Usage: "PKCS#12 client certificate for mTLS, in the format file.p12[:password]"}, &cli.BoolFlag{Name: "from-config", Aliases: []string{"from_config", "C"}, Usage: "read subscriptions from config file (service mode)"}, &cli.BoolFlag{Name: "poll", Aliases: []string{"p"}, Usage: "return events and exit, do not listen for new events"}, &cli.BoolFlag{Name: "scheduled", Aliases: []string{"sched", "S"}, Usage: "also return scheduled/delayed events"}, @@ -88,7 +89,13 @@ func execSubscribe(c *cli.Context) error { if err != nil { return err } - cl := client.New(conf) + if err := parseCertFlag(c, conf); err != nil { + return err + } + cl, err := client.New(conf) + if err != nil { + return err + } since := c.String("since") user := c.String("user") token := c.String("token") diff --git a/go.mod b/go.mod index 992aced83..6c06d6cc4 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/stripe/stripe-go/v74 v74.30.0 golang.org/x/sys v0.40.0 golang.org/x/text v0.33.0 + software.sslmate.com/src/go-pkcs12 v0.7.0 ) require ( diff --git a/go.sum b/go.sum index c42210b40..c8a56e2c8 100644 --- a/go.sum +++ b/go.sum @@ -287,3 +287,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0= +software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=