Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ package client
import (
"bufio"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
"io"
"net/http"
"os"
"regexp"
"strings"
"sync"
"time"

"software.sslmate.com/src/go-pkcs12"
)

const (
Expand All @@ -34,6 +38,7 @@ var (
type Client struct {
Messages chan *Message
config *Config
http *http.Client
subscriptions map[string]*subscription
mu sync.Mutex
}
Expand Down Expand Up @@ -75,12 +80,45 @@ type subscription struct {
}

// New creates a new Client using a given Config
func New(config *Config) *Client {
func New(config *Config) (*Client, error) {
httpClient, err := newHTTPClient(config)
if err != nil {
return nil, err
}
return &Client{
Messages: make(chan *Message, 50), // Allow reading a few messages
config: config,
http: httpClient,
subscriptions: make(map[string]*subscription),
}, nil
}

// newHTTPClient creates an HTTP client, optionally configured with a PKCS#12 client certificate
// for mTLS when config.CertFile is set.
func newHTTPClient(config *Config) (*http.Client, error) {
if config.CertFile == "" {
return &http.Client{}, nil
}
p12Data, err := os.ReadFile(config.CertFile)
if err != nil {
return nil, fmt.Errorf("reading cert file %s: %w", config.CertFile, err)
}
privateKey, cert, caCerts, err := pkcs12.DecodeChain(p12Data, config.CertPassword)
if err != nil {
return nil, fmt.Errorf("decoding cert file %s: %w", config.CertFile, err)
}
tlsCert := tls.Certificate{
Certificate: [][]byte{cert.Raw},
PrivateKey: privateKey,
Leaf: cert,
}
for _, ca := range caCerts {
tlsCert.Certificate = append(tlsCert.Certificate, ca.Raw)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
}
return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}, nil
}

// Publish sends a message to a specific topic, optionally using options.
Expand Down Expand Up @@ -112,7 +150,7 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO
}
}
log.Debug("%s Publishing message with headers %s", util.ShortTopicURL(topicURL), req.Header)
resp, err := http.DefaultClient.Do(req)
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -152,7 +190,7 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
options = append(options, WithPoll())
go func() {
err := performSubscribeRequest(ctx, msgChan, topicURL, "", options...)
err := performSubscribeRequest(ctx, c.http, msgChan, topicURL, "", options...)
close(msgChan)
errChan <- err
}()
Expand Down Expand Up @@ -196,7 +234,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, er
topicURL: topicURL,
cancel: cancel,
}
go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
go handleSubscribeConnLoop(ctx, c.http, c.Messages, topicURL, subscriptionID, options...)
return subscriptionID, nil
}

Expand Down Expand Up @@ -225,11 +263,11 @@ func (c *Client) expandTopicURL(topic string) (string, error) {
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil
}

func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {
func handleSubscribeConnLoop(ctx context.Context, httpClient *http.Client, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {
for {
// TODO The retry logic is crude and may lose messages. It should record the last message like the
// Android client, use since=, and do incremental backoff too
if err := performSubscribeRequest(ctx, msgChan, topicURL, subcriptionID, options...); err != nil {
if err := performSubscribeRequest(ctx, httpClient, msgChan, topicURL, subcriptionID, options...); err != nil {
log.Warn("%s Connection failed: %s", util.ShortTopicURL(topicURL), err.Error())
}
select {
Expand All @@ -241,7 +279,7 @@ func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicUR
}
}

func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error {
func performSubscribeRequest(ctx context.Context, httpClient *http.Client, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error {
streamURL := fmt.Sprintf("%s/json", topicURL)
log.Debug("%s Listening to %s", util.ShortTopicURL(topicURL), streamURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, streamURL, nil)
Expand All @@ -253,7 +291,7 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR
return err
}
}
resp, err := http.DefaultClient.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestMain(m *testing.M) {
func TestClient_Publish_Subscribe(t *testing.T) {
s, port := test.StartServer(t)
defer test.StopServer(t, s, port)
c := client.New(newTestConfig(port))
c, _ := client.New(newTestConfig(port))

subscriptionID, _ := c.Subscribe("mytopic")
time.Sleep(time.Second)
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestClient_Publish_Subscribe(t *testing.T) {
func TestClient_Publish_Poll(t *testing.T) {
s, port := test.StartServer(t)
defer test.StopServer(t, s, port)
c := client.New(newTestConfig(port))
c, _ := client.New(newTestConfig(port))

msg, err := c.Publish("mytopic", "some message", client.WithNoFirebase(), client.WithTagsList("tag1,tag2"))
require.Nil(t, err)
Expand Down
7 changes: 7 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ type Config struct {
DefaultToken string `yaml:"default-token"`
DefaultCommand string `yaml:"default-command"`
Subscribe []Subscribe `yaml:"subscribe"`

// CertFile is the path to a PKCS#12 (.p12) file used for mTLS client authentication.
// CertPassword is the password for the PKCS#12 file (may be empty).
CertFile string `yaml:"cert-file"`
CertPassword string `yaml:"cert-password"`
}

// Subscribe is the struct for a Subscription within Config
Expand All @@ -43,6 +48,8 @@ func NewConfig() *Config {
DefaultToken: "",
DefaultCommand: "",
Subscribe: nil,
CertFile: "",
CertPassword: "",
}
}

Expand Down
20 changes: 20 additions & 0 deletions cmd/config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/client"
"heckel.io/ntfy/v2/util"
"os"
"strings"
)

// initConfigFileInputSourceFunc is like altsrc.InitInputSourceWithContext and altsrc.NewYamlSourceFromFlagFunc, but checks
Expand Down Expand Up @@ -58,3 +60,21 @@ func newYamlSourceFromFile(file string, flags []cli.Flag) (altsrc.InputSourceCon
}
return altsrc.NewMapInputSource(file, rawConfig), nil
}

// parseCertFlag reads the --cert flag (format: file.p12 or file.p12:password) and
// populates conf.CertFile and conf.CertPassword. It is a no-op when the flag is not set.
func parseCertFlag(c *cli.Context, conf *client.Config) error {
certArg := c.String("cert")
if certArg == "" {
return nil
}
// Split on the LAST colon to allow Windows paths like C:\path\to\file.p12:pass
if idx := strings.LastIndex(certArg, ":"); idx != -1 {
conf.CertFile = certArg[:idx]
conf.CertPassword = certArg[idx+1:]
} else {
conf.CertFile = certArg
conf.CertPassword = ""
}
return nil
}
9 changes: 8 additions & 1 deletion cmd/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ var flagsPublish = append(
&cli.StringFlag{Name: "email", Aliases: []string{"mail", "e"}, EnvVars: []string{"NTFY_EMAIL"}, Usage: "also send to e-mail address"},
&cli.StringFlag{Name: "user", Aliases: []string{"u"}, EnvVars: []string{"NTFY_USER"}, Usage: "username[:password] used to auth against the server"},
&cli.StringFlag{Name: "token", Aliases: []string{"k"}, EnvVars: []string{"NTFY_TOKEN"}, Usage: "access token used to auth against the server"},
&cli.StringFlag{Name: "cert", EnvVars: []string{"NTFY_CERT"}, Usage: "PKCS#12 client certificate for mTLS, in the format file.p12[:password]"},
&cli.IntFlag{Name: "wait-pid", Aliases: []string{"wait_pid", "pid"}, EnvVars: []string{"NTFY_WAIT_PID"}, Usage: "wait until PID exits before publishing"},
&cli.BoolFlag{Name: "wait-cmd", Aliases: []string{"wait_cmd", "cmd", "done"}, EnvVars: []string{"NTFY_WAIT_CMD"}, Usage: "run command and wait until it finishes before publishing"},
&cli.BoolFlag{Name: "no-cache", Aliases: []string{"no_cache", "C"}, EnvVars: []string{"NTFY_NO_CACHE"}, Usage: "do not cache message server-side"},
Expand Down Expand Up @@ -92,6 +93,9 @@ func execPublish(c *cli.Context) error {
if err != nil {
return err
}
if err := parseCertFlag(c, conf); err != nil {
return err
}
title := c.String("title")
priority := c.String("priority")
tags := c.String("tags")
Expand Down Expand Up @@ -229,7 +233,10 @@ func execPublish(c *cli.Context) error {
}
}
}
cl := client.New(conf)
cl, err := client.New(conf)
if err != nil {
return err
}
m, err := cl.PublishReader(topic, body, options...)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion cmd/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
require.Equal(t, websocket.TextMessage, messageType)
require.Equal(t, "open", toMessage(t, string(data)).Event)

c := client.New(client.NewConfig())
c, _ := client.New(client.NewConfig())
_, err = c.Publish(fmt.Sprintf("http://127.0.0.1:%d/mytopic", port), "my message")
require.Nil(t, err)

Expand Down
9 changes: 8 additions & 1 deletion cmd/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var flagsSubscribe = append(
&cli.StringFlag{Name: "since", Aliases: []string{"s"}, Usage: "return events since `SINCE` (Unix timestamp, or all)"},
&cli.StringFlag{Name: "user", Aliases: []string{"u"}, EnvVars: []string{"NTFY_USER"}, Usage: "username[:password] used to auth against the server"},
&cli.StringFlag{Name: "token", Aliases: []string{"k"}, EnvVars: []string{"NTFY_TOKEN"}, Usage: "access token used to auth against the server"},
&cli.StringFlag{Name: "cert", EnvVars: []string{"NTFY_CERT"}, Usage: "PKCS#12 client certificate for mTLS, in the format file.p12[:password]"},
&cli.BoolFlag{Name: "from-config", Aliases: []string{"from_config", "C"}, Usage: "read subscriptions from config file (service mode)"},
&cli.BoolFlag{Name: "poll", Aliases: []string{"p"}, Usage: "return events and exit, do not listen for new events"},
&cli.BoolFlag{Name: "scheduled", Aliases: []string{"sched", "S"}, Usage: "also return scheduled/delayed events"},
Expand Down Expand Up @@ -88,7 +89,13 @@ func execSubscribe(c *cli.Context) error {
if err != nil {
return err
}
cl := client.New(conf)
if err := parseCertFlag(c, conf); err != nil {
return err
}
cl, err := client.New(conf)
if err != nil {
return err
}
since := c.String("since")
user := c.String("user")
token := c.String("token")
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
github.com/stripe/stripe-go/v74 v74.30.0
golang.org/x/sys v0.40.0
golang.org/x/text v0.33.0
software.sslmate.com/src/go-pkcs12 v0.7.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
software.sslmate.com/src/go-pkcs12 v0.7.0 h1:Db8W44cB54TWD7stUFFSWxdfpdn6fZVcDl0w3R4RVM0=
software.sslmate.com/src/go-pkcs12 v0.7.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=