diff --git a/.github/workflows/runtests.yaml b/.github/workflows/runtests.yaml index bcf5a0e..af78f13 100644 --- a/.github/workflows/runtests.yaml +++ b/.github/workflows/runtests.yaml @@ -12,7 +12,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: '1.20' + go-version: '1.21' - name: Build run: go build ./... diff --git a/cluster/agent.go b/cluster/agent.go index 01fc054..3df7a7f 100644 --- a/cluster/agent.go +++ b/cluster/agent.go @@ -7,11 +7,16 @@ package cluster import ( "bytes" "context" + "net" + "path" + "path/filepath" + "strconv" + "github.com/panjf2000/ants/v2" "github.com/wind-c/comqtt/v2/cluster/discovery" "github.com/wind-c/comqtt/v2/cluster/discovery/mlist" "github.com/wind-c/comqtt/v2/cluster/discovery/serf" - "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/cluster/message" "github.com/wind-c/comqtt/v2/cluster/raft" "github.com/wind-c/comqtt/v2/cluster/raft/etcd" @@ -21,10 +26,6 @@ import ( "github.com/wind-c/comqtt/v2/config" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/packets" - "net" - "path" - "path/filepath" - "strconv" ) const ( @@ -119,7 +120,7 @@ func (a *Agent) Start() (err error) { if err := a.grpcService.StartRpcServer(); err != nil { return err } - zero.Info().Str("addr", net.JoinHostPort(a.Config.BindAddr, strconv.Itoa(a.Config.GrpcPort))).Msg("grpc listen at") + log.Info("grpc listen at", "addr", net.JoinHostPort(a.Config.BindAddr, strconv.Itoa(a.Config.GrpcPort))) } // init goroutine pool @@ -176,17 +177,16 @@ func (a *Agent) Stop() { } // stop raft - zero.Info().Msg("stopping raft...") + log.Info("stopping raft...") a.raftPeer.Stop() - zero.Info().Msg("raft stopped") + log.Info("raft stopped") // stop node - zero.Info().Msg("stopping node...") + log.Info("stopping node...") a.membership.Stop() a.grpcService.StopRpcServer() - zero.Info().Msg("grpc server stopped") - zero.Info().Msg("node stopped") - zero.Close() + log.Info("grpc server stopped") + log.Info("node stopped") } func (a *Agent) BindMqttServer(server *mqtt.Server) { @@ -264,7 +264,7 @@ func (a *Agent) raftApplyListener() { } else { continue } - zero.Info().Str("from", msg.NodeID).Str("filter", filter).Uint8("type", msg.Type).Msg("apply listening") + log.Info("apply listening", "from", msg.NodeID, "filter", filter, "type", msg.Type) case <-a.ctx.Done(): return } @@ -306,10 +306,10 @@ func (a *Agent) getPeersFile() string { func (a *Agent) genNodesFile() { if err := discovery.GenNodesFile(a.getNodesFile(), a.membership.Members()); err != nil { - zero.Error().Err(err).Msg("gen nodes file") + log.Error("gen nodes file", "error", err) } if err := a.raftPeer.GenPeersFile(a.getPeersFile()); err != nil { - zero.Error().Err(err).Msg("gen peers file") + log.Error("gen peers file", "error", err) } } @@ -461,39 +461,27 @@ func (a *Agent) processOutboundPacket(pk *packets.Packet) { } func OnJoinLog(nodeId, addr, prompt string, err error) { - logEvent := zero.Info() - if err != nil { - logEvent.Err(err) - } - logEvent.Str("node", nodeId).Str("addr", addr).Msg(prompt) + log.Info(prompt, "error", err, "addr", addr) } func OnApplyLog(leaderId, nodeId string, tp byte, filter []byte, prompt string, err error) { - logEvent := zero.Info() - if err != nil { - logEvent.Err(err) - } - logEvent.Str("leader", leaderId).Str("from", nodeId).Uint8("type", tp).Bytes("filter", filter).Msg(prompt) + log.Info(prompt, "error", err, "leader", leaderId, "from", nodeId, "type", tp, "filter", filter) } func OnPublishPacketLog(direction byte, nodeId, cid, topic string, pid uint16) { - logEvent := zero.Info() if direction == DirectionInbound { - logEvent.Str("d", "inbound").Str("from", nodeId) + log.Info("publish message", "d", "inbound", "from", nodeId, "cid", cid, "pid", pid, "topic", topic) } else { - logEvent.Str("d", "outbound").Str("to", nodeId) + log.Info("publish message", "d", "outbound", "to", nodeId, "cid", cid, "pid", pid, "topic", topic) } - logEvent.Str("cid", cid).Uint16("pid", pid).Str("topic", topic).Msg("publish message") } func OnConnectPacketLog(direction byte, node, clientId string) { - logEvent := zero.Info() if direction == DirectionInbound { - logEvent.Str("d", "inbound").Str("from", node) + log.Info("connection notification", "d", "inbound", "from", node, "cid", clientId) } else { - logEvent.Str("d", "outbound").Str("to", node) + log.Info("connection notification", "d", "outbound", "to", node, "cid", clientId) } - logEvent.Str("cid", clientId).Msg("connection notification") } func (a *Agent) Join(nodeName, addr string) error { @@ -513,15 +501,15 @@ func (a *Agent) Leave() error { func (a *Agent) AddRaftPeer(id, addr string) { a.raftPeer.Join(id, addr) - zero.Info().Str("nid", id).Str("addr", addr).Msg("add peer") + log.Info("add peer", "nid", id, "addr", addr) } func (a *Agent) RemoveRaftPeer(id string) { a.raftPeer.Leave(id) - zero.Info().Str("nid", id).Msg("remove peer") + log.Info("remove peer", "nid", id) } func (a *Agent) GetValue(key string) []string { - zero.Info().Str("key", key).Msg("get value") + log.Info("get value", "key", key) return a.raftPeer.Lookup(key) } diff --git a/cluster/discovery/mlist/delegate.go b/cluster/discovery/mlist/delegate.go index 675c7dc..74b1c9e 100644 --- a/cluster/discovery/mlist/delegate.go +++ b/cluster/discovery/mlist/delegate.go @@ -6,11 +6,12 @@ package mlist import ( "encoding/json" - "github.com/hashicorp/memberlist" - "github.com/wind-c/comqtt/v2/cluster/log/zero" - mqtt "github.com/wind-c/comqtt/v2/mqtt" "sync" "time" + + "github.com/hashicorp/memberlist" + "github.com/wind-c/comqtt/v2/cluster/log" + mqtt "github.com/wind-c/comqtt/v2/mqtt" ) // Maximum number of messages to be held in the queue. @@ -126,7 +127,7 @@ func (d *Delegate) handleQueueDepth() { case <-time.After(15 * time.Minute): n := d.Broadcasts.NumQueued() if n > maxQueueSize { - zero.Info().Int("current", n).Int("limit", maxQueueSize).Msg("delete messages") + log.Info("delete messages", "current", n, "limit", maxQueueSize) d.Broadcasts.Prune(maxQueueSize) } } diff --git a/cluster/discovery/mlist/events.go b/cluster/discovery/mlist/events.go index ec484af..94b1d1c 100644 --- a/cluster/discovery/mlist/events.go +++ b/cluster/discovery/mlist/events.go @@ -7,7 +7,7 @@ package mlist import ( "github.com/hashicorp/memberlist" "github.com/wind-c/comqtt/v2/cluster/discovery" - "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" ) //type Event struct { @@ -50,5 +50,5 @@ func (n *NodeEvents) NotifyUpdate(node *memberlist.Node) { } func onLog(node *memberlist.Node, prompt string) { - zero.Info().Str("node", node.Name).Str("addr", node.Addr.String()).Msg(prompt) + log.Info(prompt, "node", node.Name, "addr", node.Addr.String()) } diff --git a/cluster/discovery/mlist/membership.go b/cluster/discovery/mlist/membership.go index 257d811..358f528 100644 --- a/cluster/discovery/mlist/membership.go +++ b/cluster/discovery/mlist/membership.go @@ -5,13 +5,14 @@ package mlist import ( + "net" + "time" + "github.com/hashicorp/memberlist" mb "github.com/wind-c/comqtt/v2/cluster/discovery" - "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/config" "github.com/wind-c/comqtt/v2/mqtt" - "net" - "time" ) type Membership struct { @@ -25,7 +26,7 @@ type Membership struct { func wrapOptions(conf *config.Cluster) *memberlist.Config { opts := make([]Option, 3) - opts[0] = WithLogOutput(zero.Logger(), LogLevelInfo) //Used to filter memberlist logs + opts[0] = WithLogOutput(log.Writer(), LogLevelInfo) //Used to filter memberlist logs opts[1] = WithBindPort(conf.BindPort) opts[2] = WithHandoffQueueDepth(conf.QueueDepth) if conf.NodeName != "" { @@ -62,7 +63,7 @@ func (m *Membership) Setup() error { return err } } - zero.Info().Str("addr", m.LocalAddr()).Int("port", m.config.BindPort).Msg("local member") + log.Info("local member", "addr", m.LocalAddr(), "port", m.config.BindPort) return nil } @@ -153,7 +154,7 @@ func (m *Membership) SendToOthers(msg []byte) { continue // skip self } if err := m.send(node, msg); err != nil { - zero.Error().Err(err).Str("from", m.config.NodeName).Str("to", node.Name).Msg("send to others") + log.Error("send to others", "error", err, "from", m.config.NodeName, "to", node.Name) } } } @@ -163,7 +164,7 @@ func (m *Membership) SendToNode(nodeName string, msg []byte) error { for _, node := range m.aliveMembers() { if node.Name == nodeName { if err := m.send(node, msg); err != nil { - zero.Error().Err(err).Str("from", m.config.NodeName).Str("to", nodeName).Msg("send to node") + log.Error("send to others", "error", err, "from", m.config.NodeName, "to", nodeName) return err } } diff --git a/cluster/discovery/serf/membership.go b/cluster/discovery/serf/membership.go index 440d359..a2e9cd7 100644 --- a/cluster/discovery/serf/membership.go +++ b/cluster/discovery/serf/membership.go @@ -5,16 +5,15 @@ package serf import ( + "strconv" + "github.com/hashicorp/logutils" "github.com/hashicorp/memberlist" "github.com/hashicorp/serf/serf" mb "github.com/wind-c/comqtt/v2/cluster/discovery" - "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/config" "github.com/wind-c/comqtt/v2/mqtt" - "io" - "os" - "strconv" ) const ( @@ -42,16 +41,11 @@ func wrapOptions(conf *config.Cluster, ech chan serf.Event) *serf.Config { if conf.QueueDepth != 0 { config.MaxQueueDepth = conf.QueueDepth } - var logger io.Writer - if zero.Logger() != nil { - logger = zero.Logger() - } else { - logger = os.Stderr - } + filter := &logutils.LevelFilter{ Levels: []logutils.LogLevel{LogLevelDebug, LogLevelWarn, LogLevelError, LogLevelInfo}, MinLevel: logutils.LogLevel(LogLevelError), - Writer: logger, + Writer: log.Writer(), } config.MemberlistConfig.LogOutput = filter config.LogOutput = filter @@ -93,8 +87,7 @@ func (m *Membership) Setup() (err error) { } } } - zero.Info().Str("addr", m.LocalAddr()).Int("port", m.config.BindPort).Msg("local member") - + log.Info("local member", "addr", m.LocalAddr(), "port", m.config.BindPort) return } @@ -125,11 +118,11 @@ func (m *Membership) Stat() map[string]int64 { func (m *Membership) Stop() { err := m.serf.Leave() if err != nil { - zero.Error().Err(err).Msg("serf leave") + log.Error("serf leave", "error", err) } err = m.serf.Shutdown() if err != nil { - zero.Error().Err(err).Msg("serf shutdown") + log.Error("serf shutdown", "error", err) } // this shuts down the event loop, note that this can't be called multiple times // if we need to do so, we could use a bool, sync.Once or recover from the panic @@ -189,7 +182,6 @@ func (m *Membership) eventLoop() { continue } m.msgCh <- ue.Payload - //zero.Info().Str("name", ue.Name).Msg("serf message") case serf.EventQuery: q := e.(*serf.Query) if q.SourceNode() == m.config.NodeName { @@ -217,7 +209,7 @@ func (m *Membership) SendToNode(nodeName string, msg []byte) error { qp := m.serf.DefaultQueryParams() qp.FilterNodes = m.otherNames(nodeName) if _, err := m.serf.Query(m.config.NodeName, msg, qp); err != nil { - zero.Error().Err(err).Str("from", m.config.NodeName).Str("to", nodeName).Msg("send to node") + log.Error("send to node", "error", err, "from", m.config.NodeName, "to", nodeName) return err } @@ -229,7 +221,7 @@ func (m *Membership) Broadcast(msg []byte) { } func onLog(node *serf.Member, prompt string) { - zero.Info().Str("node", node.Name).Str("addr", node.Addr.String()).Msg(prompt) + log.Info(prompt, "node", node.Name, "addr", node.Addr.String()) } func (m *Membership) isLocal(member serf.Member) bool { diff --git a/cluster/log/default_logger.go b/cluster/log/default_logger.go new file mode 100644 index 0000000..ea077b3 --- /dev/null +++ b/cluster/log/default_logger.go @@ -0,0 +1,76 @@ +package log + +import ( + "context" + "io" + "log/slog" + "os" + "sync" +) + +var defaultLogger *Logger // Singleton instance of the logger. +var mu sync.Mutex // Mutex for ensuring thread safety when initializing the logger. + +// Init initializes the logger with the provided options. +func Init(opt *Options) { + mu.Lock() + defer mu.Unlock() + defaultLogger = New(opt) +} + +// Default returns the default logger instance. +func Default() *slog.Logger { + return defaultLogger.Logger +} + +// Writer returns the writer associated with the logger. +func Writer() io.Writer { + if defaultLogger != nil { + return defaultLogger.writer + } + return nil +} + +// Info logs an informational message with optional arguments. +func Info(msg string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.Info(msg, args...) + } +} + +// Warn logs a warning message with optional arguments. +func Warn(msg string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.Warn(msg, args...) + } +} + +// Error logs an error message with optional arguments. +func Error(msg string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.Error(msg, args...) + } +} + +// Fatal logs a fatal error message with optional arguments and exits the program. +func Fatal(msg string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.Error(msg, args...) + os.Exit(1) + } +} + +// Debug logs a debug message with optional arguments. +func Debug(msg string, args ...interface{}) { + if defaultLogger != nil { + defaultLogger.Debug(msg, args...) + } +} + +// Log logs a message with the specified log level and optional arguments. +func Log(level slog.Level, msg string, args ...interface{}) { + if defaultLogger != nil { + slevel := slog.Level(level) + defaultLogger.Log(context.TODO(), slevel, msg, args...) + } +} diff --git a/cluster/log/logger.go b/cluster/log/logger.go new file mode 100644 index 0000000..2a93df1 --- /dev/null +++ b/cluster/log/logger.go @@ -0,0 +1,159 @@ +package log + +import ( + "context" + "io" + "log/slog" + "os" + + "gopkg.in/natefinch/lumberjack.v2" +) + +// Constants for log formats +const ( + Text = iota // Log format is TEXT. + Json // Log format is JSON. +) + +// Format represents the log format type. +type Format int + +// Options defines configuration options for the logger. +type Options struct { + + // Indicates whether logging is enabled. + Disable bool `json:"disable" yaml:"disable"` + + // Log format, currently supports Text: 0 and JSON: 1, with Text as the default. + Format Format `json:"format" yaml:"format"` + + // Log level, with supported values LevelDebug: 4, LevelInfo: 0, LevelWarn: 4, and LevelError: 8. + Level int `json:"level" yaml:"level"` + + // Filename is the file to write logs to. Backup log files will be retained + // in the same directory. If empty, logs will not be written to a file. + Filename string `json:"filename" yaml:"filename"` + + // MaxSize is the maximum size in megabytes of the log file before it gets + // rotated. It defaults to 100 megabytes. + MaxSize int `json:"maxsize" yaml:"maxsize"` + + // MaxAge is the maximum number of days to retain old log files based on the + // timestamp encoded in their filename. Note that a day is defined as 24 + // hours and may not exactly correspond to calendar days due to daylight + // savings, leap seconds, etc. The default is not to remove old log files + // based on age. + MaxAge int `json:"maxage" yaml:"maxage"` + + // MaxBackups is the maximum number of old log files to retain. The default + // is to retain all old log files (though MaxAge may still cause them to get + // deleted.) + MaxBackups int `json:"maxbackups" yaml:"maxbackups"` + + // Compress determines if the rotated log files should be compressed + // using gzip. The default is not to perform compression. + Compress bool `json:"compress" yaml:"compress"` +} + +// Options defines configuration options for the logger. +func DefaultOptions() *Options { + return &Options{ + MaxSize: 100, + MaxAge: 30, + MaxBackups: 1, + Format: Text, + } +} + +// New creates a new Logger based on the provided options. +func New(opt *Options) *Logger { + if opt == nil { + opt = DefaultOptions() + } + + var writer io.Writer + writer = os.Stdout + + if len(opt.Filename) != 0 { + fileWriter := &lumberjack.Logger{ + Filename: opt.Filename, + MaxSize: opt.MaxSize, + MaxBackups: opt.MaxBackups, + MaxAge: opt.MaxAge, + Compress: opt.Compress, + } + writer = io.MultiWriter(os.Stdout, fileWriter) + } + + return &Logger{ + writer: writer, + Logger: slog.New(NewHandler(opt, writer)), + opt: opt, + } +} + +// Logger is a wrapper for slog.Logger. +type Logger struct { + *slog.Logger + opt *Options + writer io.Writer +} + +// Handler is a wrapper for slog.Handler. +type Handler struct { + opt *Options + internal slog.Handler +} + +// NewHandler creates a new handler based on the provided options and writer. +func NewHandler(opt *Options, writer io.Writer) *Handler { + var handler slog.Handler + + switch opt.Format { + case Text: + handler = slog.NewTextHandler(writer, &slog.HandlerOptions{ + Level: slog.Level(opt.Level), + }) + + case Json: + handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{ + Level: slog.Level(opt.Level), + }) + + default: + handler = slog.NewTextHandler(writer, &slog.HandlerOptions{ + Level: slog.Level(opt.Level), + }) + } + + return &Handler{ + opt: opt, + internal: handler, + } +} + +// Enabled reports whether the handler handles records at the given level. +// The handler ignores records whose level is lower. +func (h *Handler) Enabled(ctx context.Context, level slog.Level) bool { + return !h.opt.Disable && h.internal.Enabled(ctx, level) +} + +// Handle handles the Record. +// It will only be called when Enabled returns true. +func (h *Handler) Handle(ctx context.Context, record slog.Record) error { + return h.internal.Handle(ctx, record) +} + +// WithAttrs returns a new Handler whose attributes consist of +// both the receiver's attributes and the arguments. +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + return h.internal.WithAttrs(attrs) +} + +// WithGroup returns a new Handler with the given group appended to +// the receiver's existing groups. +// The keys of all subsequent attributes, whether added by With or in a +// Record, should be qualified by the sequence of group names. +func (h *Handler) WithGroup(name string) slog.Handler { + return h.internal.WithGroup(name) +} diff --git a/cluster/log/zap/zaplog.go b/cluster/log/zap/zaplog.go deleted file mode 100644 index 1f66c9b..0000000 --- a/cluster/log/zap/zaplog.go +++ /dev/null @@ -1,249 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 wind -// SPDX-FileContributor: wind (573966@qq.com) - -package zap - -import ( - "github.com/wind-c/comqtt/v2/config" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "gopkg.in/natefinch/lumberjack.v2" - "os" - "time" -) - -const ( - console = iota - json -) - -var logger *zap.Logger -var cfg config.Log - -func Init(c config.Log) *zap.Logger { - cfg = c - level := zapcore.Level(cfg.Level) - if !cfg.Enable { - level = zapcore.Level(7) - } - encoder := getEncoder(cfg.Format) - - infoLevel := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { - return lvl < zapcore.WarnLevel && lvl >= level - }) - - warnLevel := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { - return lvl >= zapcore.WarnLevel && lvl >= level - }) - - infoWriter := getLogWriter(cfg.InfoFile) - errorWriter := getLogWriter(cfg.ErrorFile) - - core := zapcore.NewTee( - zapcore.NewCore(encoder, infoWriter, infoLevel), - zapcore.NewCore(encoder, errorWriter, warnLevel), - ) - - ops := make([]zap.Option, 0) - ops = append(ops, zap.AddCaller()) //zap.AddCaller() 才会显示打日志点的文件名和行数 - ops = append(ops, zap.AddCallerSkip(1)) - if cfg.Env == 0 { - ops = append(ops, zap.Development()) - ops = append(ops, zap.AddStacktrace(zapcore.WarnLevel)) //warn以上输出调用堆栈 - } - if cfg.NodeName != "" { - ops = append(ops, zap.Fields(zap.String("nn", cfg.NodeName))) - } - logger = zap.New(core, ops...) - zap.ReplaceGlobals(logger) - - return logger -} - -func getEncoder(tp int) zapcore.Encoder { - conf := zapcore.EncoderConfig{ - MessageKey: "m", - LevelKey: "l", - TimeKey: "t", - CallerKey: "f", - EncodeLevel: zapcore.CapitalLevelEncoder, - EncodeCaller: zapcore.ShortCallerEncoder, - EncodeTime: func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { - enc.AppendString(t.Format("2006-01-02 15:04:05")) - }, - EncodeDuration: func(d time.Duration, enc zapcore.PrimitiveArrayEncoder) { - enc.AppendInt64(int64(d) / 1000000) - }, - } - - encoder := zapcore.NewConsoleEncoder(conf) - if tp == json { - encoder = zapcore.NewJSONEncoder(conf) - } - - return encoder -} - -func getLogWriter(filename string) zapcore.WriteSyncer { - writeSyncer := zapcore.AddSync(os.Stdout) - if filename != "" { - lumberJackLogger := &lumberjack.Logger{ - Filename: filename, // ⽇志⽂件路径 - MaxSize: 100, // 1M=1024KB=1024000byte - MaxBackups: 10, // 最多保留10个备份 - MaxAge: 30, // days - Compress: true, // 是否压缩 disabled by default - } - if cfg.MaxSize != 0 { - lumberJackLogger.MaxSize = cfg.MaxSize - } - if cfg.MaxBackups != 0 { - lumberJackLogger.MaxBackups = cfg.MaxBackups - } - if cfg.MaxAge != 0 { - lumberJackLogger.MaxAge = cfg.MaxAge - } - if !cfg.Compress { - lumberJackLogger.Compress = false - } - - writeSyncer = zapcore.AddSync(lumberJackLogger) - } - - return writeSyncer -} - -func Sync() { - zap.L().Sync() -} - -func getLevel(level string) zapcore.Level { - logLevel := zap.DebugLevel - switch level { - case "debug": - logLevel = zap.DebugLevel - case "info": - logLevel = zap.InfoLevel - case "warn": - logLevel = zap.WarnLevel - case "error": - logLevel = zap.ErrorLevel - case "panic": - logLevel = zap.PanicLevel - case "fatal": - logLevel = zap.FatalLevel - default: - logLevel = zap.InfoLevel - } - return logLevel -} - -// Field isolate the reference -type Field = zap.Field - -func Debug(msg string, fields ...Field) { - zap.L().Debug(msg, fields...) -} - -func Info(msg string, fields ...Field) { - zap.L().Info(msg, fields...) -} -func Infof(template string, args ...interface{}) { - zap.S().Infof(template, args...) -} - -func Warn(msg string, fields ...Field) { - zap.L().Warn(msg, fields...) -} -func Warnf(template string, args ...interface{}) { - zap.S().Warnf(template, args...) -} - -func Error(err error, fields ...Field) { - zap.L().Error(err.Error(), fields...) -} -func Errorf(template string, args ...interface{}) { - zap.S().Errorf(template, args...) -} - -func Panic(msg error, fields ...Field) { - zap.L().Panic(msg.Error(), fields...) -} -func Panicf(template string, args ...interface{}) { - zap.S().Panicf(template, args) -} -func DPanic(msg string, fields ...Field) { - zap.L().DPanic(msg, fields...) -} - -func Fatal(msg string, fields ...Field) { - zap.L().Fatal(msg, fields...) -} -func Fatalf(template string, args ...interface{}) { - zap.S().Fatalf(template, args...) -} - -func String(key string, val string) Field { - return zap.String(key, val) -} -func Strings(key string, val []string) Field { - return zap.Strings(key, val) -} -func ByteString(key string, val []byte) Field { - return zap.ByteString(key, val) -} -func Binary(key string, val []byte) Field { - return zap.Binary(key, val) -} -func Int(key string, val int) Field { - return zap.Int(key, val) -} -func Int8(key string, val int8) Field { - return zap.Int8(key, val) -} -func Int16(key string, val int16) Field { - return zap.Int16(key, val) -} -func Int32(key string, val int32) Field { - return zap.Int32(key, val) -} -func Int64(key string, val int64) Field { - return zap.Int64(key, val) -} -func Uint(key string, val uint) Field { - return zap.Uint(key, val) -} -func Uint8(key string, val uint8) Field { - return zap.Uint8(key, val) -} -func Uint16(key string, val uint16) Field { - return zap.Uint16(key, val) -} -func Uint32(key string, val uint32) Field { - return zap.Uint32(key, val) -} -func Uint64(key string, val uint64) Field { - return zap.Uint64(key, val) -} -func Bool(key string, val bool) Field { - return zap.Bool(key, val) -} -func Duration(key string, val time.Duration) Field { - return zap.Duration(key, val) -} -func Time(key string, val time.Time) Field { - return zap.Time(key, val) -} -func Float32(key string, val float32) Field { - return zap.Float32(key, val) -} -func Float64(key string, val float64) Field { - return zap.Float64(key, val) -} -func Uintptr(key string, val uintptr) Field { - return zap.Uintptr(key, val) -} -func Any(key string, value interface{}) Field { - return zap.Any(key, value) -} diff --git a/cluster/log/zap/zaplog_test.go b/cluster/log/zap/zaplog_test.go deleted file mode 100644 index b3bbe0e..0000000 --- a/cluster/log/zap/zaplog_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 wind -// SPDX-FileContributor: wind (573966@qq.com) - -package zap - -import ( - "errors" - "github.com/wind-c/comqtt/v2/config" - "go.uber.org/zap" - "os" - "testing" -) - -func Test_setupLogger(t *testing.T) { - logger := Init(config.Log{ - Enable: true, - Env: 1, - Level: 1, - InfoFile: "./logs/co-info.log", - ErrorFile: "./logs/co-error.log", - }) - logger.Info("in main args:", zap.String("args", "os.Args")) - logger.Sugar().Infof("in main args:%v", os.Args) - logger.Error("error ", zap.Error(errors.New("test"))) - logger.Sugar().Errorf("error %v", "error") - logger.Warn("warn ", zap.String("kk", "warn 123")) - logger.Sugar().Warnf("warn %v", "warn 123") - logger.Sugar().Infof("env is %v", 123) - logger.Sugar().Infof("ip=%v, port=%v, env=%v", "127.0.0.1", 8080, "prod") -} diff --git a/cluster/log/zero/zerolog.go b/cluster/log/zero/zerolog.go deleted file mode 100644 index 4a1b8c2..0000000 --- a/cluster/log/zero/zerolog.go +++ /dev/null @@ -1,171 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 wind -// SPDX-FileContributor: wind (573966@qq.com) - -package zero - -import ( - "github.com/rs/zerolog" - "github.com/wind-c/comqtt/v2/config" - "gopkg.in/natefinch/lumberjack.v2" - "io" - "os" - "time" -) - -const ( - FormatConsole = iota - FormatJson -) - -var logger *zerolog.Logger -var cfg config.Log -var lws []*levelWriter - -func Init(c config.Log) *zerolog.Logger { - cfg = c - if !cfg.Enable { - zerolog.SetGlobalLevel(zerolog.Disabled) - } - - // init zerolog format - zerolog.TimeFieldFormat = "2006-01-02 15:04:05.000" - zerolog.DurationFieldInteger = true - zerolog.TimestampFieldName = "timestamp" - zerolog.DurationFieldUnit = time.Millisecond - zerolog.TimestampFieldName = "t" - zerolog.LevelFieldName = "l" - zerolog.MessageFieldName = "m" - writers := zerolog.MultiLevelWriter() - if cfg.Format == FormatJson { - if w := getLogWriter(cfg.InfoFile); w != nil { - lw := newLevelWriter(w, zerolog.Level(cfg.Level), zerolog.WarnLevel) - lws = append(lws, lw) - writers = zerolog.MultiLevelWriter(lw) - } - if w := getLogWriter(cfg.ErrorFile); w != nil { - lw := newLevelWriter(w, zerolog.ErrorLevel, zerolog.PanicLevel) - lws = append(lws, lw) - writers = zerolog.MultiLevelWriter(writers, lw) - } - if w := getLogWriter(cfg.ThirdpartyFile); w != nil { - lw := newLevelWriter(w, zerolog.NoLevel, zerolog.NoLevel) - lws = append(lws, lw) - writers = zerolog.MultiLevelWriter(writers, lw) - } - } else { - writers = zerolog.MultiLevelWriter(stdErrWriter()) - } - - lg := zerolog.New(writers).With().Timestamp().Logger() - if cfg.Caller { - lg = zerolog.New(writers).With().Timestamp().Caller().Logger() - } - logger = &lg - return logger -} - -func Logger() *zerolog.Logger { - return logger -} - -func Close() { - for _, w := range lws { - w.Close() - } -} - -func getLogWriter(filename string) io.Writer { - if filename != "" { - lumberJackLogger := &lumberjack.Logger{ - Filename: filename, // ⽇志⽂件路径 - MaxSize: 100, // 1M=1024KB=1024000byte - MaxBackups: 10, // 最多保留10个备份 - MaxAge: 30, // days - Compress: true, // 是否压缩 disabled by default - } - if cfg.MaxSize != 0 { - lumberJackLogger.MaxSize = cfg.MaxSize - } - if cfg.MaxBackups != 0 { - lumberJackLogger.MaxBackups = cfg.MaxBackups - } - if cfg.MaxAge != 0 { - lumberJackLogger.MaxAge = cfg.MaxAge - } - if !cfg.Compress { - lumberJackLogger.Compress = false - } - - return lumberJackLogger - } - - return nil -} - -type levelWriter struct { - writer io.Writer - minLevel zerolog.Level - maxLevel zerolog.Level -} - -func newLevelWriter(w io.Writer, minLevel zerolog.Level, maxLevel zerolog.Level) *levelWriter { - return &levelWriter{ - writer: w, - minLevel: minLevel, - maxLevel: maxLevel, - } -} - -func (lw *levelWriter) Close() error { - if c, ok := lw.writer.(io.Closer); ok { - return c.Close() - } - return nil -} - -func (lw *levelWriter) WriteLevel(level zerolog.Level, p []byte) (n int, err error) { - if level >= lw.minLevel && level <= lw.maxLevel { - return lw.Write(p) - } - return len(p), nil -} - -func (lw *levelWriter) Write(p []byte) (n int, err error) { - return lw.writer.Write(p) -} - -func stdOutWriter() zerolog.LevelWriter { - cw := zerolog.NewConsoleWriter() - cw.NoColor = false - return newLevelWriter(cw, zerolog.DebugLevel, zerolog.WarnLevel) -} - -func stdErrWriter() zerolog.LevelWriter { - cw := zerolog.NewConsoleWriter() - cw.Out = os.Stderr - cw.NoColor = false - return newLevelWriter(cw, zerolog.DebugLevel, zerolog.Disabled) -} - -func Trace() *zerolog.Event { - return logger.Trace() -} -func Debug() *zerolog.Event { - return logger.Debug() -} -func Info() *zerolog.Event { - return logger.Info() -} -func Warn() *zerolog.Event { - return logger.Warn() -} -func Error() *zerolog.Event { - return logger.Error() -} -func Fatal() *zerolog.Event { - return logger.Fatal() -} -func Panic() *zerolog.Event { - return logger.Panic() -} diff --git a/cluster/raft/etcd/kvstore.go b/cluster/raft/etcd/kvstore.go index c0c51bf..e2864dc 100644 --- a/cluster/raft/etcd/kvstore.go +++ b/cluster/raft/etcd/kvstore.go @@ -7,11 +7,11 @@ package etcd import ( "bytes" "encoding/gob" - "github.com/wind-c/comqtt/v2/cluster/log/zero" - "github.com/wind-c/comqtt/v2/cluster/message" - "github.com/wind-c/comqtt/v2/mqtt/packets" "sync" + "github.com/wind-c/comqtt/v2/cluster/log" + "github.com/wind-c/comqtt/v2/cluster/message" + "github.com/wind-c/comqtt/v2/mqtt/packets" "go.etcd.io/etcd/raft/v3/raftpb" "go.etcd.io/etcd/server/v3/etcdserver/api/snap" ) @@ -36,12 +36,12 @@ func newKVStore(snapshotter *snap.Snapshotter, commitC <-chan *commit, errorC <- } snapshot, err := s.loadSnapshot() if err != nil { - zero.Fatal().Err(err).Msg("[store] load snapshot") + log.Fatal("[store] load snapshot", "error", err) } if snapshot != nil { - zero.Info().Uint64("term", snapshot.Metadata.Term).Uint64("index", snapshot.Metadata.Index).Msg("[store] loading snapshot at term and index") + log.Info("[store] loading snapshot at term and index", "term", snapshot.Metadata.Term, "index", snapshot.Metadata.Index) if err := s.recoverFromSnapshot(snapshot.Data); err != nil { - zero.Fatal().Err(err).Msg("[store] recover snapshot") + log.Fatal("[store] recover snapshot", "error", err) } } // read commits from raft into kvStore map until error @@ -117,12 +117,12 @@ func (s *KVStore) readCommits() { // signaled to load snapshot snapshot, err := s.loadSnapshot() if err != nil { - zero.Fatal().Err(err).Msg("[store] load snapshot") + log.Fatal("[store] load snapshot", "error", err) } if snapshot != nil { - zero.Info().Uint64("term", snapshot.Metadata.Term).Uint64("index", snapshot.Metadata.Index).Msg("[store] loading snapshot at term and index") + log.Info("[store] loading snapshot at term and index", "term", snapshot.Metadata.Term, "index", snapshot.Metadata.Index) if err := s.recoverFromSnapshot(snapshot.Data); err != nil { - zero.Fatal().Err(err).Msg("[store] recover snapshot") + log.Fatal("[store] recover snapshot", "error", err) } } continue @@ -146,7 +146,7 @@ func (s *KVStore) readCommits() { close(commit.applyDoneC) } if err, ok := <-s.errorC; ok { - zero.Fatal().Err(err).Msg("[store] read commit") + log.Fatal("[store] read commit", "error", err) } } diff --git a/cluster/raft/etcd/peer.go b/cluster/raft/etcd/peer.go index 60bac16..eefc37e 100644 --- a/cluster/raft/etcd/peer.go +++ b/cluster/raft/etcd/peer.go @@ -8,17 +8,18 @@ import ( "context" "errors" "fmt" - "github.com/wind-c/comqtt/v2/cluster/log/zero" - "github.com/wind-c/comqtt/v2/cluster/message" - "github.com/wind-c/comqtt/v2/config" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" "net" "net/http" "os" "strconv" "time" + "github.com/wind-c/comqtt/v2/cluster/log" + "github.com/wind-c/comqtt/v2/cluster/message" + "github.com/wind-c/comqtt/v2/config" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.etcd.io/etcd/client/pkg/v3/fileutil" "go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/types" @@ -185,14 +186,14 @@ func getZapLogger() *zap.Logger { EncodeTime: zapcore.ISO8601TimeEncoder, EncodeDuration: zapcore.StringDurationEncoder, } - core := zapcore.NewCore(zapcore.NewJSONEncoder(encoderCfg), zapcore.AddSync(zero.Logger()), zapcore.ErrorLevel) + core := zapcore.NewCore(zapcore.NewJSONEncoder(encoderCfg), zapcore.AddSync(log.Writer()), zapcore.ErrorLevel) return zap.New(core) } func (p *Peer) startRaft() { if !fileutil.Exist(p.snapDir) { if err := os.MkdirAll(p.snapDir, 0750); err != nil { - zero.Fatal().Err(err).Msg("[raft] failed to create dir for dnapshot") + log.Fatal("[raft] failed to create dir for dnapshot", "error", err) } } p.snapshotter = snap.New(p.logger, p.snapDir) @@ -236,7 +237,7 @@ func (p *Peer) startRaft() { } if err := p.transport.Start(); err != nil { - zero.Fatal().Err(err).Msg("[raft] transport start") + log.Fatal("[raft] transport start", "error", err) } for i := range p.peers { @@ -253,15 +254,15 @@ func (p *Peer) serveRaft() { addr := net.JoinHostPort(p.conf.BindAddr, strconv.Itoa(p.conf.RaftPort)) listener, err := newStoppableListener(addr, p.httpStopC) if err != nil { - zero.Fatal().Err(err).Msg("[raft] failed ti listen rafthttp") + log.Fatal("[raft] failed ti listen rafthttp", "error", err) } - zero.Info().Str("host", p.genLocalAddr()).Msg("[raft] http is listeing at") + log.Info("[raft] http is listeing at", "host", p.genLocalAddr()) err = (&http.Server{Handler: p.transport.Handler()}).Serve(listener) select { case <-p.httpStopC: default: - zero.Fatal().Err(err).Msg("[raft] failed to serve rafthttp") + log.Fatal("[raft] failed to serve rafthttp", "error", err) } close(p.httpStopC) } @@ -269,7 +270,7 @@ func (p *Peer) serveRaft() { func (p *Peer) serveChannels() { snapshot, err := p.raftStorage.Snapshot() if err != nil { - zero.Panic().Err(err).Msg("[raft] snapshot") + log.Fatal("[raft] snapshot", "error", err) } p.confState = snapshot.Metadata.ConfState p.snapshotIndex = snapshot.Metadata.Index @@ -291,7 +292,7 @@ func (p *Peer) serveChannels() { p.proposeC = nil } else { if err := p.node.Propose(context.TODO(), prop.MsgpackBytes()); err != nil { - zero.Error().Err(err).Bytes("filter", prop.Payload).Str("nid", prop.NodeID).Uint8("type", prop.Type).Msg("Propose") + log.Error("Propose", "error", err, "filter", prop.Payload, "nid", prop.NodeID, "type", prop.Type) } } @@ -302,7 +303,7 @@ func (p *Peer) serveChannels() { confChangeCount++ cc.ID = confChangeCount if err := p.node.ProposeConfChange(context.TODO(), cc); err != nil { - zero.Error().Err(err).Uint64("nid", cc.NodeID).Int32("type", int32(cc.Type)).Msg("ProposeConfChange") + log.Error("ProposeConfChange", "error", err, "nid", cc.NodeID, "type", cc.Type) } } } @@ -324,11 +325,11 @@ func (p *Peer) serveChannels() { if !raft.IsEmptySnap(rd.Snapshot) { if err := p.saveSnap(rd.Snapshot); err != nil { - zero.Error().Err(err).Msg("saveSnap") + log.Error("saveSnap", "error", err) } } if err := p.wal.Save(rd.HardState, rd.Entries); err != nil { - zero.Error().Err(err).Msg("wal.Save") + log.Error("wal.Save", "error", err) } if !raft.IsEmptySnap(rd.Snapshot) { p.raftStorage.ApplySnapshot(rd.Snapshot) @@ -357,12 +358,13 @@ func (p *Peer) serveChannels() { // replayWAL replays WAL entries into the raft instance. func (p *Peer) replayWAL() *wal.WAL { - zero.Info().Uint64("id", p.id).Msg("[raft] replaying WAL of member") + + log.Info("[raft] replaying WAL of membe", "id", p.id) snapshot := p.loadSnapshot() w := p.openWAL(snapshot) _, st, ents, err := w.ReadAll() if err != nil { - zero.Fatal().Err(err).Msg("[raft] failed to read WAL") + log.Error("[raft] failed to read WAL", "error", err) } p.raftStorage = raft.NewMemoryStorage() if snapshot != nil { @@ -380,11 +382,11 @@ func (p *Peer) loadSnapshot() *raftpb.Snapshot { if wal.Exist(p.walDir) { walSnaps, err := wal.ValidSnapshotEntries(p.logger, p.walDir) if err != nil { - zero.Fatal().Err(err).Msg("[raft] error listening snapshots") + log.Fatal("[raft] error listening snapshots", "error", err) } snapshot, err := p.snapshotter.LoadNewestAvailable(walSnaps) if err != nil && err != snap.ErrNoSnapshot { - zero.Fatal().Err(err).Msg("[raft] error loading snapshit") + log.Fatal("[raft] error loading snapshit", "error", err) } return snapshot } @@ -396,11 +398,11 @@ func (p *Peer) openWAL(snapshot *raftpb.Snapshot) *wal.WAL { if !wal.Exist(p.walDir) { err := os.Mkdir(p.walDir, 0750) if err != nil { - zero.Fatal().Err(err).Msg("[raft] cannot create dir for wal") + log.Fatal("[raft] cannot create dir for wal", "error", err) } w, err := wal.Create(p.logger, p.walDir, nil) if err != nil { - zero.Fatal().Err(err).Msg("[raft] create dir for error") + log.Fatal("[raft] create dir for error", "error", err) } w.Close() } @@ -409,10 +411,10 @@ func (p *Peer) openWAL(snapshot *raftpb.Snapshot) *wal.WAL { if snapshot != nil { walSnap.Index, walSnap.Term = snapshot.Metadata.Index, snapshot.Metadata.Term } - zero.Info().Uint64("term", walSnap.Term).Uint64("index", walSnap.Index).Msg("[raft] loading WAL") + log.Info("[raft] loading WAL", "term", walSnap.Term, "index", walSnap.Index) w, err := wal.Open(p.logger, p.walDir, walSnap) if err != nil { - zero.Fatal().Err(err).Msg("[raft] error loading wal") + log.Fatal("[raft] error loading wal", "error", err) } return w } @@ -438,7 +440,7 @@ func (p *Peer) entriesToApply(ents []raftpb.Entry) (nents []raftpb.Entry) { } firstIdx := ents[0].Index if firstIdx > p.appliedIndex+1 { - zero.Fatal().Uint64("first-idx", firstIdx).Uint64("fapplied-idx", p.appliedIndex).Msg("[raft] fisrt index of committed entry should <= progress.appliedIndex+1") + log.Fatal("[raft] fisrt index of committed entry should <= progress.appliedIndex+1", "first-idx", firstIdx, "fapplied-idx", p.appliedIndex) } if p.appliedIndex-firstIdx+1 < uint64(len(ents)) { nents = ents[p.appliedIndex-firstIdx+1:] @@ -451,11 +453,10 @@ func (p *Peer) publishSnapshot(snapshotToSave raftpb.Snapshot) { return } - zero.Info().Uint64("snap-idx", p.snapshotIndex).Msg("[raft] publishing snapshot at index") - defer zero.Info().Uint64("snap-idx", p.snapshotIndex).Msg("[raft] finished publishing snapshot at index") - + log.Info("[raft] publishing snapshot at index", "snap-idx", p.snapshotIndex) + defer log.Info("[raft] finished publishing snapshot at index", "snap-idx", p.snapshotIndex) if snapshotToSave.Metadata.Index <= p.appliedIndex { - zero.Info().Uint64("snap-idx", snapshotToSave.Metadata.Index).Uint64("applied-idx", p.appliedIndex).Msg("[raft] snapshot index shuold > progress.appliedIndex") + log.Info("[raft] snapshot index shuold > progress.appliedIndex", "snap-idx", snapshotToSave.Metadata.Index, "applied-idx", p.appliedIndex) } p.commitC <- nil // trigger kvstore to load snapshot @@ -491,7 +492,7 @@ func (p *Peer) publishEntries(ents []raftpb.Entry) (<-chan struct{}, bool) { case raftpb.ConfChangeAddNode: if len(cc.Context) > 0 { p.transport.AddPeer(types.ID(cc.NodeID), []string{string(cc.Context)}) - zero.Info().Uint64("node", cc.NodeID).Msg("[raft] node is added to the cluster") + log.Info("[raft] node is added to the cluster", "node", cc.NodeID) } case raftpb.ConfChangeRemoveNode: //if cc.NodeID == p.id { @@ -499,7 +500,7 @@ func (p *Peer) publishEntries(ents []raftpb.Entry) (<-chan struct{}, bool) { // return nil, false //} p.transport.RemovePeer(types.ID(cc.NodeID)) - zero.Info().Uint64("node", cc.NodeID).Msg("[raft] node is removed to the cluster") + log.Info("[raft] node is removed to the cluster", "node", cc.NodeID) } } } @@ -537,17 +538,17 @@ func (p *Peer) maybeTriggerSnapshot(applyDoneC <-chan struct{}) { } } - zero.Info().Uint64("applied-idx", p.appliedIndex).Uint64("snapshot-idx", p.snapshotIndex).Msg("[raft] start snapshot") + log.Info("[raft] start snapshot", "applied-idx", p.appliedIndex, "snapshot-idx", p.snapshotIndex) data, err := p.getSnapshot() if err != nil { - zero.Panic().Err(err).Msg("[raft] get snapshot") + log.Fatal("[raft] get snapshot", "error", err) } snapshot, err := p.raftStorage.CreateSnapshot(p.appliedIndex, &p.confState, data) if err != nil { - zero.Panic().Err(err).Msg("[raft] create snapshot") + log.Fatal("[raft] create snapshot", "error", err) } if err = p.saveSnap(snapshot); err != nil { - zero.Panic().Err(err).Msg("[raft] save snapshot") + log.Fatal("[raft] save snapshot", "error", err) } compactIndex := uint64(1) @@ -555,10 +556,10 @@ func (p *Peer) maybeTriggerSnapshot(applyDoneC <-chan struct{}) { compactIndex = p.appliedIndex - snapshotCatchUpEntriesN } if err = p.raftStorage.Compact(compactIndex); err != nil { - zero.Panic().Err(err).Msg("[raft] compact snapshot") + log.Fatal("[raft] compact snapshot", "error", err) } - zero.Info().Uint64("compact-idx", compactIndex).Msg("compacted log at index") + log.Info("compacted log at index", "compact-idx", compactIndex) p.snapshotIndex = p.appliedIndex } diff --git a/cluster/raft/hashicorp/fsm.go b/cluster/raft/hashicorp/fsm.go index a985e9f..065c07c 100644 --- a/cluster/raft/hashicorp/fsm.go +++ b/cluster/raft/hashicorp/fsm.go @@ -7,11 +7,12 @@ package hashicorp import ( "bytes" "encoding/gob" - "github.com/wind-c/comqtt/v2/cluster/message" - "github.com/wind-c/comqtt/v2/mqtt/packets" "io" "sync" + "github.com/wind-c/comqtt/v2/cluster/message" + "github.com/wind-c/comqtt/v2/mqtt/packets" + "github.com/hashicorp/raft" ) @@ -36,10 +37,8 @@ func (f *Fsm) Apply(l *raft.Log) interface{} { filter := string(msg.Payload) if msg.Type == packets.Subscribe { f.kv.Set(filter, msg.NodeID) - //zero.Info().Str("from", msg.NodeID).Str("event", "subscribe").Str("filter", filter).Msg("apply") } else if msg.Type == packets.Unsubscribe { f.kv.Del(filter, msg.NodeID) - //zero.Info().Str("from", msg.NodeID).Str("event", "unsubscribe").Str("filter", filter).Msg("apply") } else { return nil } diff --git a/cluster/raft/hashicorp/peer.go b/cluster/raft/hashicorp/peer.go index ceed30a..54ac4ca 100644 --- a/cluster/raft/hashicorp/peer.go +++ b/cluster/raft/hashicorp/peer.go @@ -8,16 +8,17 @@ import ( "encoding/json" "errors" "fmt" - "github.com/wind-c/comqtt/v2/cluster/log/zero" - "github.com/wind-c/comqtt/v2/cluster/message" - "github.com/wind-c/comqtt/v2/cluster/utils" - "github.com/wind-c/comqtt/v2/config" "net" "os" "path/filepath" "strconv" "time" + "github.com/wind-c/comqtt/v2/cluster/log" + "github.com/wind-c/comqtt/v2/cluster/message" + "github.com/wind-c/comqtt/v2/cluster/utils" + "github.com/wind-c/comqtt/v2/config" + "github.com/hashicorp/raft" raftdb "github.com/hashicorp/raft-boltdb/v2" ) @@ -74,7 +75,7 @@ func Setup(conf *config.Cluster, notifyCh chan<- *message.Message) (*Peer, error config := raft.DefaultConfig() config.LocalID = raft.ServerID(conf.NodeName) config.LogLevel = "ERROR" - config.LogOutput = zero.Logger() + config.LogOutput = log.Writer() config.ShutdownOnRemove = false // Disable shutdown on removal config.SnapshotInterval = 30 * time.Second // Check every 5 seconds to see if there are enough new entries for a snapshot, can be overridden config.SnapshotThreshold = 16384 // Snapshots are created every 16384 entries by default, can be overridden @@ -130,7 +131,7 @@ func Setup(conf *config.Cluster, notifyCh chan<- *message.Message) (*Peer, error */ peersFile := filepath.Join(conf.RaftDir, peersFIle) if utils.PathExists(peersFile) { - zero.Info().Msg("found peers.json file, recovering Raft configuration...") + log.Info("found peers.json file, recovering Raft configuration...") var configuration raft.Configuration configuration, err = raft.ReadConfigJSON(peersFile) @@ -146,7 +147,7 @@ func Setup(conf *config.Cluster, notifyCh chan<- *message.Message) (*Peer, error return nil, fmt.Errorf("recovery failed to delete peers.json, please delete manually (see peers.info for details): %v", err) } - zero.Info().Msg("deleted peers.json file after successful recovery") + log.Info("deleted peers.json file after successful recovery") } if conf.RaftBootstrap { @@ -165,7 +166,7 @@ func Setup(conf *config.Cluster, notifyCh chan<- *message.Message) (*Peer, error }, } if err := raft.BootstrapCluster(config, logCache, stable, snapshot, transport, configuration); err != nil { - zero.Error().Err(err).Msg("raft bootstrap cluster") + log.Error("raft bootstrap cluster", "error", err) return nil, err } } @@ -178,9 +179,9 @@ func Setup(conf *config.Cluster, notifyCh chan<- *message.Message) (*Peer, error peer := &Peer{config, rf, fm, store, transport} if id, err := peer.waitForLeader(peer.electionTimeout() * 3); err != nil { - zero.Warn().Str("leader", "unknow").Msg("timeout waiting for raft leader") + log.Warn("timeout waiting for raft leader", "leader", "unknow") } else { - zero.Info().Str("leader", id).Msg("found raft leader") + log.Info("found raft leader", "leader", id) } return peer, nil @@ -194,13 +195,13 @@ func (p *Peer) Stop() { // snapshot if err := p.snapshot().Error(); err != "" { - zero.Warn().Msg("failed to create snapshot!") + log.Warn("failed to create snapshot!") } // close raft shutdownFuture := p.raft.Shutdown() if err := shutdownFuture.Error(); err != nil { - zero.Error().Err(err).Msg("shutdown raft") + log.Error("shutdown raft", "error", err) } // close store p.store.Close() @@ -264,7 +265,7 @@ func (p *Peer) Join(nodeId, addr string) error { // However if *both* the ID and the address are the same, then nothing -- not even // a join operation -- is needed. if srv.Address == raft.ServerAddress(addr) && srv.ID == raft.ServerID(nodeId) { - zero.Warn().Str("node", nodeId).Str("addr", addr).Msg("it is already a cluster member, ignoring join request") + log.Warn("it is already a cluster member, ignoring join request", "node", nodeId, "addr", addr) return nil } diff --git a/cluster/raft/hashicorp/transport.go b/cluster/raft/hashicorp/transport.go index 7043971..340d4f5 100644 --- a/cluster/raft/hashicorp/transport.go +++ b/cluster/raft/hashicorp/transport.go @@ -5,10 +5,11 @@ package hashicorp import ( - "github.com/hashicorp/raft" - "github.com/wind-c/comqtt/v2/cluster/log/zero" "net" "time" + + "github.com/hashicorp/raft" + "github.com/wind-c/comqtt/v2/cluster/log" ) func newRaftTrans(ln net.Listener) *raft.NetworkTransport { @@ -16,17 +17,17 @@ func newRaftTrans(ln net.Listener) *raft.NetworkTransport { addr, ok := layer.Addr().(*net.TCPAddr) if !ok { if err := ln.Close(); err != nil { - zero.Error().Err(err).Msg("raft addr is not tcp addr") + log.Error("raft addr is not tcp addr", "error", err) } return nil } if addr.IP == nil || addr.IP.IsUnspecified() { if err := ln.Close(); err != nil { - zero.Error().Err(err).Msg("raft addr is not valid") + log.Error("raft addr is not valid", "error", err) } return nil } - return raft.NewNetworkTransport(layer, maxPool, DefaultRaftTimeout, zero.Logger()) + return raft.NewNetworkTransport(layer, maxPool, DefaultRaftTimeout, log.Writer()) } type raftLayer struct { diff --git a/cluster/service.go b/cluster/service.go index dc51d76..8a79249 100644 --- a/cluster/service.go +++ b/cluster/service.go @@ -8,8 +8,13 @@ import ( "context" "errors" "fmt" + "net" + "strconv" + "sync" + "time" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/cluster/message" crpc "github.com/wind-c/comqtt/v2/cluster/rpc" "github.com/wind-c/comqtt/v2/mqtt/packets" @@ -17,10 +22,6 @@ import ( "google.golang.org/grpc/credentials/insecure" _ "google.golang.org/grpc/health" "google.golang.org/grpc/keepalive" - "net" - "strconv" - "sync" - "time" ) const ( @@ -71,7 +72,7 @@ func (s *RpcService) StartRpcServer() error { // serve grpc go func() { if err := grpcServer.Serve(grpcListen); err != nil { - zero.Error().Err(err).Msg("grpc server serve") + log.Error("grpc server serve", "error", err) } }() @@ -216,7 +217,7 @@ func (c *ClientManager) getClient(nodeId string) (*client, error) { func (c *ClientManager) RelayPublishPacket(nodeId string, msg *message.Message) { client, err := c.getClient(nodeId) if err != nil { - zero.Error().Err(err).Msg("get grpc client") + log.Error("get grpc client", "error", err) return } @@ -229,7 +230,7 @@ func (c *ClientManager) RelayPublishPacket(nodeId string, msg *message.Message) Payload: msg.Payload, } if _, err := client.PublishPacket(ctx, &req); err != nil { - zero.Error().Err(err).Str("to", nodeId).Str("cid", msg.ClientID).Msg("relay publish packet") + log.Error("relay publish packet", "error", err, "to", nodeId, "cid", msg.ClientID) } } @@ -247,7 +248,7 @@ func (c *ClientManager) ConnectNotifyToNode(nodeId, clientId string) { } OnConnectPacketLog(DirectionOutbound, nodeId, clientId) if _, err := client.ConnectNotify(ctx, &req); err != nil { - zero.Error().Err(err).Str("to", nodeId).Str("cid", clientId).Msg("connection notification") + log.Error("connection notification", "error", err, "to", nodeId, "cid", clientId) } } @@ -264,7 +265,7 @@ func (c *ClientManager) ConnectNotifyToOthers(msg *message.Message) { func (c *ClientManager) RelayRaftApply(nodeId string, msg *message.Message) { client, err := c.getClient(nodeId) if err != nil { - zero.Error().Err(err).Msg("get grpc client") + log.Error("get grpc client", "error", err) return } @@ -293,7 +294,7 @@ func (c *ClientManager) RaftApplyToOthers(msg *message.Message) { func (c *ClientManager) RelayRaftJoin(nodeId string) { client, err := c.getClient(nodeId) if err != nil { - zero.Error().Err(err).Msg("get grpc client") + log.Error("get grpc client", "error", err) return } diff --git a/cluster/storage/redis/redis.go b/cluster/storage/redis/redis.go index 30523a9..9438ad1 100644 --- a/cluster/storage/redis/redis.go +++ b/cluster/storage/redis/redis.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + redis "github.com/go-redis/redis/v8" "github.com/wind-c/comqtt/v2/cluster/utils" "github.com/wind-c/comqtt/v2/mqtt" @@ -120,12 +121,11 @@ func (s *Storage) Init(config any) error { } s.config.HPrefix += ":" - s.Log.Info(). - Str("address", s.config.Options.Addr). - Str("username", s.config.Options.Username). - Int("password-len", len(s.config.Options.Password)). - Int("db", s.config.Options.DB). - Msg("connecting to redis service") + s.Log.Info("connecting to redis service", + "address", s.config.Options.Addr, + "username", s.config.Options.Username, + "password-len", len(s.config.Options.Password), + "db", s.config.Options.DB) s.db = redis.NewClient(s.config.Options) _, err := s.db.Ping(context.Background()).Result() @@ -133,14 +133,14 @@ func (s *Storage) Init(config any) error { return fmt.Errorf("failed to ping service: %w", err) } - s.Log.Info().Msg("connected to redis service") + s.Log.Info("connected to redis service") return nil } // Stop closes the redis connection. func (s *Storage) Stop() error { - s.Log.Info().Msg("disconnecting from redis service") + s.Log.Info("disconnecting from redis service") return s.db.Close() } @@ -158,7 +158,7 @@ func (s *Storage) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (s *Storage) updateClient(cl *mqtt.Client) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -186,14 +186,14 @@ func (s *Storage) updateClient(cl *mqtt.Client) { err := s.db.HSet(s.ctx, s.hKey(storage.ClientKey), clientKey(cl), in).Err() if err != nil { - s.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data") + s.Log.Error("failed to hset client data", "error", storage.ErrDBFileNotOpen, "data", in) } } // OnDisconnect removes a client from the store if they were using a clean session. func (s *Storage) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -203,14 +203,14 @@ func (s *Storage) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := s.db.HDel(s.ctx, s.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - s.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") + s.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (s *Storage) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -229,7 +229,7 @@ func (s *Storage) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [ err := s.db.HSet(s.ctx, s.hKey(utils.JoinStrings(storage.SubscriptionKey, cl.ID)), pk.Filters[i].Filter, in).Err() if err != nil { - s.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data") + s.Log.Error("failed to hset subscription data", "error", err, "data", in) } } } @@ -237,14 +237,14 @@ func (s *Storage) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [ // OnUnsubscribed removes one or more client subscriptions from the store. func (s *Storage) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } for i := 0; i < len(pk.Filters); i++ { err := s.db.HDel(s.ctx, s.hKey(utils.JoinStrings(storage.SubscriptionKey, cl.ID)), pk.Filters[i].Filter).Err() if err != nil { - s.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data") + s.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl)) } } } @@ -252,14 +252,14 @@ func (s *Storage) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes // OnRetainMessage adds a retained message for a topic to the store. func (s *Storage) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if r == -1 { err := s.db.HDel(s.ctx, s.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err() if err != nil { - s.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data") + s.Log.Error("failed to delete retained message data", "error", err, "id", clientKey(cl)) } return @@ -285,14 +285,14 @@ func (s *Storage) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { err := s.db.HSet(s.ctx, s.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err() if err != nil { - s.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data") + s.Log.Error("failed to hset retained message data", "error", err, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (s *Storage) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -319,27 +319,27 @@ func (s *Storage) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, r err := s.db.HSet(s.ctx, s.hKey(utils.JoinStrings(storage.InflightKey, cl.ID)), inflightKey(cl, pk), in).Err() if err != nil { - s.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data") + s.Log.Error("failed to hset qos inflight message data", "error", err, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (s *Storage) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := s.db.HDel(s.ctx, s.hKey(utils.JoinStrings(storage.InflightKey, cl.ID)), inflightKey(cl, pk)).Err() if err != nil { - s.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") + s.Log.Error("failed to delete inflight message data", "error", err, "id", clientKey(cl)) } } // OnQosDropped removes a dropped inflight message from the store. func (s *Storage) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) } s.OnQosComplete(cl, pk) @@ -348,7 +348,7 @@ func (s *Storage) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (s *Storage) OnSysInfoTick(sys *system.Info) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -358,40 +358,40 @@ func (s *Storage) OnSysInfoTick(sys *system.Info) { err := s.db.HSet(s.ctx, s.hKey(storage.SysInfoKey), sysInfoKey(), in).Err() if err != nil { - s.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data") + s.Log.Error("failed to hset server info data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (s *Storage) OnRetainedExpired(filter string) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := s.db.HDel(s.ctx, s.hKey(storage.RetainedKey), retainedKey(filter)).Err() if err != nil { - s.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") + s.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (s *Storage) OnClientExpired(cl *mqtt.Client) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := s.db.HDel(s.ctx, s.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - s.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") + s.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) } } // StoredSysInfo returns the system info from the store. func (s *Storage) StoredSysInfo() (v storage.SystemInfo, err error) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -401,7 +401,7 @@ func (s *Storage) StoredSysInfo() (v storage.SystemInfo, err error) { } if err = v.UnmarshalBinary([]byte(row)); err != nil { - s.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data") + s.Log.Error("failed to unmarshal sys info data", "error", err, "data", row) } return v, nil @@ -410,17 +410,17 @@ func (s *Storage) StoredSysInfo() (v storage.SystemInfo, err error) { // StoredClientByCid returns a stored client from the store. func (s *Storage) StoredClientByCid(cid string) (v storage.Client, err error) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } row, err := s.db.HGet(s.ctx, s.hKey(storage.ClientKey), cid).Result() if err != nil && !errors.Is(err, redis.Nil) { - s.Log.Error().Err(err).Msg("failed to HGet client data") + s.Log.Error("failed to HGet client data", "error", err) return } if err = v.UnmarshalBinary([]byte(row)); err != nil { - s.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data") + s.Log.Error("failed to unmarshal client data", "error", err, "data", row) } return v, nil @@ -429,20 +429,20 @@ func (s *Storage) StoredClientByCid(cid string) (v storage.Client, err error) { // StoredSubscriptionsByCid returns all stored subscriptions of client from the store. func (s *Storage) StoredSubscriptionsByCid(cid string) (v []storage.Subscription, err error) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := s.db.HGetAll(s.ctx, s.hKey(utils.JoinStrings(storage.SubscriptionKey, cid))).Result() if err != nil && !errors.Is(err, redis.Nil) { - s.Log.Error().Err(err).Msg("failed to HGetAll subscription data") + s.Log.Error("failed to HGetAll subscription data", "error", err) return } for filter, row := range rows { var d storage.Subscription if err = d.UnmarshalBinary([]byte(row)); err != nil { - s.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data") + s.Log.Error("failed to unmarshal subscription data", "error", err, "data", row) } if d.Filter == "" { @@ -458,18 +458,18 @@ func (s *Storage) StoredSubscriptionsByCid(cid string) (v []storage.Subscription // StoredRetainedMessageByTopic returns a stored retained message of topic from the store. func (s *Storage) StoredRetainedMessageByTopic(topic string) (v storage.Message, err error) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } row, err := s.db.HGet(s.ctx, s.hKey(storage.RetainedKey), retainedKey(topic)).Result() if err != nil && !errors.Is(err, redis.Nil) { - s.Log.Error().Err(err).Msg("failed to HGetAll retained message data") + s.Log.Error("failed to HGetAll retained message data", "error", err) return } if err = v.UnmarshalBinary([]byte(row)); err != nil { - s.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data") + s.Log.Error("failed to unmarshal retained message dat", "error", err, "data", row) } if v.TopicName == "" { @@ -482,20 +482,20 @@ func (s *Storage) StoredRetainedMessageByTopic(topic string) (v storage.Message, // StoredInflightMessagesByCid returns all stored inflight messages of client from the store. func (s *Storage) StoredInflightMessagesByCid(cid string) (v []storage.Message, err error) { if s.db == nil { - s.Log.Error().Err(storage.ErrDBFileNotOpen) + s.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := s.db.HGetAll(s.ctx, s.hKey(utils.JoinStrings(storage.InflightKey, cid))).Result() if err != nil && !errors.Is(err, redis.Nil) { - s.Log.Error().Err(err).Msg("failed to HGetAll inflight message data") + s.Log.Error("failed to HGetAll inflight message data", "error", err) return } for _, row := range rows { var d storage.Message if err = d.UnmarshalBinary([]byte(row)); err != nil { - s.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") + s.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row) } v = append(v, d) diff --git a/cluster/storage/redis/redis_test.go b/cluster/storage/redis/redis_test.go index 473c450..0567047 100644 --- a/cluster/storage/redis/redis_test.go +++ b/cluster/storage/redis/redis_test.go @@ -5,24 +5,28 @@ package redis import ( + "io" + "sort" + "testing" + "time" + + "log/slog" + "github.com/wind-c/comqtt/v2/cluster/utils" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - "os" - "sort" - "testing" - "time" miniredis "github.com/alicebob/miniredis/v2" redis "github.com/go-redis/redis/v8" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -41,7 +45,7 @@ var ( func newHook(t *testing.T, addr string) *Storage { s := new(Storage) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) err := s.Init(&Options{ Options: &redis.Options{ @@ -83,13 +87,13 @@ func TestInflightKey(t *testing.T) { func TestID(t *testing.T) { s := new(Storage) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) require.Equal(t, "redis-db", s.ID()) } func TestProvides(t *testing.T) { s := new(Storage) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) require.True(t, s.Provides(mqtt.OnSessionEstablished)) require.True(t, s.Provides(mqtt.OnDisconnect)) require.True(t, s.Provides(mqtt.OnSubscribed)) @@ -112,7 +116,7 @@ func TestHKey(t *testing.T) { m := miniredis.RunT(t) defer m.Close() s := newHook(t, m.Addr()) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) require.Equal(t, defaultHPrefix+":test", s.hKey("test")) } @@ -122,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) { defer m.Close() s := newHook(t, defaultAddr) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) err := s.Init(nil) require.NoError(t, err) defer teardown(t, s) @@ -133,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadConfig(t *testing.T) { s := new(Storage) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) err := s.Init(map[string]any{}) require.Error(t, err) @@ -141,7 +145,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitBadAddr(t *testing.T) { s := new(Storage) - s.SetOpts(&logger, nil) + s.SetOpts(logger, nil) err := s.Init(&Options{ Options: &redis.Options{ Addr: "abc:123", diff --git a/cmd/cluster/main.go b/cmd/cluster/main.go index 3fbd914..c3799f5 100644 --- a/cmd/cluster/main.go +++ b/cmd/cluster/main.go @@ -9,10 +9,19 @@ import ( "encoding/json" "flag" "fmt" + "io" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "github.com/go-redis/redis/v8" - "github.com/rs/zerolog" cs "github.com/wind-c/comqtt/v2/cluster" - colog "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" coredis "github.com/wind-c/comqtt/v2/cluster/storage/redis" "github.com/wind-c/comqtt/v2/config" mqtt "github.com/wind-c/comqtt/v2/mqtt" @@ -24,24 +33,13 @@ import ( pauth "github.com/wind-c/comqtt/v2/plugin/auth/postgresql" rauth "github.com/wind-c/comqtt/v2/plugin/auth/redis" cokafka "github.com/wind-c/comqtt/v2/plugin/bridge/kafka" - "io" - "log" - "net" - "net/http" - _ "net/http/pprof" - "os" - "os/signal" - "strconv" - "strings" - "syscall" ) var agent *cs.Agent -var logger *zerolog.Logger func pprof() { go func() { - log.Println(http.ListenAndServe(":6060", nil)) + log.Info("listen pprof", "error", http.ListenAndServe(":6060", nil)) }() } @@ -49,9 +47,7 @@ func main() { sigCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() err := realMain(sigCtx) - if err != nil { - log.Fatal(err) - } + onError(err, "") } func realMain(ctx context.Context) error { @@ -79,11 +75,8 @@ func realMain(ctx context.Context) error { flag.StringVar(&cfg.Redis.Options.Addr, "redis", "127.0.0.1:6379", "redis address for cluster mode") flag.StringVar(&cfg.Redis.Options.Password, "redis-pass", "", "redis password for cluster mode") flag.IntVar(&cfg.Redis.Options.DB, "redis-db", 0, "redis db for cluster mode") - flag.BoolVar(&cfg.Log.Enable, "log-enable", true, "log enabled or not") - flag.IntVar(&cfg.Log.Env, "env", 0, "app running environment:0 development or 1 production") - flag.IntVar(&cfg.Log.Level, "level", 1, "log level options:0Debug,1Info, 2Warn, 3Error, 4Fatal, 5Panic, 6NoLevel, 7Off") - flag.StringVar(&cfg.Log.InfoFile, "info-file", "./logs/co-info.log", "info log filename") - flag.StringVar(&cfg.Log.ErrorFile, "error-file", "./logs/co-err.log", "error log filename") + flag.BoolVar(&cfg.Log.Disable, "log-disable", false, "log disabled or not") + flag.StringVar(&cfg.Log.Filename, "log-file", "./logs/comqtt.log", "log filename") //parse arguments flag.Parse() //load config file @@ -99,27 +92,18 @@ func realMain(ctx context.Context) error { } } + //init log + log.Init(&cfg.Log) + //enable pprof if cfg.PprofEnable { pprof() } - //init log - if cfg.Cluster.NodeName == "" { - if hn, err := os.Hostname(); err == nil { - cfg.Log.NodeName = hn - } - } - - logger = colog.Init(cfg.Log) - if cfg.Log.Enable && cfg.Log.Format == 1 { - log.Println("log output to the files, please check") - } - // create server instance and init hooks - cfg.Mqtt.Options.Logger = logger + cfg.Mqtt.Options.Logger = log.Default() server := mqtt.New(&cfg.Mqtt.Options) - logger.Info().Msg("comqtt server initializing...") + log.Info("comqtt server initializing...") initStorage(server, cfg) initAuth(server, cfg) initBridge(server, cfg) @@ -165,22 +149,18 @@ func realMain(ctx context.Context) error { errCh <- err } }() - - if cfg.Log.Format == 1 { - log.Println("comqtt server started") - } + log.Info("comqtt server started") // exit select { case err := <-errCh: - log.Fatalf("server error: %s", err.Error()) // todo: change the formatting so the error is logged - server.Log.Fatal().Err(err).Msg("server error") // <-- this swallows the error. + onError(err, "server error") + case <-ctx.Done(): - server.Log.Warn().Msg("caught signal, stopping...") + server.Log.Warn("caught signal, stopping...") } agent.Stop() server.Close() - colog.Close() return nil } @@ -253,14 +233,14 @@ func initClusterNode(server *mqtt.Server, conf *config.Config) { agent = cs.NewAgent(&conf.Cluster) agent.BindMqttServer(server) onError(agent.Start(), "create node and join cluster") - - logger.Info().Msg("cluster node created") + log.Info("cluster node created") } // onError handle errors and simplify code func onError(err error, msg string) { if err != nil { - logger.Fatal().Err(err).Msg(msg) + log.Error(msg, "error", err) + os.Exit(1) } } diff --git a/cmd/single/main.go b/cmd/single/main.go index d150eff..ae97a68 100644 --- a/cmd/single/main.go +++ b/cmd/single/main.go @@ -7,9 +7,14 @@ package main import ( "context" "flag" + "net/http" + "os" + "os/signal" + "syscall" + "time" + rv8 "github.com/go-redis/redis/v8" - "github.com/rs/zerolog" - colog "github.com/wind-c/comqtt/v2/cluster/log/zero" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/config" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" @@ -24,19 +29,11 @@ import ( rauth "github.com/wind-c/comqtt/v2/plugin/auth/redis" cokafka "github.com/wind-c/comqtt/v2/plugin/bridge/kafka" "go.etcd.io/bbolt" - "log" - "net/http" - "os" - "os/signal" - "syscall" - "time" ) -var logger *zerolog.Logger - func pprof() { go func() { - log.Println(http.ListenAndServe(":6060", nil)) + log.Info("listen pprof", "error", http.ListenAndServe(":6060", nil)) }() } @@ -44,9 +41,7 @@ func main() { sigCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() err := realMain(sigCtx) - if err != nil { - log.Fatal(err) - } + onError(err, "") } func realMain(ctx context.Context) error { @@ -62,38 +57,29 @@ func realMain(ctx context.Context) error { flag.StringVar(&cfg.Mqtt.TCP, "tcp", ":1883", "network address for Mqtt TCP listener") flag.StringVar(&cfg.Mqtt.WS, "ws", ":1882", "network address for Mqtt Websocket listener") flag.StringVar(&cfg.Mqtt.HTTP, "http", ":8080", "network address for web info dashboard listener") - flag.BoolVar(&cfg.Log.Enable, "log-enable", true, "log enabled or not") - flag.IntVar(&cfg.Log.Env, "env", 0, "app running environment,0 development or 1 production") - flag.IntVar(&cfg.Log.Level, "level", 1, "log level options:0Debug,1Info, 2Warn, 3Error, 4Fatal, 5Panic, 6NoLevel, 7Off") - flag.StringVar(&cfg.Log.InfoFile, "info-file", "./logs/co-info.log", "info log filename") - flag.StringVar(&cfg.Log.ErrorFile, "error-file", "./logs/co-err.log", "error log filename") + flag.BoolVar(&cfg.Log.Disable, "log-disable", true, "log disabled or not") + flag.StringVar(&cfg.Log.Filename, "log-file", "./logs/comqtt.log", "log filename") //parse arguments flag.Parse() //load config file if confFile != "" { if cfg, err = config.Load(confFile); err != nil { - log.Fatal(err) + onError(err, "") } } + //init log + log.Init(&cfg.Log) + //enable pprof if cfg.PprofEnable { pprof() } - //init log - if hn, err := os.Hostname(); err == nil { - cfg.Log.NodeName = hn - } - logger = colog.Init(cfg.Log) - if cfg.Log.Enable && cfg.Log.Format == 1 { - log.Println("log output to the files, please check") - } - // create server instance and init hooks - cfg.Mqtt.Options.Logger = logger + cfg.Mqtt.Options.Logger = log.Default() server := mqtt.New(&cfg.Mqtt.Options) - logger.Info().Msg("comqtt server initializing...") + log.Info("comqtt server initializing...") initStorage(server, cfg) initAuth(server, cfg) initBridge(server, cfg) @@ -101,7 +87,7 @@ func realMain(ctx context.Context) error { // gen tls config var listenerConfig *listeners.Config if tlsConfig, err := config.GenTlsConfig(cfg); err != nil { - server.Log.Fatal().Err(err) + onError(err, "") } else { if tlsConfig != nil { listenerConfig = &listeners.Config{TLSConfig: tlsConfig} @@ -129,19 +115,16 @@ func realMain(ctx context.Context) error { } }() - if cfg.Log.Format == 1 { - log.Println("comqtt server started") - } + log.Info("comqtt server started") + select { case err := <-errCh: - log.Fatalf("server error: %s", err.Error()) - server.Log.Fatal().Err(err).Msg("server error") + onError(err, "server error") case <-ctx.Done(): - server.Log.Warn().Msg("caught signal, stopping...") + log.Warn("caught signal, stopping...") } server.Close() - server.Log.Info().Msg("main.go finished") - colog.Close() + log.Info("main.go finished") return nil } @@ -213,6 +196,7 @@ func initBridge(server *mqtt.Server, conf *config.Config) { // onError handle errors and simplify code func onError(err error, msg string) { if err != nil { - logger.Fatal().Err(err).Msg(msg) + log.Error(msg, "error", err) + os.Exit(1) } } diff --git a/config/conf.yml b/config/conf.yml index 526a4c5..a864a19 100644 --- a/config/conf.yml +++ b/config/conf.yml @@ -69,19 +69,11 @@ redis: prefix: comqtt log: - enable: true - env: 0 #0 dev or 1 prod - format: 1 #output format 0console or 1json - caller: false #whether to display code line number - info-file: ./logs/comqtt-info.log - error-file: ./logs/comqtt-error.log - thirdparty-file: ./logs/thirdparty.log # level 6NoLevel,logs of the third-party library - maxsize: 100 #100M - max-age: 30 #30day - max-backups: 10 #number of log files - localtime: true #true or false - compress: true #true or false - level: 1 #-1Trace 0Debug 1Info 2Warn 3Error(default) 4Fatal 5Panic 6NoLevel 7Off - sampler: #a maximum of three logs can be output every second - burst: 3 #log count - period: 1 #second + disable: false #Indicates whether logging is enabled. + format: 1 #Log format, currently supports Text: 0 and JSON: 1, with Text as the default. + filename: ./logs/comqtt.log #Filename is the file to write logs to + maxsize: 100 #MaxSize is the maximum size in megabytes of the log file before it gets rotated. It defaults to 100 megabytes. + max-age: 30 #MaxAge is the maximum number of days to retain old log files based on the timestamp encoded in their filename + max-backups: 10 #MaxBackups is the maximum number of old log files to retain + compress: true #Compress determines if the rotated log files should be compressed using gzip + level: 0 #Log level, with supported values LevelDebug: 4, LevelInfo: 0, LevelWarn: 4, and LevelError: 8. diff --git a/config/config.go b/config/config.go index 0cc60d2..a4a5cea 100644 --- a/config/config.go +++ b/config/config.go @@ -8,9 +8,11 @@ import ( tls2 "crypto/tls" "crypto/x509" "errors" + "os" + + "github.com/wind-c/comqtt/v2/cluster/log" comqtt "github.com/wind-c/comqtt/v2/mqtt" "gopkg.in/yaml.v3" - "os" ) const ( @@ -82,16 +84,16 @@ func parse(buf []byte) (*Config, error) { } type Config struct { - StorageWay uint `yaml:"storage-way"` - StoragePath string `yaml:"storage-path"` - BridgeWay uint `yaml:"bridge-way"` - BridgePath string `yaml:"bridge-path"` - Auth auth `yaml:"auth"` - Mqtt mqtt `yaml:"mqtt"` - Cluster Cluster `yaml:"cluster"` - Redis redis `yaml:"redis"` - Log Log `yaml:"log"` - PprofEnable bool `yaml:"pprof-enable"` + StorageWay uint `yaml:"storage-way"` + StoragePath string `yaml:"storage-path"` + BridgeWay uint `yaml:"bridge-way"` + BridgePath string `yaml:"bridge-path"` + Auth auth `yaml:"auth"` + Mqtt mqtt `yaml:"mqtt"` + Cluster Cluster `yaml:"cluster"` + Redis redis `yaml:"redis"` + Log log.Options `yaml:"log"` + PprofEnable bool `yaml:"pprof-enable"` } type auth struct { @@ -185,60 +187,6 @@ func GenTlsConfig(conf *Config) (*tls2.Config, error) { return tlsConfig, nil } -type Log struct { - // Enable Log enabled or not - Enable bool `json:"enable" yaml:"enable"` - - // Env app running environment,0 development or 1 production - Env int `json:"env" yaml:"env"` - - // NodeName used in a cluster environment to distinguish nodes - NodeName string `json:"node-name" yaml:"node-name"` - - // Format output format 0 console or 1 json - Format int `json:"format" yaml:"format"` - - // Whether to display code line number - Caller bool `json:"caller" yaml:"caller"` - - // Filename is the file to write logs to. Backup log files will be retained - // in the same directory. It uses -lumberjack.log in - // os.TempDir() if empty. - InfoFile string `json:"info-file" yaml:"info-file"` - ErrorFile string `json:"error-file" yaml:"error-file"` - ThirdpartyFile string `json:"thirdparty-file" yaml:"thirdparty-file"` - - // MaxSize is the maximum size in megabytes of the log file before it gets - // rotated. It defaults to 100 megabytes. - MaxSize int `json:"maxsize" yaml:"maxsize"` - - // MaxAge is the maximum number of days to retain old log files based on the - // timestamp encoded in their filename. Note that a day is defined as 24 - // hours and may not exactly correspond to calendar days due to daylight - // savings, leap seconds, etc. The default is not to remove old log files - // based on age. - MaxAge int `json:"max-age" yaml:"max-age"` - - // MaxBackups is the maximum number of old log files to retain. The default - // is to retain all old log files (though MaxAge may still cause them to get - // deleted.) - MaxBackups int `json:"max-backups" yaml:"max-backups"` - - // LocalTime determines if the time used for formatting the timestamps in - // backup files is the computer's local time. The default is to use UTC - // time. - Localtime bool `json:"localtime" yaml:"localtime"` - - // Compress determines if the rotated log files should be compressed - // using gzip. The default is not to perform compression. - Compress bool `json:"compress" yaml:"compress"` - // contains filtered or unexported fields - - // Log Level: -1Trace 0Debug 1Info 2Warn 3Error(default) 4Fatal 5Panic 6NoLevel 7Off - Level int `json:"level" yaml:"level"` - Sampler -} - type Sampler struct { Burst int `json:"burst" yaml:"burst"` Period int `json:"period" yaml:"period"` diff --git a/config/config_test.go b/config/config_test.go index 3e3c0ac..f352150 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -6,8 +6,9 @@ package config import ( "fmt" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) var buf = []byte(` @@ -72,7 +73,6 @@ func TestLoadConfigFromFile(t *testing.T) { require.Equal(t, 7946, cfg.Cluster.BindPort) require.Equal(t, "127.0.0.1:6379", cfg.Redis.Options.Addr) require.Equal(t, 10240, cfg.Cluster.QueueDepth) - require.Equal(t, 3, cfg.Log.Sampler.Burst) fmt.Println(cfg) } @@ -84,5 +84,4 @@ func TestParse(t *testing.T) { require.Equal(t, 7946, cfg.Cluster.BindPort) require.Equal(t, "127.0.0.1:6379", cfg.Redis.Options.Addr) require.Equal(t, 10240, cfg.Cluster.QueueDepth) - require.Equal(t, 3, cfg.Log.Sampler.Burst) } diff --git a/go.mod b/go.mod index a9ef393..ca794d0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/wind-c/comqtt/v2 -go 1.20 +go 1.21 require ( github.com/alicebob/miniredis/v2 v2.23.0 @@ -22,7 +22,6 @@ require ( github.com/lib/pq v1.2.0 github.com/panjf2000/ants/v2 v2.7.1 github.com/rs/xid v1.4.0 - github.com/rs/zerolog v1.28.0 github.com/satori/go.uuid v1.2.0 github.com/segmentio/kafka-go v0.4.38 github.com/stretchr/testify v1.8.2 diff --git a/go.sum b/go.sum index a9f3853..e060b80 100644 --- a/go.sum +++ b/go.sum @@ -90,7 +90,6 @@ github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmf github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -122,6 +121,7 @@ github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYF github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -179,6 +179,7 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -330,9 +331,12 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/panjf2000/ants/v2 v2.7.1 h1:qBy5lfSdbxvrR0yUnZfaEDjf0FlCw4ufsbcsxmE7r+M= github.com/panjf2000/ants/v2 v2.7.1/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= @@ -390,8 +394,6 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= -github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -689,6 +691,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -743,6 +746,7 @@ gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXL gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/mqtt/clients.go b/mqtt/clients.go index 39ddbd4..3497df5 100644 --- a/mqtt/clients.go +++ b/mqtt/clients.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co, wind package mqtt @@ -104,13 +104,13 @@ func (cl *Clients) GetByListener(id string) []*Client { // Client contains information about a client known by the broker. type Client struct { - ops *ops // ops provides a reference to server ops. + Properties ClientProperties // client properties State ClientState // the operational state of the client. + Net ClientConnection // network connection state of the client ID string // the client id. - Net ClientConnection // network connection state of the clinet - Properties ClientProperties // client properties - InheritWay int // session inheritance way + ops *ops // ops provides a reference to server ops. sync.RWMutex // mutex + InheritWay int // session inheritance way } // ClientConnection contains the connection transport and metadata for the client. @@ -119,42 +119,42 @@ type ClientConnection struct { bconn *bufio.ReadWriter // a buffered net.Conn for reading packets Remote string // the remote address of the client Listener string // listener id of the client - Inline bool // client is an inline programmetic client + Inline bool // if true, the client is the built-in 'inline' embedded client } // ClientProperties contains the properties which define the client behaviour. type ClientProperties struct { - Username []byte - Will Will Props packets.Properties + Will Will + Username []byte ProtocolVersion byte Clean bool } // Will contains the last will and testament details for a client connection. type Will struct { - TopicName string // - Payload []byte // - User []packets.UserProperty // - + TopicName string // - Flag uint32 // 0,1 WillDelayInterval uint32 // - Qos byte // - Retain bool // - } -// State tracks the state of the client. +// ClientState tracks the state of the client. type ClientState struct { TopicAliases TopicAliases // a map of topic aliases stopCause atomic.Value // reason for stopping - open context.Context // indicate that the client is open for packet exchange - Subscriptions *Subscriptions // a map of the subscription filters a client maintains - outbound chan *packets.Packet // queue for pending outbound packets Inflight *Inflight // a map of in-flight qos messages - cancelOpen context.CancelFunc // cancel function for open context + Subscriptions *Subscriptions // a map of the subscription filters a client maintains disconnected int64 // the time the client disconnected in unix time, for calculating expiry + outbound chan *packets.Packet // queue for pending outbound packets endOnce sync.Once // only end once isTakenOver uint32 // used to identify orphaned clients packetID uint32 // the current highest packetID + open context.Context // indicate that the client is open for packet exchange + cancelOpen context.CancelFunc // cancel function for open context outboundQty int32 // number of messages currently in the outbound queue Keepalive uint16 // the number of seconds the connection can wait ServerKeepalive bool // keepalive was set by the server @@ -200,7 +200,8 @@ func (cl *Client) WriteLoop() { select { case pk := <-cl.State.outbound: if err := cl.WritePacket(*pk); err != nil { - cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet") + // TODO : Figure out what to do with error + cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) } atomic.AddInt32(&cl.State.outboundQty, -1) case <-cl.State.open.Done(): @@ -318,7 +319,7 @@ func (cl *Client) ResendInflightMessages(force bool) error { return nil } -// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. +// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session. func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { deleted := []uint16{} for _, tk := range cl.State.Inflight.GetAll(false) { diff --git a/mqtt/clients_test.go b/mqtt/clients_test.go index a6748f1..957a9ba 100644 --- a/mqtt/clients_test.go +++ b/mqtt/clients_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -29,7 +29,7 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) { cl = newClient(w, &ops{ info: new(system.Info), hooks: new(Hooks), - log: &logger, + log: logger, options: &Options{ Capabilities: &Capabilities{ ReceiveMaximum: 10, @@ -263,7 +263,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) { cl.State.Inflight.internal[uint16(i)] = packets.Packet{} } - cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1) + cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1 i, err := cl.NextPacketID() require.NoError(t, err) require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i) @@ -303,7 +303,7 @@ func TestClientResendInflightMessages(t *testing.T) { err := cl.ResendInflightMessages(true) require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -315,7 +315,7 @@ func TestClientResendInflightMessages(t *testing.T) { func TestClientResendInflightMessagesWriteFailure(t *testing.T) { pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) cl, r, _ := newTestClient() - r.Close() + _ = r.Close() cl.State.Inflight.Set(*pk1.Packet) require.Equal(t, 1, cl.State.Inflight.Len()) @@ -342,8 +342,8 @@ func TestClientReadFixedHeader(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -357,8 +357,8 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) - r.Close() + _, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -372,8 +372,8 @@ func TestClientReadFixedHeaderPacketOversized(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -387,7 +387,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -401,8 +401,8 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { defer cl.Stop(errClientStop) go func() { - r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) - r.Close() + _, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -414,7 +414,7 @@ func TestClientReadOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 18, // Fixed header 0, 5, // Topic Name - LSB+MSB 'a', '/', 'b', '/', 'c', // Topic Name @@ -424,7 +424,7 @@ func TestClientReadOK(t *testing.T) { 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() var pks []packets.Packet @@ -499,10 +499,10 @@ func TestClientReadFixedHeaderError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header }) - r.Close() + _ = r.Close() }() cl.Net.bconn = nil @@ -516,13 +516,13 @@ func TestClientReadReadHandlerErr(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, // Topic Name - LSB+MSB 'd', '/', 'e', '/', 'f', // Topic Name 'y', 'e', 'a', 'h', // Payload }) - r.Close() + _ = r.Close() }() err := cl.Read(func(cl *Client, pk packets.Packet) error { @@ -536,13 +536,13 @@ func TestClientReadReadPacketOK(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ packets.Publish << 4, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() fh := new(packets.FixedHeader) @@ -573,7 +573,7 @@ func TestClientReadPacket(t *testing.T) { t.Run(tt.Desc, func(t *testing.T) { atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0) go func() { - r.Write(tt.RawBytes) + _, _ = r.Write(tt.RawBytes) }() fh := new(packets.FixedHeader) @@ -600,7 +600,7 @@ func TestClientReadPacket(t *testing.T) { func TestClientReadPacketInvalidTypeError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() _, err := cl.ReadPacket(&packets.FixedHeader{}) require.Error(t, err) require.Contains(t, err.Error(), "invalid packet type") @@ -624,7 +624,7 @@ func TestClientWritePacket(t *testing.T) { require.NoError(t, err, pkInfo, tt.Case, tt.Desc) time.Sleep(2 * time.Millisecond) - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) @@ -660,13 +660,13 @@ func TestClientReadPacketReadingError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -680,13 +680,13 @@ func TestClientReadPacketReadUnknown(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { - r.Write([]byte{ + _, _ = r.Write([]byte{ 0, 11, // Fixed header 0, 5, 'd', '/', 'e', '/', 'f', 'y', 'e', 'a', 'h', }) - r.Close() + _ = r.Close() }() _, err := cl.ReadPacket(&packets.FixedHeader{ @@ -706,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { func TestClientWritePacketWriteError(t *testing.T) { cl, _, _ := newTestClient() - cl.Net.Conn.Close() + _ = cl.Net.Conn.Close() err := cl.WritePacket(*pkTable[1].Packet) require.Error(t, err) diff --git a/mqtt/cmd/main.go b/mqtt/cmd/main.go index 6435940..050fcc4 100644 --- a/mqtt/cmd/main.go +++ b/mqtt/cmd/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -68,7 +68,8 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") + } diff --git a/mqtt/examples/auth/basic/main.go b/mqtt/examples/auth/basic/main.go index 1946131..84be23f 100644 --- a/mqtt/examples/auth/basic/main.go +++ b/mqtt/examples/auth/basic/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -77,7 +77,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/auth/encoded/main.go b/mqtt/examples/auth/encoded/main.go index 5a708cc..8d0095f 100644 --- a/mqtt/examples/auth/encoded/main.go +++ b/mqtt/examples/auth/encoded/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -59,7 +59,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/benchmark/main.go b/mqtt/examples/benchmark/main.go index b30fe45..b1accd2 100644 --- a/mqtt/examples/benchmark/main.go +++ b/mqtt/examples/benchmark/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -45,7 +45,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/debug/main.go b/mqtt/examples/debug/main.go index 798d1bc..6c05f58 100644 --- a/mqtt/examples/debug/main.go +++ b/mqtt/examples/debug/main.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "log" + "log/slog" "os" "os/signal" "syscall" - "github.com/rs/zerolog" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/hooks/debug" @@ -27,8 +27,12 @@ func main() { }() server := mqtt.New(nil) - l := server.Log.Level(zerolog.DebugLevel) - server.Log = &l + + level := new(slog.LevelVar) + server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: level, + })) + level.Set(slog.LevelDebug) err := server.AddHook(new(debug.Hook), &debug.Options{ // ShowPacketData: true, @@ -56,7 +60,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/direct/main.go b/mqtt/examples/direct/main.go new file mode 100644 index 0000000..885a338 --- /dev/null +++ b/mqtt/examples/direct/main.go @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" + + mqtt "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func main() { + sigs := make(chan os.Signal, 1) + done := make(chan bool, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + done <- true + }() + + server := mqtt.New(&mqtt.Options{ + InlineClient: true, // you must enable inline client to use direct publishing and subscribing. + }) + _ = server.AddHook(new(auth.AllowHook), nil) + + // Start the server + go func() { + err := server.Serve() + if err != nil { + log.Fatal(err) + } + }() + + // Demonstration of using an inline client to directly subscribe to a topic and receive a message when + // that subscription is activated. The inline subscription method uses the same internal subscription logic + // as used for external (normal) clients. + go func() { + // Inline subscriptions can also receive retained messages on subscription. + _ = server.Publish("direct/retained", []byte("retained message"), true, 0) + _ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0) + + // Subscribe to a filter and handle any received messages via a callback function. + callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) { + server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload)) + } + server.Log.Info("inline client subscribing") + _ = server.Subscribe("direct/#", 1, callbackFn) + _ = server.Subscribe("direct/#", 2, callbackFn) + }() + + // There is a shorthand convenience function, Publish, for easily sending publish packets if you are not + // concerned with creating your own packets. If you want to have more control over your packets, you can + //directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage. + go func() { + for range time.Tick(time.Second * 3) { + err := server.Publish("direct/publish", []byte("scheduled message"), false, 0) + if err != nil { + server.Log.Error("server.Publish", "error", err) + } + server.Log.Info("main.go issued direct message to direct/publish") + } + }() + + go func() { + time.Sleep(time.Second * 10) + // Unsubscribe from the same filter to stop receiving messages. + server.Log.Info("inline client unsubscribing") + _ = server.Unsubscribe("direct/#", 1) + }() + + <-done + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") +} diff --git a/mqtt/examples/hooks/main.go b/mqtt/examples/hooks/main.go index f4be075..f5835d7 100644 --- a/mqtt/examples/hooks/main.go +++ b/mqtt/examples/hooks/main.go @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "bytes" + "fmt" "log" "os" "os/signal" @@ -62,9 +63,9 @@ func main() { Payload: []byte("injected scheduled message"), }) if err != nil { - server.Log.Error().Err(err).Msg("server.InjectPacket") + server.Log.Error("server.InjectPacket", "error", err) } - server.Log.Info().Msgf("main.go injected packet to direct/publish") + server.Log.Info("main.go injected packet to direct/publish") } }() @@ -74,16 +75,16 @@ func main() { for range time.Tick(time.Second * 5) { err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) if err != nil { - server.Log.Error().Err(err).Msg("server.Publish") + server.Log.Error("server.Publish", "error", err) } - server.Log.Info().Msgf("main.go issued direct message to direct/publish") + server.Log.Info("main.go issued direct message to direct/publish") } }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } type ExampleHook struct { @@ -106,39 +107,44 @@ func (h *ExampleHook) Provides(b byte) bool { } func (h *ExampleHook) Init(config any) error { - h.Log.Info().Msg("initialised") + h.Log.Info("initialised") return nil } func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { - h.Log.Info().Str("client", cl.ID).Msgf("client connected") + h.Log.Info("client connected", "client", cl.ID) return nil } func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { - h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected") + if err != nil { + h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err) + } else { + h.Log.Info("client disconnected", "client", cl.ID, "expire", expire) + } + } func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { - h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes) + h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters) } func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { - h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed") + h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters) } func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { - h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client") + h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload)) pkx := pk if string(pk.Payload) == "hello" { pkx.Payload = []byte("hello world") - h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client") + h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload)) } return pkx, nil } func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { - h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client") + h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload)) } diff --git a/mqtt/examples/paho.testing/main.go b/mqtt/examples/paho.testing/main.go index 03694f4..66ce193 100644 --- a/mqtt/examples/paho.testing/main.go +++ b/mqtt/examples/paho.testing/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -45,9 +45,9 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } type pahoAuthHook struct { diff --git a/mqtt/examples/persistence/badger/main.go b/mqtt/examples/persistence/badger/main.go index a27fd09..c3a6e38 100644 --- a/mqtt/examples/persistence/badger/main.go +++ b/mqtt/examples/persistence/badger/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -52,8 +52,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") - + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/persistence/bolt/main.go b/mqtt/examples/persistence/bolt/main.go index 2b2c973..685d59a 100644 --- a/mqtt/examples/persistence/bolt/main.go +++ b/mqtt/examples/persistence/bolt/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -30,12 +30,15 @@ func main() { server := mqtt.New(nil) _ = server.AddHook(new(auth.AllowHook), nil) - err := server.AddHook(new(bolt.Hook), bolt.Options{ + err := server.AddHook(new(bolt.Hook), &bolt.Options{ Path: "bolt.db", Options: &bbolt.Options{ Timeout: 500 * time.Millisecond, }, }) + if err != nil { + log.Fatal(err) + } tcp := listeners.NewTCP("t1", ":1883", nil) err = server.AddListener(tcp) @@ -51,7 +54,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/persistence/redis/main.go b/mqtt/examples/persistence/redis/main.go index 4d0e14b..53d8926 100644 --- a/mqtt/examples/persistence/redis/main.go +++ b/mqtt/examples/persistence/redis/main.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main import ( "log" + "log/slog" "os" "os/signal" "syscall" - "github.com/rs/zerolog" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/redis" @@ -30,8 +30,12 @@ func main() { server := mqtt.New(nil) _ = server.AddHook(new(auth.AllowHook), nil) - l := server.Log.Level(zerolog.DebugLevel) - server.Log = &l + + level := new(slog.LevelVar) + server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: level, + })) + level.Set(slog.LevelDebug) err := server.AddHook(new(redis.Hook), &redis.Options{ Options: &rv8.Options{ @@ -58,8 +62,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") - + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/tcp/main.go b/mqtt/examples/tcp/main.go index 1060b95..8a51a70 100644 --- a/mqtt/examples/tcp/main.go +++ b/mqtt/examples/tcp/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -52,7 +52,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/tls/main.go b/mqtt/examples/tls/main.go index 52a4b62..837158e 100644 --- a/mqtt/examples/tls/main.go +++ b/mqtt/examples/tls/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -111,7 +111,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/examples/websocket/main.go b/mqtt/examples/websocket/main.go index 1a48d1e..1757bcc 100644 --- a/mqtt/examples/websocket/main.go +++ b/mqtt/examples/websocket/main.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package main @@ -41,7 +41,7 @@ func main() { }() <-done - server.Log.Warn().Msg("caught signal, stopping...") - server.Close() - server.Log.Info().Msg("main.go finished") + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") } diff --git a/mqtt/hooks.go b/mqtt/hooks.go index e6f8a8b..9746675 100644 --- a/mqtt/hooks.go +++ b/mqtt/hooks.go @@ -1,20 +1,19 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co -// SPDX-FileContributor: mochi-co, wind +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, wind, thedevop, dgduncan package mqtt import ( "errors" "fmt" + "log/slog" "sync" "sync/atomic" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - - "github.com/rs/zerolog" ) const ( @@ -74,7 +73,7 @@ type Hook interface { Provides(b byte) bool Init(config any) error Stop() error - SetOpts(l *zerolog.Logger, o *HookOptions) + SetOpts(l *slog.Logger, o *HookOptions) OnStarted() OnStopped() OnConnectAuthenticate(cl *Client, pk packets.Packet) bool @@ -125,11 +124,11 @@ type HookOptions struct { // Hooks is a slice of Hook interfaces to be called in sequence. type Hooks struct { - Log *zerolog.Logger // a logger for the hook (from the server) - internal atomic.Value // a slice of []Hook - wg sync.WaitGroup // a waitgroup for syncing hook shutdown - qty int64 // the number of hooks in use - sync.Mutex // a mutex for locking when adding hooks + Log *slog.Logger // a logger for the hook (from the server) + internal atomic.Value // a slice of []Hook + wg sync.WaitGroup // a waitgroup for syncing hook shutdown + qty int64 // the number of hooks in use + sync.Mutex // a mutex for locking when adding hooks } // Len returns the number of hooks added. @@ -187,9 +186,9 @@ func (h *Hooks) GetAll() []Hook { func (h *Hooks) Stop() { go func() { for _, hook := range h.GetAll() { - h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") + h.Log.Info("stopping hook", "hook", hook.ID()) if err := hook.Stop(); err != nil { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") + h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID()) } h.wg.Done() @@ -274,7 +273,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, if hook.Provides(OnPacketRead) { npk, err := hook.OnPacketRead(cl, pkx) if err != nil && errors.Is(err, packets.ErrRejectPacket) { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected") + h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx) return pk, err } else if err != nil { continue @@ -402,10 +401,16 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er npk, err := hook.OnPublish(cl, pkx) if err != nil { if errors.Is(err, packets.ErrRejectPacket) { - h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected") + h.Log.Debug("publish packet rejected", + "error", err, + "hook", hook.ID(), + "packet", pkx) return pk, err } - h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet error") + h.Log.Error("publish packet error", + "error", err, + "hook", hook.ID(), + "packet", pkx) return pk, err } pkx = npk @@ -504,7 +509,10 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will { if hook.Provides(OnWill) { mlwt, err := hook.OnWill(cl, will) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error") + h.Log.Error("parse will error", + "error", err, + "hook", hook.ID(), + "will", will) continue } will = mlwt @@ -548,7 +556,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) { if hook.Provides(StoredClients) { v, err := hook.StoredClients() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients") + h.Log.Error("failed to load clients", "error", err, "hook", hook.ID()) return v, err } @@ -568,7 +576,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { if hook.Provides(StoredSubscriptions) { v, err := hook.StoredSubscriptions() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions") + h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID()) return v, err } @@ -588,7 +596,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { if hook.Provides(StoredInflightMessages) { v, err := hook.StoredInflightMessages() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages") + h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID()) return v, err } @@ -608,7 +616,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { if hook.Provides(StoredRetainedMessages) { v, err := hook.StoredRetainedMessages() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages") + h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID()) return v, err } @@ -627,7 +635,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { if hook.Provides(StoredSysInfo) { v, err := hook.StoredSysInfo() if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info") + h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID()) return v, err } @@ -646,7 +654,7 @@ func (h *Hooks) StoredClientByCid(cid string) (v storage.Client, err error) { if hook.Provides(StoredClientByCid) { v, err := hook.StoredClientByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients") + h.Log.Error("failed to load clients", "error", err, "hook", hook.ID()) return v, err } @@ -665,7 +673,7 @@ func (h *Hooks) StoredSubscriptionsByCid(cid string) (v []storage.Subscription, if hook.Provides(StoredSubscriptionsByCid) { v, err := hook.StoredSubscriptionsByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get subscriptions") + h.Log.Error("failed to get subscriptions", "error", err, "hook", hook.ID()) return v, err } @@ -684,7 +692,7 @@ func (h *Hooks) StoredInflightMessagesByCid(cid string) (v []storage.Message, er if hook.Provides(StoredInflightMessagesByCid) { v, err := hook.StoredInflightMessagesByCid(cid) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get inflight messages") + h.Log.Error("failed to get inflight messages", "error", err, "hook", hook.ID()) return v, err } @@ -703,7 +711,7 @@ func (h *Hooks) StoredRetainedMessageByTopic(topic string) (v storage.Message, e if hook.Provides(StoredRetainedMessageByTopic) { v, err := hook.StoredRetainedMessageByTopic(topic) if err != nil { - h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to get retained message") + h.Log.Error("failed to get retained message", "error", err, "hook", hook.ID()) return v, err } @@ -752,7 +760,7 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { // all hooks. type HookBase struct { Hook - Log *zerolog.Logger + Log *slog.Logger Opts *HookOptions } @@ -775,12 +783,12 @@ func (h *HookBase) Init(config any) error { // SetOpts is called by the server to propagate internal values and generally should // not be called manually. -func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) { +func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) { h.Log = l h.Opts = opts } -// Stop is called to gracefully shutdown the hook. +// Stop is called to gracefully shut down the hook. func (h *HookBase) Stop() error { return nil } diff --git a/mqtt/hooks/auth/allow_all.go b/mqtt/hooks/auth/allow_all.go index dd059ce..1e8e71c 100644 --- a/mqtt/hooks/auth/allow_all.go +++ b/mqtt/hooks/auth/allow_all.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth diff --git a/mqtt/hooks/auth/allow_all_test.go b/mqtt/hooks/auth/allow_all_test.go index 4815365..90fd48e 100644 --- a/mqtt/hooks/auth/allow_all_test.go +++ b/mqtt/hooks/auth/allow_all_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth diff --git a/mqtt/hooks/auth/auth.go b/mqtt/hooks/auth/auth.go index a093d2b..c4eae8d 100644 --- a/mqtt/hooks/auth/auth.go +++ b/mqtt/hooks/auth/auth.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -67,10 +67,9 @@ func (h *Hook) Init(config any) error { } } - h.Log.Info(). - Int("authentication", len(h.ledger.Auth)). - Int("acl", len(h.ledger.ACL)). - Msg("loaded auth rules") + h.Log.Info("loaded auth rules", + "authentication", len(h.ledger.Auth), + "acl", len(h.ledger.ACL)) return nil } @@ -82,11 +81,9 @@ func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { return true } - h.Log.Info(). - Str("username", string(pk.Connect.Username)). - Str("remote", cl.Net.Remote). - Msg("client failed authentication check") - + h.Log.Info("client failed authentication check", + "username", string(pk.Connect.Username), + "remote", cl.Net.Remote) return false } @@ -97,11 +94,10 @@ func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { return true } - h.Log.Debug(). - Str("client", cl.ID). - Str("username", string(cl.Properties.Username)). - Str("topic", topic). - Msg("client failed allowed ACL check") + h.Log.Debug("client failed allowed ACL check", + "client", cl.ID, + "username", string(cl.Properties.Username), + "topic", topic) return false } diff --git a/mqtt/hooks/auth/auth_test.go b/mqtt/hooks/auth/auth_test.go index 60059e5..ee6bd09 100644 --- a/mqtt/hooks/auth/auth_test.go +++ b/mqtt/hooks/auth/auth_test.go @@ -1,20 +1,19 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth import ( + "log/slog" "os" "testing" - - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/packets" ) -var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) +var logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) // func teardown(t *testing.T, path string, h *Hook) { // h.Stop() @@ -34,7 +33,7 @@ func TestBasicProvides(t *testing.T) { func TestBasicInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -42,7 +41,7 @@ func TestBasicInitBadConfig(t *testing.T) { func TestBasicInitDefaultConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -50,7 +49,7 @@ func TestBasicInitDefaultConfig(t *testing.T) { func TestBasicInitWithLedgerPointer(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := &Ledger{ Auth: []AuthRule{ @@ -79,7 +78,7 @@ func TestBasicInitWithLedgerPointer(t *testing.T) { func TestBasicInitWithLedgerJSON(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -93,7 +92,7 @@ func TestBasicInitWithLedgerJSON(t *testing.T) { func TestBasicInitWithLedgerYAML(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -107,7 +106,7 @@ func TestBasicInitWithLedgerYAML(t *testing.T) { func TestBasicInitWithLedgerBadDAta(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Nil(t, h.ledger) err := h.Init(&Options{ @@ -119,7 +118,7 @@ func TestBasicInitWithLedgerBadDAta(t *testing.T) { func TestOnConnectAuthenticate(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := new(Ledger) ln.Auth = checkLedger.Auth @@ -158,7 +157,7 @@ func TestOnConnectAuthenticate(t *testing.T) { func TestOnACL(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) ln := new(Ledger) ln.Auth = checkLedger.Auth diff --git a/mqtt/hooks/auth/ledger.go b/mqtt/hooks/auth/ledger.go index e47d4ce..0a85490 100644 --- a/mqtt/hooks/auth/ledger.go +++ b/mqtt/hooks/auth/ledger.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -124,8 +124,8 @@ func (r RString) Matches(a string) bool { } // FilterMatches returns true if a filter matches a topic rule. -func (f RString) FilterMatches(a string) bool { - _, ok := MatchTopic(string(f), a) +func (r RString) FilterMatches(a string) bool { + _, ok := MatchTopic(string(r), a) return ok } @@ -205,7 +205,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { } // ACLOk returns true if the rules indicate the user is allowed to read or write to -// a specific filter or topic respectively, based on the write bool. +// a specific filter or topic respectively, based on the `write` bool. func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { // If the users map is set, always check for a predefined user first instead // of iterating through global rules. @@ -233,17 +233,31 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo return n, true } - for filter, access := range rule.Filters { - if filter.FilterMatches(topic) { - if !write && (access == ReadOnly || access == ReadWrite) { - return n, true - } else if write && (access == WriteOnly || access == ReadWrite) { - return n, true - } else { - return n, false + if write { + for filter, access := range rule.Filters { + if access == WriteOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } } } } + + if !write { + for filter, access := range rule.Filters { + if access == ReadOnly || access == ReadWrite { + if filter.FilterMatches(topic) { + return n, true + } + } + } + } + + for filter := range rule.Filters { + if filter.FilterMatches(topic) { + return n, false + } + } } } diff --git a/mqtt/hooks/auth/ledger_test.go b/mqtt/hooks/auth/ledger_test.go index 6ac8ac9..18004d1 100644 --- a/mqtt/hooks/auth/ledger_test.go +++ b/mqtt/hooks/auth/ledger_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package auth @@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) { }, } - new := &Ledger{ + n := &Ledger{ Auth: AuthRules{ {Remote: "127.0.0.1", Allow: true}, {Remote: "192.168.*", Allow: true}, }, } - old.Update(new) + old.Update(n) require.Len(t, old.Auth, 2) require.Equal(t, RString("192.168.*"), old.Auth[1].Remote) - require.NotSame(t, new, old) + require.NotSame(t, n, old) } func TestLedgerToJSON(t *testing.T) { diff --git a/mqtt/hooks/debug/debug.go b/mqtt/hooks/debug/debug.go index f8c5f86..03921ea 100644 --- a/mqtt/hooks/debug/debug.go +++ b/mqtt/hooks/debug/debug.go @@ -1,17 +1,17 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package debug import ( + "fmt" + "log/slog" "strings" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage" "github.com/wind-c/comqtt/v2/mqtt/packets" - - "github.com/rs/zerolog" ) // Options contains configuration settings for the debug output. @@ -25,7 +25,7 @@ type Options struct { type Hook struct { mqtt.HookBase config *Options - Log *zerolog.Logger + Log *slog.Logger } // ID returns the ID of the hook. @@ -54,25 +54,25 @@ func (h *Hook) Init(config any) error { } // SetOpts is called when the hook receives inheritable server parameters. -func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) { +func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) { h.Log = l - h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send() + h.Log.Debug("", "method", "SetOpts") } // Stop is called when the hook is stopped. func (h *Hook) Stop() error { - h.Log.Debug().Str("method", "Stop").Send() + h.Log.Debug("", "method", "Stop") return nil } // OnStarted is called when the server starts. func (h *Hook) OnStarted() { - h.Log.Debug().Str("method", "OnStarted").Send() + h.Log.Debug("", "method", "OnStarted") } // OnStopped is called when the server stops. func (h *Hook) OnStopped() { - h.Log.Debug().Str("method", "OnStopped").Send() + h.Log.Debug("", "method", "OnStopped") } // OnPacketRead is called when a new packet is received from a client. @@ -81,8 +81,7 @@ func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, return pk, nil } - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) - + h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) return pk, nil } @@ -92,85 +91,72 @@ func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) { return } - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) + h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) } // OnRetainMessage is called when a published message is retained (or retain deleted/modified). func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic") + h.Log.Debug("retained message on topic", "m", h.packetMeta(pk)) } // OnQosPublish is called when a publish packet with Qos is issued to a subscriber. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out") + h.Log.Debug("inflight out", "m", h.packetMeta(pk)) } // OnQosComplete is called when the Qos flow for a message has been completed. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete") + h.Log.Debug("inflight complete", "m", h.packetMeta(pk)) } // OnQosDropped is called the Qos flow for a message expires. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped") + h.Log.Debug("inflight dropped", "m", h.packetMeta(pk)) } -// OnLWTSent is called when a will message has been issued from a disconnecting client. +// OnLWTSent is called when a Will Message has been issued from a disconnecting client. func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { - h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client") + h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID) } // OnRetainedExpired is called when the server clears expired retained messages. func (h *Hook) OnRetainedExpired(filter string) { - h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired") + h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter) } // OnClientExpired is called when the server clears an expired client. func (h *Hook) OnClientExpired(cl *mqtt.Client) { - h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired") + h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID) } // StoredClients is called when the server restores clients from a store. func (h *Hook) StoredClients() (v []storage.Client, err error) { - h.Log.Debug(). - Str("method", "StoredClients"). - Send() + h.Log.Debug("", "method", "StoredClients") return v, nil } -// StoredClients is called when the server restores subscriptions from a store. +// StoredSubscriptions is called when the server restores subscriptions from a store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { - h.Log.Debug(). - Str("method", "StoredSubscriptions"). - Send() - + h.Log.Debug("", "method", "StoredSubscriptions") return v, nil } -// StoredClients is called when the server restores retained messages from a store. +// StoredRetainedMessages is called when the server restores retained messages from a store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { - h.Log.Debug(). - Str("method", "StoredRetainedMessages"). - Send() - + h.Log.Debug("", "method", "StoredRetainedMessages") return v, nil } -// StoredClients is called when the server restores inflight messages from a store. +// StoredInflightMessages is called when the server restores inflight messages from a store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { - h.Log.Debug(). - Str("method", "StoredInflightMessages"). - Send() - + h.Log.Debug("", "method", "StoredInflightMessages") return v, nil } -// StoredClients is called when the server restores system info from a store. +// StoredSysInfo is called when the server restores system info from a store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { - h.Log.Debug(). - Str("method", "StoredClients"). - Send() + h.Log.Debug("", "method", "StoredSysInfo") return v, nil } diff --git a/mqtt/hooks/storage/badger/badger.go b/mqtt/hooks/storage/badger/badger.go index 701d844..f714d62 100644 --- a/mqtt/hooks/storage/badger/badger.go +++ b/mqtt/hooks/storage/badger/badger.go @@ -1,12 +1,13 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co -// SPDX-FileContributor: mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co, gsagula package badger import ( "bytes" "errors" + "fmt" "strings" "github.com/wind-c/comqtt/v2/mqtt" @@ -127,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -136,7 +136,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -165,14 +165,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data") + h.Log.Error("failed to upsert client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if their session has expired. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -188,14 +188,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.Delete(clientKey(cl), new(storage.Client)) if err != nil { - h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data") + h.Log.Error("failed to delete client data", "error", err, "data", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -217,7 +217,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data") + h.Log.Error("failed to upsert subscription data", "error", err, "data", in) } } } @@ -225,14 +225,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } for i := 0; i < len(pk.Filters); i++ { err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription)) if err != nil { - h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data") + h.Log.Error("failed to delete subscription data", "error", err, "data", subscriptionKey(cl, pk.Filters[i].Filter)) } } } @@ -240,14 +240,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if r == -1 { err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete retained message data", "error", err, "data", retainedKey(pk.TopicName)) } return @@ -276,14 +276,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data") + h.Log.Error("failed to upsert retained message data", "error", err, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -312,27 +312,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data") + h.Log.Error("failed to upsert qos inflight data", "error", err, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(inflightKey(cl, pk), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data") + h.Log.Error("failed to delete inflight message data", "error", err, "data", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -341,7 +341,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -353,40 +353,40 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.Upsert(in.ID, in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data") + h.Log.Error("failed to upsert $SYS data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(retainedKey(filter), new(storage.Message)) if err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") + h.Log.Error("failed to delete expired retained message data", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.Delete(clientKey(cl), new(storage.Client)) if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") + h.Log.Error("failed to delete expired client data", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -401,7 +401,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -416,7 +416,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -431,7 +431,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -446,7 +446,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -460,20 +460,21 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { // Errorf satisfies the badger interface for an error logger. func (h *Hook) Errorf(m string, v ...interface{}) { - h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) + } // Warningf satisfies the badger interface for a warning logger. func (h *Hook) Warningf(m string, v ...interface{}) { - h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } // Infof satisfies the badger interface for an info logger. func (h *Hook) Infof(m string, v ...interface{}) { - h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } // Debugf satisfies the badger interface for a debug logger. func (h *Hook) Debugf(m string, v ...interface{}) { - h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) + h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v) } diff --git a/mqtt/hooks/storage/badger/badger_test.go b/mqtt/hooks/storage/badger/badger_test.go index a657797..47a6fd1 100644 --- a/mqtt/hooks/storage/badger/badger_test.go +++ b/mqtt/hooks/storage/badger/badger_test.go @@ -1,16 +1,16 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package badger import ( + "log/slog" "os" "strings" "testing" "time" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/timshannon/badgerhold" "github.com/wind-c/comqtt/v2/mqtt" @@ -20,7 +20,7 @@ import ( ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -38,8 +38,8 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() - h.db.Badger().Close() + _ = h.Stop() + _ = h.db.Badger().Close() err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) require.NoError(t, err) } @@ -95,7 +95,7 @@ func TestProvides(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -103,7 +103,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitUseDefaults(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) { func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -146,7 +146,7 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { func TestOnClientExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -170,13 +170,13 @@ func TestOnClientExpired(t *testing.T) { func TestOnClientExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnClientExpired(client) } func TestOnClientExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -185,13 +185,13 @@ func TestOnClientExpiredClosedDB(t *testing.T) { func TestOnSessionEstablishedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSessionEstablished(client, packets.Packet{}) } func TestOnSessionEstablishedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -200,7 +200,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) { func TestOnWillSent(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -219,13 +219,13 @@ func TestOnWillSent(t *testing.T) { func TestOnDisconnectNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnDisconnect(client, nil, false) } func TestOnDisconnectClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -234,7 +234,7 @@ func TestOnDisconnectClosedDB(t *testing.T) { func TestOnDisconnectSessionTakenOver(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -257,12 +257,12 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) { func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.OnSubscribed(client, pkf, []byte{0}, []int{0}) r := new(storage.Subscription) err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) @@ -271,7 +271,7 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { require.Equal(t, pkf.Filters[0].Filter, r.Filter) require.Equal(t, byte(0), r.Qos) - h.OnUnsubscribed(client, pkf, nil, nil) + h.OnUnsubscribed(client, pkf, []byte{0}, []int{0}) err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) require.Error(t, err) require.Equal(t, badgerhold.ErrNotFound, err) @@ -279,37 +279,37 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { func TestOnSubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.SetOpts(logger, nil) + h.OnSubscribed(client, pkf, []byte{0}, []int{0}) } func TestOnSubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) - h.OnSubscribed(client, pkf, []byte{0}, nil) + h.OnSubscribed(client, pkf, []byte{0}, []int{0}) } func TestOnUnsubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) - h.OnUnsubscribed(client, pkf, nil, nil) + h.SetOpts(logger, nil) + h.OnUnsubscribed(client, pkf, []byte{0}, []int{0}) } func TestOnUnsubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) - h.OnUnsubscribed(client, pkf, nil, nil) + h.OnUnsubscribed(client, pkf, []byte{0}, []int{0}) } func TestOnRetainMessageThenUnset(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -344,7 +344,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) { func TestOnRetainedExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -371,13 +371,13 @@ func TestOnRetainedExpired(t *testing.T) { func TestOnRetainExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainedExpired("a/b/c") } func TestOnRetainExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -386,13 +386,13 @@ func TestOnRetainExpiredClosedDB(t *testing.T) { func TestOnRetainMessageNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainMessage(client, packets.Packet{}, 0) } func TestOnRetainMessageClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -401,7 +401,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) { func TestOnQosPublishThenQOSComplete(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -436,13 +436,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) { func TestOnQosPublishNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) } func TestOnQosPublishClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -451,13 +451,13 @@ func TestOnQosPublishClosedDB(t *testing.T) { func TestOnQosCompleteNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosComplete(client, packets.Packet{}) } func TestOnQosCompleteClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -466,13 +466,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) { func TestOnQosDroppedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosDropped(client, packets.Packet{}) } func TestOnSysInfoTick(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -494,13 +494,13 @@ func TestOnSysInfoTick(t *testing.T) { func TestOnSysInfoTickNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSysInfoTick(new(system.Info)) } func TestOnSysInfoTickClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -509,7 +509,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) { func TestStoredClients(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -534,7 +534,7 @@ func TestStoredClients(t *testing.T) { func TestStoredClientsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredClients() require.Empty(t, v) require.NoError(t, err) @@ -542,7 +542,7 @@ func TestStoredClientsNoDB(t *testing.T) { func TestStoredSubscriptions(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -567,7 +567,7 @@ func TestStoredSubscriptions(t *testing.T) { func TestStoredSubscriptionsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSubscriptions() require.Empty(t, v) require.NoError(t, err) @@ -575,7 +575,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) { func TestStoredRetainedMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -603,7 +603,7 @@ func TestStoredRetainedMessages(t *testing.T) { func TestStoredRetainedMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredRetainedMessages() require.Empty(t, v) require.NoError(t, err) @@ -611,7 +611,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) { func TestStoredInflightMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -639,7 +639,7 @@ func TestStoredInflightMessages(t *testing.T) { func TestStoredInflightMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredInflightMessages() require.Empty(t, v) require.NoError(t, err) @@ -647,7 +647,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) { func TestStoredSysInfo(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -669,7 +669,7 @@ func TestStoredSysInfo(t *testing.T) { func TestStoredSysInfoNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSysInfo() require.Empty(t, v) require.NoError(t, err) @@ -678,27 +678,27 @@ func TestStoredSysInfoNoDB(t *testing.T) { func TestErrorf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Errorf("test", 1, 2, 3) } func TestWarningf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Warningf("test", 1, 2, 3) } func TestInfof(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Infof("test", 1, 2, 3) } func TestDebugf(t *testing.T) { // coverage: one day check log hook h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.Debugf("test", 1, 2, 3) } diff --git a/mqtt/hooks/storage/bolt/bolt.go b/mqtt/hooks/storage/bolt/bolt.go index 1e42f3e..374e493 100644 --- a/mqtt/hooks/storage/bolt/bolt.go +++ b/mqtt/hooks/storage/bolt/bolt.go @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co -// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. + +// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. package bolt import ( @@ -132,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -141,7 +141,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -169,14 +169,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { } err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data") + h.Log.Error("failed to save client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if they were using a clean session. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -190,14 +190,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) if err != nil && !errors.Is(err, storm.ErrNotFound) { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -219,10 +219,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save subscription data") + h.Log.Error("failed to save subscription data", "error", err, "client", cl.ID, "data", in) } } } @@ -230,7 +227,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -239,9 +236,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] ID: subscriptionKey(cl, pk.Filters[i].Filter), }) if err != nil { - h.Log.Error().Err(err). - Str("id", subscriptionKey(cl, pk.Filters[i].Filter)). - Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", subscriptionKey(cl, pk.Filters[i].Filter)) } } } @@ -249,7 +244,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -258,9 +253,7 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { ID: retainedKey(pk.TopicName), }) if err != nil { - h.Log.Error().Err(err). - Str("id", retainedKey(pk.TopicName)). - Msg("failed to delete retained publish") + h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(pk.TopicName)) } return } @@ -287,17 +280,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { } err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save retained publish data") + h.Log.Error("failed to save retained publish data", "error", err, "client", cl.ID, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -325,17 +315,14 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Str("client", cl.ID). - Interface("data", in). - Msg("failed to save qos inflight data") + h.Log.Error("failed to save qos inflight data", "error", err, "client", cl.ID, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -343,16 +330,14 @@ func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { ID: inflightKey(cl, pk), }) if err != nil { - h.Log.Error().Err(err). - Str("id", inflightKey(cl, pk)). - Msg("failed to delete inflight data") + h.Log.Error("failed to delete inflight data", "error", err, "id", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -361,7 +346,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -373,41 +358,39 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.Save(in) if err != nil { - h.Log.Error().Err(err). - Interface("data", in). - Msg("failed to save $SYS data") + h.Log.Error("failed to save $SYS data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") + h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) if err != nil && !errors.Is(err, storm.ErrNotFound) { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") + h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -422,7 +405,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -437,7 +420,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -452,7 +435,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -467,7 +450,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } diff --git a/mqtt/hooks/storage/bolt/bolt_test.go b/mqtt/hooks/storage/bolt/bolt_test.go index f9175a4..4d7a8f7 100644 --- a/mqtt/hooks/storage/bolt/bolt_test.go +++ b/mqtt/hooks/storage/bolt/bolt_test.go @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package bolt import ( + "log/slog" "os" "testing" "time" @@ -15,12 +16,11 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/system" "github.com/asdine/storm/v3" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -38,8 +38,8 @@ var ( ) func teardown(t *testing.T, path string, h *Hook) { - h.Stop() - err := os.RemoveAll(path) + _ = h.Stop() + err := os.Remove(path) require.NoError(t, err) } @@ -94,7 +94,7 @@ func TestProvides(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -102,7 +102,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitUseDefaults(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadPath(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Path: "..", }) @@ -122,7 +122,7 @@ func TestInitBadPath(t *testing.T) { func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -155,13 +155,13 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { func TestOnSessionEstablishedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSessionEstablished(client, packets.Packet{}) } func TestOnSessionEstablishedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -170,7 +170,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) { func TestOnWillSent(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -189,7 +189,7 @@ func TestOnWillSent(t *testing.T) { func TestOnClientExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -213,7 +213,7 @@ func TestOnClientExpired(t *testing.T) { func TestOnClientExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -222,19 +222,19 @@ func TestOnClientExpiredClosedDB(t *testing.T) { func TestOnClientExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnClientExpired(client) } func TestOnDisconnectNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnDisconnect(client, nil, false) } func TestOnDisconnectClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -243,7 +243,7 @@ func TestOnDisconnectClosedDB(t *testing.T) { func TestOnDisconnectSessionTakenOver(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) @@ -266,7 +266,7 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) { func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -288,13 +288,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { func TestOnSubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSubscribed(client, pkf, []byte{0}, nil) } func TestOnSubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -303,13 +303,13 @@ func TestOnSubscribedClosedDB(t *testing.T) { func TestOnUnsubscribedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnUnsubscribed(client, pkf, nil, nil) } func TestOnUnsubscribedClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -318,7 +318,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) { func TestOnRetainMessageThenUnset(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -353,7 +353,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) { func TestOnRetainedExpired(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -380,7 +380,7 @@ func TestOnRetainedExpired(t *testing.T) { func TestOnRetainedExpiredClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -389,19 +389,19 @@ func TestOnRetainedExpiredClosedDB(t *testing.T) { func TestOnRetainedExpiredNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainedExpired("a/b/c") } func TestOnRetainMessageNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnRetainMessage(client, packets.Packet{}, 0) } func TestOnRetainMessageClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -410,7 +410,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) { func TestOnQosPublishThenQOSComplete(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -445,13 +445,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) { func TestOnQosPublishNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) } func TestOnQosPublishClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -460,13 +460,13 @@ func TestOnQosPublishClosedDB(t *testing.T) { func TestOnQosCompleteNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosComplete(client, packets.Packet{}) } func TestOnQosCompleteClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -475,13 +475,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) { func TestOnQosDroppedNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnQosDropped(client, packets.Packet{}) } func TestOnSysInfoTick(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -503,13 +503,13 @@ func TestOnSysInfoTick(t *testing.T) { func TestOnSysInfoTickNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) h.OnSysInfoTick(new(system.Info)) } func TestOnSysInfoTickClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -518,7 +518,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) { func TestStoredClients(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -543,7 +543,7 @@ func TestStoredClients(t *testing.T) { func TestStoredClientsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredClients() require.Empty(t, v) require.NoError(t, err) @@ -551,7 +551,7 @@ func TestStoredClientsNoDB(t *testing.T) { func TestStoredClientsClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -562,7 +562,7 @@ func TestStoredClientsClosedDB(t *testing.T) { func TestStoredSubscriptions(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -587,7 +587,7 @@ func TestStoredSubscriptions(t *testing.T) { func TestStoredSubscriptionsNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSubscriptions() require.Empty(t, v) require.NoError(t, err) @@ -595,7 +595,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) { func TestStoredSubscriptionsClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -606,7 +606,7 @@ func TestStoredSubscriptionsClosedDB(t *testing.T) { func TestStoredRetainedMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -634,7 +634,7 @@ func TestStoredRetainedMessages(t *testing.T) { func TestStoredRetainedMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredRetainedMessages() require.Empty(t, v) require.NoError(t, err) @@ -642,7 +642,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) { func TestStoredRetainedMessagesClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -653,7 +653,7 @@ func TestStoredRetainedMessagesClosedDB(t *testing.T) { func TestStoredInflightMessages(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -681,7 +681,7 @@ func TestStoredInflightMessages(t *testing.T) { func TestStoredInflightMessagesNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredInflightMessages() require.Empty(t, v) require.NoError(t, err) @@ -689,7 +689,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) { func TestStoredInflightMessagesClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) @@ -700,7 +700,7 @@ func TestStoredInflightMessagesClosedDB(t *testing.T) { func TestStoredSysInfo(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h.config.Path, h) @@ -722,7 +722,7 @@ func TestStoredSysInfo(t *testing.T) { func TestStoredSysInfoNoDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) v, err := h.StoredSysInfo() require.Empty(t, v) require.NoError(t, err) @@ -730,7 +730,7 @@ func TestStoredSysInfoNoDB(t *testing.T) { func TestStoredSysInfoClosedDB(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) teardown(t, h.config.Path, h) diff --git a/mqtt/hooks/storage/redis/redis.go b/mqtt/hooks/storage/redis/redis.go index 49241a1..4ecb025 100644 --- a/mqtt/hooks/storage/redis/redis.go +++ b/mqtt/hooks/storage/redis/redis.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package redis @@ -103,7 +103,7 @@ func (h *Hook) Init(config any) error { } h.ctx = context.Background() - h.ctx.Deadline() + if config == nil { config = &Options{ Options: &redis.Options{ @@ -117,12 +117,11 @@ func (h *Hook) Init(config any) error { h.config.HPrefix = defaultHPrefix } - h.Log.Info(). - Str("address", h.config.Options.Addr). - Str("username", h.config.Options.Username). - Int("password-len", len(h.config.Options.Password)). - Int("db", h.config.Options.DB). - Msg("connecting to redis service") + h.Log.Info("connecting to redis service", + "address", h.config.Options.Addr, + "username", h.config.Options.Username, + "password-len", len(h.config.Options.Password), + "db", h.config.Options.DB) h.db = redis.NewClient(h.config.Options) _, err := h.db.Ping(context.Background()).Result() @@ -130,14 +129,15 @@ func (h *Hook) Init(config any) error { return fmt.Errorf("failed to ping service: %w", err) } - h.Log.Info().Msg("connected to redis service") + h.Log.Info("connected to redis service") return nil } // Stop closes the redis connection. func (h *Hook) Stop() error { - h.Log.Info().Msg("disconnecting from redis service") + h.Log.Info("disconnecting from redis service") + return h.db.Close() } @@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } -// OnWillSent is called when a client sends a will message and the will message is removed -// from the client record. +// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record. func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { h.updateClient(cl) } @@ -155,7 +154,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { // updateClient writes the client data to the store. func (h *Hook) updateClient(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -184,14 +183,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) { err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data") + h.Log.Error("failed to hset client data", "error", err, "data", in) } } // OnDisconnect removes a client from the store if they were using a clean session. func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -205,14 +204,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") + h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl)) } } // OnSubscribed adds one or more client subscriptions to the store. func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -234,7 +233,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data") + h.Log.Error("failed to hset subscription data", "error", err, "data", in) } } } @@ -242,14 +241,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by // OnUnsubscribed removes one or more client subscriptions from the store. func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte, counts []int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } for i := 0; i < len(pk.Filters); i++ { err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data") + h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl)) } } } @@ -257,14 +256,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] // OnRetainMessage adds a retained message for a topic to the store. func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } if r == -1 { err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName)) } return @@ -293,14 +292,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data") + h.Log.Error("failed to hset retained message data", "error", err, "data", in) } } // OnQosPublish adds or updates an inflight message in the store. func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -328,27 +327,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data") + h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in) } } // OnQosComplete removes a resolved inflight message from the store. func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") + h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk)) } } // OnQosDropped removes a dropped inflight message from the store. func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) } h.OnQosComplete(cl, pk) @@ -357,7 +356,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { // OnSysInfoTick stores the latest system info in the store. func (h *Hook) OnSysInfoTick(sys *system.Info) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -369,53 +368,53 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err() if err != nil { - h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data") + h.Log.Error("failed to hset server info data", "error", err, "data", in) } } // OnRetainedExpired deletes expired retained messages from the store. func (h *Hook) OnRetainedExpired(filter string) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") + h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter)) } } // OnClientExpired deleted expired clients from the store. func (h *Hook) OnClientExpired(cl *mqtt.Client) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() if err != nil { - h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") + h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl)) } } // StoredClients returns all stored clients from the store. func (h *Hook) StoredClients() (v []storage.Client, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll client data") + h.Log.Error("failed to HGetAll client data", "error", err) return } for _, row := range rows { var d storage.Client if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data") + h.Log.Error("failed to unmarshal client data", "error", err, "data", row) } v = append(v, d) @@ -427,20 +426,20 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) { // StoredSubscriptions returns all stored subscriptions from the store. func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll subscription data") + h.Log.Error("failed to HGetAll subscription data", "error", err) return } for _, row := range rows { var d storage.Subscription if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data") + h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row) } v = append(v, d) @@ -452,20 +451,20 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { // StoredRetainedMessages returns all stored retained messages from the store. func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll retained message data") + h.Log.Error("failed to HGetAll retained message data", "error", err) return } for _, row := range rows { var d storage.Message if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data") + h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row) } v = append(v, d) @@ -477,20 +476,20 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { // StoredInflightMessages returns all stored inflight messages from the store. func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() if err != nil && !errors.Is(err, redis.Nil) { - h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data") + h.Log.Error("failed to HGetAll inflight message data", "error", err) return } for _, row := range rows { var d storage.Message if err = d.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") + h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row) } v = append(v, d) @@ -502,7 +501,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { // StoredSysInfo returns the system info from the store. func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { if h.db == nil { - h.Log.Error().Err(storage.ErrDBFileNotOpen) + h.Log.Error("", "error", storage.ErrDBFileNotOpen) return } @@ -512,7 +511,7 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { } if err = v.UnmarshalBinary([]byte(row)); err != nil { - h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data") + h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row) } return v, nil diff --git a/mqtt/hooks/storage/redis/redis_test.go b/mqtt/hooks/storage/redis/redis_test.go index 0226806..9a41af4 100644 --- a/mqtt/hooks/storage/redis/redis_test.go +++ b/mqtt/hooks/storage/redis/redis_test.go @@ -5,6 +5,7 @@ package redis import ( + "log/slog" "os" "sort" "testing" @@ -17,12 +18,11 @@ import ( miniredis "github.com/alicebob/miniredis/v2" redis "github.com/go-redis/redis/v8" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) client = &mqtt.Client{ ID: "test", @@ -41,7 +41,7 @@ var ( func newHook(t *testing.T, addr string) *Hook { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Options: &redis.Options{ @@ -87,13 +87,13 @@ func TestSysInfoKey(t *testing.T) { func TestID(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Equal(t, "redis-db", h.ID()) } func TestProvides(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.True(t, h.Provides(mqtt.OnSessionEstablished)) require.True(t, h.Provides(mqtt.OnDisconnect)) require.True(t, h.Provides(mqtt.OnSubscribed)) @@ -116,7 +116,7 @@ func TestHKey(t *testing.T) { s := miniredis.RunT(t) defer s.Close() h := newHook(t, s.Addr()) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) require.Equal(t, defaultHPrefix+"test", h.hKey("test")) } @@ -126,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) { defer s.Close() h := newHook(t, defaultAddr) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(nil) require.NoError(t, err) defer teardown(t, h) @@ -137,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadConfig(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(map[string]any{}) require.Error(t, err) @@ -145,7 +145,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitBadAddr(t *testing.T) { h := new(Hook) - h.SetOpts(&logger, nil) + h.SetOpts(logger, nil) err := h.Init(&Options{ Options: &redis.Options{ Addr: "abc:123", diff --git a/mqtt/hooks/storage/storage.go b/mqtt/hooks/storage/storage.go index 8cabf01..b7f6ce4 100644 --- a/mqtt/hooks/storage/storage.go +++ b/mqtt/hooks/storage/storage.go @@ -25,7 +25,7 @@ var ( ErrDBFileNotOpen = errors.New("db file not open") ) -// Client is a storable representation of an mqtt client. +// Client is a storable representation of an MQTT client. type Client struct { Will ClientWill `json:"will,omitempty"` // will topic and payload data if applicable Properties ClientProperties `json:"properties,omitempty"` // the connect properties for the client @@ -55,9 +55,9 @@ type ClientProperties struct { // ClientWill contains a will message for a client, and limited mqtt v5 properties. type ClientWill struct { - TopicName string `json:"topicName,omitempty"` Payload []byte `json:"payload,omitempty"` User []packets.UserProperty `json:"user,omitempty"` + TopicName string `json:"topicName,omitempty"` Flag uint32 `json:"flag,omitempty"` WillDelayInterval uint32 `json:"willDelayInterval,omitempty"` Qos byte `json:"qos,omitempty"` @@ -147,7 +147,7 @@ func (d *Message) ToPacket() packets.Packet { return pk } -// Subscription is a storable representation of an mqtt subscription. +// Subscription is a storable representation of an MQTT subscription. type Subscription struct { T string `json:"t,omitempty"` ID string `json:"id,omitempty" storm:"id"` diff --git a/mqtt/hooks/storage/storage_test.go b/mqtt/hooks/storage/storage_test.go index daee469..1cdc346 100644 --- a/mqtt/hooks/storage/storage_test.go +++ b/mqtt/hooks/storage/storage_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package storage diff --git a/mqtt/hooks_test.go b/mqtt/hooks_test.go index 35b1fde..94ed049 100644 --- a/mqtt/hooks_test.go +++ b/mqtt/hooks_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -215,7 +215,7 @@ func TestHooksAddInitFailure(t *testing.T) { func TestHooksStop(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger err := h.Add(new(HookBase), nil) require.NoError(t, err) @@ -334,7 +334,7 @@ func TestHooksOnUnsubscribe(t *testing.T) { func TestHooksOnPublish(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -360,7 +360,7 @@ func TestHooksOnPublish(t *testing.T) { func TestHooksOnPacketRead(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -386,7 +386,7 @@ func TestHooksOnPacketRead(t *testing.T) { func TestHooksOnAuthPacket(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -404,7 +404,7 @@ func TestHooksOnAuthPacket(t *testing.T) { func TestHooksOnConnect(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -420,7 +420,7 @@ func TestHooksOnConnect(t *testing.T) { func TestHooksOnPacketEncode(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -432,7 +432,7 @@ func TestHooksOnPacketEncode(t *testing.T) { func TestHooksOnLWT(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger hook := new(modifiedHookBase) err := h.Add(hook, nil) @@ -449,7 +449,7 @@ func TestHooksOnLWT(t *testing.T) { func TestHooksStoredClients(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredClients() require.NoError(t, err) @@ -471,7 +471,7 @@ func TestHooksStoredClients(t *testing.T) { func TestHooksStoredSubscriptions(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredSubscriptions() require.NoError(t, err) @@ -493,7 +493,7 @@ func TestHooksStoredSubscriptions(t *testing.T) { func TestHooksStoredRetainedMessages(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredRetainedMessages() require.NoError(t, err) @@ -515,7 +515,7 @@ func TestHooksStoredRetainedMessages(t *testing.T) { func TestHooksStoredInflightMessages(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredInflightMessages() require.NoError(t, err) @@ -537,7 +537,7 @@ func TestHooksStoredInflightMessages(t *testing.T) { func TestHooksStoredSysInfo(t *testing.T) { h := new(Hooks) - h.Log = &logger + h.Log = logger v, err := h.StoredSysInfo() require.NoError(t, err) @@ -575,7 +575,7 @@ func TestHookBaseInit(t *testing.T) { func TestHookBaseSetOpts(t *testing.T) { h := new(HookBase) - h.SetOpts(&logger, new(HookOptions)) + h.SetOpts(logger, new(HookOptions)) require.NotNil(t, h.Log) require.NotNil(t, h.Opts) } diff --git a/mqtt/inflight.go b/mqtt/inflight.go index 51e77fc..b94b2fd 100644 --- a/mqtt/inflight.go +++ b/mqtt/inflight.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -14,12 +14,12 @@ import ( // Inflight is a map of InflightMessage keyed on packet id. type Inflight struct { - internal map[uint16]packets.Packet // internal contains the inflight packets sync.RWMutex - receiveQuota int32 // remaining inbound qos quota for flow control - sendQuota int32 // remaining outbound qos quota for flow control - maximumReceiveQuota int32 // maximum allowed receive quota - maximumSendQuota int32 // maximum allowed send quota + internal map[uint16]packets.Packet // internal contains the inflight packets + receiveQuota int32 // remaining inbound qos quota for flow control + sendQuota int32 // remaining outbound qos quota for flow control + maximumReceiveQuota int32 // maximum allowed receive quota + maximumSendQuota int32 // maximum allowed send quota } // NewInflights returns a new instance of an Inflight packets map. diff --git a/mqtt/inflight_test.go b/mqtt/inflight_test.go index 9b55d12..8de6e65 100644 --- a/mqtt/inflight_test.go +++ b/mqtt/inflight_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt diff --git a/mqtt/listeners/http_healthcheck.go b/mqtt/listeners/http_healthcheck.go new file mode 100644 index 0000000..a82e2e3 --- /dev/null +++ b/mqtt/listeners/http_healthcheck.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "context" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. +type HTTPHealthCheck struct { + sync.RWMutex + id string // the internal id of the listener + address string // the network address to bind to + config *Config // configuration values for the listener + listen *http.Server // the http server + end uint32 // ensure the close methods are only called once +} + +// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address. +func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck { + if config == nil { + config = new(Config) + } + return &HTTPHealthCheck{ + id: id, + address: address, + config: config, + } +} + +// ID returns the id of the listener. +func (l *HTTPHealthCheck) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *HTTPHealthCheck) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *HTTPHealthCheck) Protocol() string { + if l.listen != nil && l.listen.TLSConfig != nil { + return "https" + } + + return "http" +} + +// Init initializes the listener. +func (l *HTTPHealthCheck) Init(_ *slog.Logger) error { + mux := http.NewServeMux() + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + l.listen = &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Addr: l.address, + Handler: mux, + } + + if l.config.TLSConfig != nil { + l.listen.TLSConfig = l.config.TLSConfig + } + + return nil +} + +// Serve starts listening for new connections and serving responses. +func (l *HTTPHealthCheck) Serve(establish EstablishFn) { + if l.listen.TLSConfig != nil { + _ = l.listen.ListenAndServeTLS("", "") + } else { + _ = l.listen.ListenAndServe() + } +} + +// Close closes the listener and any client connections. +func (l *HTTPHealthCheck) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = l.listen.Shutdown(ctx) + } + + closeClients(l.id) +} diff --git a/mqtt/listeners/http_healthcheck_test.go b/mqtt/listeners/http_healthcheck_test.go new file mode 100644 index 0000000..1c753c1 --- /dev/null +++ b/mqtt/listeners/http_healthcheck_test.go @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Derek Duncan + +package listeners + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewHTTPHealthCheck(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "healthcheck", l.id) + require.Equal(t, testAddr, l.address) +} + +func TestHTTPHealthCheckID(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "healthcheck", l.ID()) +} + +func TestHTTPHealthCheckAddress(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, testAddr, l.Address()) +} + +func TestHTTPHealthCheckProtocol(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + require.Equal(t, "http", l.Protocol()) +} + +func TestHTTPHealthCheckTLSProtocol(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ + TLSConfig: tlsConfigBasic, + }) + + _ = l.Init(logger) + require.Equal(t, "https", l.Protocol()) +} + +func TestHTTPHealthCheckInit(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + require.NotNil(t, l.listen) + require.Equal(t, testAddr, l.listen.Addr) +} + +func TestHTTPHealthCheckServeAndClose(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // call healthcheck + resp, err := http.Get("http://localhost" + testAddr + "/healthcheck") + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck") + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { + // setup http stats listener + l := NewHTTPHealthCheck("healthcheck", testAddr, nil) + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + // make disallowed method type http request + resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.NoError(t, err) + require.NotNil(t, resp) + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + // ensure listening is closed + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.Equal(t, true, closed) + + _, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody) + require.Error(t, err) + <-o +} + +func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { + l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ + TLSConfig: tlsConfigBasic, + }) + + err := l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + l.Close(MockCloser) +} diff --git a/mqtt/listeners/http_sysinfo.go b/mqtt/listeners/http_sysinfo.go index 0454f32..63d7c28 100644 --- a/mqtt/listeners/http_sysinfo.go +++ b/mqtt/listeners/http_sysinfo.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -8,26 +8,24 @@ import ( "context" "encoding/json" "io" + "log/slog" "net/http" "sync" "sync/atomic" "time" "github.com/wind-c/comqtt/v2/mqtt/system" - - "github.com/rs/zerolog" ) // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. type HTTPStats struct { sync.RWMutex - id string // the internal id of the listener - address string // the network address to bind to - config *Config // configuration values for the listener - listen *http.Server // the http server - log *zerolog.Logger // server logger - sysInfo *system.Info // pointers to the server data - end uint32 // ensure the close methods are only called once + id string // the internal id of the listener + address string // the network address to bind to + config *Config // configuration values for the listener + listen *http.Server // the http server + sysInfo *system.Info // pointers to the server data + end uint32 // ensure the close methods are only called once handlers Handlers } @@ -48,7 +46,15 @@ func NewHTTP(id, address string, config *Config, sysInfo *system.Info, handlers // NewHTTPStats initialises and returns a new HTTP listener, listening on an address. func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats { - return NewHTTP(id, address, config, sysInfo, nil) + if config == nil { + config = new(Config) + } + return &HTTPStats{ + id: id, + address: address, + sysInfo: sysInfo, + config: config, + } } // ID returns the id of the listener. @@ -71,16 +77,9 @@ func (l *HTTPStats) Protocol() string { } // Init initializes the listener. -func (l *HTTPStats) Init(log *zerolog.Logger) error { - l.log = log - +func (l *HTTPStats) Init(_ *slog.Logger) error { mux := http.NewServeMux() - mux.HandleFunc("/mqtt/stats", l.jsonHandler) - - for path, handler := range l.handlers { - mux.HandleFunc(path, handler) - } - + mux.HandleFunc("/", l.jsonHandler) l.listen = &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, @@ -98,9 +97,9 @@ func (l *HTTPStats) Init(log *zerolog.Logger) error { // Serve starts listening for new connections and serving responses. func (l *HTTPStats) Serve(establish EstablishFn) { if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -112,7 +111,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -124,8 +123,8 @@ func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { out, err := json.MarshalIndent(info, "", "\t") if err != nil { - io.WriteString(w, err.Error()) + _, _ = io.WriteString(w, err.Error()) } - w.Write(out) + _, _ = w.Write(out) } diff --git a/mqtt/listeners/http_sysinfo_test.go b/mqtt/listeners/http_sysinfo_test.go index 34f29fc..76dbe3c 100644 --- a/mqtt/listeners/http_sysinfo_test.go +++ b/mqtt/listeners/http_sysinfo_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -42,14 +42,14 @@ func TestHTTPStatsTLSProtocol(t *testing.T) { TLSConfig: tlsConfigBasic, }, nil) - l.Init(nil) + _ = l.Init(logger) require.Equal(t, "https", l.Protocol()) } func TestHTTPStatsInit(t *testing.T) { sysInfo := new(system.Info) l := NewHTTPStats("t1", testAddr, nil, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) require.NotNil(t, l.sysInfo) @@ -65,7 +65,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) { // setup http stats listener l := NewHTTPStats("t1", testAddr, nil, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -113,7 +113,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) { TLSConfig: tlsConfigBasic, }, sysInfo) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) diff --git a/mqtt/listeners/listeners.go b/mqtt/listeners/listeners.go index 0dd8f15..429f497 100644 --- a/mqtt/listeners/listeners.go +++ b/mqtt/listeners/listeners.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -9,7 +9,7 @@ import ( "net" "sync" - "github.com/rs/zerolog" + "log/slog" ) // Config contains configuration values for a listener. @@ -22,18 +22,18 @@ type Config struct { // EstablishFn is a callback function for establishing new clients. type EstablishFn func(id string, c net.Conn) error -// CloseFunc is a callback function for closing all listener clients. +// CloseFn is a callback function for closing all listener clients. type CloseFn func(id string) // Listener is an interface for network listeners. A network listener listens // for incoming client connections and adds them to the server. type Listener interface { - Init(*zerolog.Logger) error // open the network address - Serve(EstablishFn) // starting actively listening for new connections - ID() string // return the id of the listener - Address() string // the address of the listener - Protocol() string // the protocol in use by the listener - Close(CloseFn) // stop and close the listener + Init(*slog.Logger) error // open the network address + Serve(EstablishFn) // starting actively listening for new connections + ID() string // return the id of the listener + Address() string // the address of the listener + Protocol() string // the protocol in use by the listener + Close(CloseFn) // stop and close the listener } // Listeners contains the network listeners for the broker. diff --git a/mqtt/listeners/listeners_test.go b/mqtt/listeners/listeners_test.go index aabc9f2..63e841c 100644 --- a/mqtt/listeners/listeners_test.go +++ b/mqtt/listeners/listeners_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -11,14 +11,15 @@ import ( "testing" "time" - "github.com/rs/zerolog" + "log/slog" + "github.com/stretchr/testify/require" ) const testAddr = ":22222" var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) testCertificate = []byte(`-----BEGIN CERTIFICATE----- MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB diff --git a/mqtt/listeners/mock.go b/mqtt/listeners/mock.go index 778c8e5..826f80c 100644 --- a/mqtt/listeners/mock.go +++ b/mqtt/listeners/mock.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -9,7 +9,7 @@ import ( "net" "sync" - "github.com/rs/zerolog" + "log/slog" ) // MockEstablisher is a function signature which can be used in testing. @@ -53,7 +53,7 @@ func (l *MockListener) Serve(establisher EstablishFn) { } // Init initializes the listener. -func (l *MockListener) Init(log *zerolog.Logger) error { +func (l *MockListener) Init(log *slog.Logger) error { if l.ErrListen { return fmt.Errorf("listen failure") } diff --git a/mqtt/listeners/mock_test.go b/mqtt/listeners/mock_test.go index c2170ce..46aa922 100644 --- a/mqtt/listeners/mock_test.go +++ b/mqtt/listeners/mock_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) { _, w := net.Pipe() err := MockEstablisher("t1", w) require.NoError(t, err) - w.Close() + _ = w.Close() } func TestNewMockListener(t *testing.T) { @@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) { require.Equal(t, true, closed) <-o - mocked.Init(nil) + _ = mocked.Init(nil) } func TestMockListenerClose(t *testing.T) { diff --git a/mqtt/listeners/net.go b/mqtt/listeners/net.go new file mode 100644 index 0000000..fa4ef3d --- /dev/null +++ b/mqtt/listeners/net.go @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co +// SPDX-FileContributor: Jeroen Rinzema + +package listeners + +import ( + "net" + "sync" + "sync/atomic" + + "log/slog" +) + +// Net is a listener for establishing client connections on basic TCP protocol. +type Net struct { // [MQTT-4.2.0-1] + mu sync.Mutex + listener net.Listener // a net.Listener which will listen for new clients + id string // the internal id of the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once +} + +// NewNet initialises and returns a listener serving incoming connections on the given net.Listener +func NewNet(id string, listener net.Listener) *Net { + return &Net{ + id: id, + listener: listener, + } +} + +// ID returns the id of the listener. +func (l *Net) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *Net) Address() string { + return l.listener.Addr().String() +} + +// Protocol returns the network of the listener. +func (l *Net) Protocol() string { + return l.listener.Addr().Network() +} + +// Init initializes the listener. +func (l *Net) Init(log *slog.Logger) error { + l.log = log + return nil +} + +// Serve starts waiting for new TCP connections, and calls the establish +// connection callback for any received. +func (l *Net) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listener.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn("", "error", err) + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *Net) Close(closeClients CloseFn) { + l.mu.Lock() + defer l.mu.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listener != nil { + err := l.listener.Close() + if err != nil { + return + } + } +} diff --git a/mqtt/listeners/net_test.go b/mqtt/listeners/net_test.go new file mode 100644 index 0000000..14a1ad6 --- /dev/null +++ b/mqtt/listeners/net_test.go @@ -0,0 +1,105 @@ +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewNet(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.id) +} + +func TestNetID(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "t1", l.ID()) +} + +func TestNetAddress(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, n.Addr().String(), l.Address()) +} + +func TestNetProtocol(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + require.Equal(t, "tcp", l.Protocol()) +} + +func TestNetInit(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + l.Close(MockCloser) + require.NoError(t, err) +} + +func TestNetServeAndClose(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestNetEstablishThenEnd(t *testing.T) { + n, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + l := NewNet("t1", n) + err = l.Init(logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + _, _ = net.Dial("tcp", n.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +} diff --git a/mqtt/listeners/tcp.go b/mqtt/listeners/tcp.go index ca25a67..1682734 100644 --- a/mqtt/listeners/tcp.go +++ b/mqtt/listeners/tcp.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -10,18 +10,18 @@ import ( "sync" "sync/atomic" - "github.com/rs/zerolog" + "log/slog" ) // TCP is a listener for establishing client connections on basic TCP protocol. type TCP struct { // [MQTT-4.2.0-1] sync.RWMutex - id string // the internal id of the listener - address string // the network address to bind to - listen net.Listener // a net.Listener which will listen for new clients - config *Config // configuration values for the listener - log *zerolog.Logger // server logger - end uint32 // ensure the close methods are only called once + id string // the internal id of the listener + address string // the network address to bind to + listen net.Listener // a net.Listener which will listen for new clients + config *Config // configuration values for the listener + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once } // NewTCP initialises and returns a new TCP listener, listening on an address. @@ -53,7 +53,7 @@ func (l *TCP) Protocol() string { } // Init initializes the listener. -func (l *TCP) Init(log *zerolog.Logger) error { +func (l *TCP) Init(log *slog.Logger) error { l.log = log var err error @@ -83,7 +83,7 @@ func (l *TCP) Serve(establish EstablishFn) { go func() { err = establish(l.id, conn) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } }() } diff --git a/mqtt/listeners/tcp_test.go b/mqtt/listeners/tcp_test.go index 6e577ed..636c8ab 100644 --- a/mqtt/listeners/tcp_test.go +++ b/mqtt/listeners/tcp_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -35,35 +35,33 @@ func TestTCPProtocol(t *testing.T) { } func TestTCPProtocolTLS(t *testing.T) { - // pick a random port: - l := NewTCP("t1", ":0", &Config{ + l := NewTCP("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(&logger) - require.NoError(t, err) + + _ = l.Init(logger) + defer l.listen.Close() require.Equal(t, "tcp", l.Protocol()) - err = l.listen.Close() - require.NoError(t, err) } func TestTCPInit(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) l.Close(MockCloser) require.NoError(t, err) - l2 := NewTCP("t2", ":0", &Config{ + l2 := NewTCP("t2", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err = l2.Init(&logger) + err = l2.Init(logger) l2.Close(MockCloser) require.NoError(t, err) require.NotNil(t, l2.config.TLSConfig) } func TestTCPServeAndClose(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -87,10 +85,10 @@ func TestTCPServeAndClose(t *testing.T) { } func TestTCPServeTLSAndClose(t *testing.T) { - l := NewTCP("t1", ":0", &Config{ + l := NewTCP("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -111,8 +109,8 @@ func TestTCPServeTLSAndClose(t *testing.T) { } func TestTCPEstablishThenEnd(t *testing.T) { - l := NewTCP("t1", ":0", nil) - err := l.Init(&logger) + l := NewTCP("t1", testAddr, nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -126,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("tcp", l.listen.Addr().String()) + _, _ = net.Dial("tcp", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/mqtt/listeners/unixsock.go b/mqtt/listeners/unixsock.go index 1ceaf99..5892fc9 100644 --- a/mqtt/listeners/unixsock.go +++ b/mqtt/listeners/unixsock.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: jason@zgwit.com package listeners @@ -10,17 +10,17 @@ import ( "sync" "sync/atomic" - "github.com/rs/zerolog" + "log/slog" ) // UnixSock is a listener for establishing client connections on basic UnixSock protocol. type UnixSock struct { sync.RWMutex - id string // the internal id of the listener. - address string // the network address to bind to. - listen net.Listener // a net.Listener which will listen for new clients. - log *zerolog.Logger // server logger - end uint32 // ensure the close methods are only called once. + id string // the internal id of the listener. + address string // the network address to bind to. + listen net.Listener // a net.Listener which will listen for new clients. + log *slog.Logger // server logger + end uint32 // ensure the close methods are only called once. } // NewUnixSock initialises and returns a new UnixSock listener, listening on an address. @@ -47,11 +47,11 @@ func (l *UnixSock) Protocol() string { } // Init initializes the listener. -func (l *UnixSock) Init(log *zerolog.Logger) error { +func (l *UnixSock) Init(log *slog.Logger) error { l.log = log var err error - _ = os.RemoveAll(l.address) + _ = os.Remove(l.address) l.listen, err = net.Listen("unix", l.address) return err } @@ -73,7 +73,7 @@ func (l *UnixSock) Serve(establish EstablishFn) { go func() { err = establish(l.id, conn) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } }() } diff --git a/mqtt/listeners/unixsock_test.go b/mqtt/listeners/unixsock_test.go index d09f776..06ce24d 100644 --- a/mqtt/listeners/unixsock_test.go +++ b/mqtt/listeners/unixsock_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: jason@zgwit.com package listeners @@ -38,19 +38,19 @@ func TestUnixSockProtocol(t *testing.T) { func TestUnixSockInit(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) l.Close(MockCloser) require.NoError(t, err) l2 := NewUnixSock("t2", testUnixAddr) - err = l2.Init(&logger) + err = l2.Init(logger) l2.Close(MockCloser) require.NoError(t, err) } func TestUnixSockServeAndClose(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -75,7 +75,7 @@ func TestUnixSockServeAndClose(t *testing.T) { func TestUnixSockEstablishThenEnd(t *testing.T) { l := NewUnixSock("t1", testUnixAddr) - err := l.Init(&logger) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -89,7 +89,7 @@ func TestUnixSockEstablishThenEnd(t *testing.T) { }() time.Sleep(time.Millisecond) - net.Dial("unix", l.listen.Addr().String()) + _, _ = net.Dial("unix", l.listen.Addr().String()) require.Equal(t, true, <-established) l.Close(MockCloser) <-o diff --git a/mqtt/listeners/websocket.go b/mqtt/listeners/websocket.go index 4e1f4d8..50715fc 100644 --- a/mqtt/listeners/websocket.go +++ b/mqtt/listeners/websocket.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -14,8 +14,9 @@ import ( "sync/atomic" "time" + "log/slog" + "github.com/gorilla/websocket" - "github.com/rs/zerolog" ) var ( @@ -29,8 +30,8 @@ type Websocket struct { // [MQTT-4.2.0-1] id string // the internal id of the listener address string // the network address to bind to config *Config // configuration values for the listener - listen *http.Server // an http server for serving websocket connections - log *zerolog.Logger // server logger + listen *http.Server // a http server for serving websocket connections + log *slog.Logger // server logger establish EstablishFn // the server's establish connection handler upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection. end uint32 // ensure the close methods are only called once @@ -75,7 +76,7 @@ func (l *Websocket) Protocol() string { } // Init initializes the listener. -func (l *Websocket) Init(log *zerolog.Logger) error { +func (l *Websocket) Init(log *slog.Logger) error { l.log = log mux := http.NewServeMux() @@ -101,7 +102,7 @@ func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c}) if err != nil { - l.log.Warn().Err(err).Send() + l.log.Warn("", "error", err) } } @@ -111,9 +112,9 @@ func (l *Websocket) Serve(establish EstablishFn) { l.establish = establish if l.listen.TLSConfig != nil { - l.listen.ListenAndServeTLS("", "") + _ = l.listen.ListenAndServeTLS("", "") } else { - l.listen.ListenAndServe() + _ = l.listen.ListenAndServe() } } @@ -125,7 +126,7 @@ func (l *Websocket) Close(closeClients CloseFn) { if atomic.CompareAndSwapUint32(&l.end, 0, 1) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - l.listen.Shutdown(ctx) + _ = l.listen.Shutdown(ctx) } closeClients(l.id) @@ -136,7 +137,7 @@ type wsConn struct { net.Conn c *websocket.Conn - // reader for the current message (may be nil) + // reader for the current message (can be nil) r io.Reader } diff --git a/mqtt/listeners/websocket_test.go b/mqtt/listeners/websocket_test.go index ee91f81..a2db1bb 100644 --- a/mqtt/listeners/websocket_test.go +++ b/mqtt/listeners/websocket_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package listeners @@ -37,24 +37,24 @@ func TestWebsocketProtocol(t *testing.T) { require.Equal(t, "ws", l.Protocol()) } -func TestWebsocketProtocoTLS(t *testing.T) { +func TestWebsocketProtocolTLS(t *testing.T) { l := NewWebsocket("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) require.Equal(t, "wss", l.Protocol()) } -func TestWebsockeInit(t *testing.T) { +func TestWebsocketInit(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) require.Nil(t, l.listen) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) require.NotNil(t, l.listen) } func TestWebsocketServeAndClose(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(logger) o := make(chan bool) go func(o chan bool) { @@ -77,7 +77,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) { l := NewWebsocket("t1", testAddr, &Config{ TLSConfig: tlsConfigBasic, }) - err := l.Init(nil) + err := l.Init(logger) require.NoError(t, err) o := make(chan bool) @@ -96,7 +96,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) { func TestWebsocketUpgrade(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(logger) e := make(chan bool) l.establish = func(id string, c net.Conn) error { @@ -110,12 +110,12 @@ func TestWebsocketUpgrade(t *testing.T) { require.Equal(t, true, <-e) s.Close() - ws.Close() + _ = ws.Close() } func TestWebsocketConnectionReads(t *testing.T) { l := NewWebsocket("t1", testAddr, nil) - l.Init(nil) + _ = l.Init(nil) recv := make(chan []byte) l.establish = func(id string, c net.Conn) error { @@ -151,5 +151,5 @@ func TestWebsocketConnectionReads(t *testing.T) { require.Equal(t, pkt, got) s.Close() - ws.Close() + _ = ws.Close() } diff --git a/mqtt/packets/codec.go b/mqtt/packets/codec.go index 029cfa7..152d777 100644 --- a/mqtt/packets/codec.go +++ b/mqtt/packets/codec.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/codec_test.go b/mqtt/packets/codec_test.go index 8b10126..9129721 100644 --- a/mqtt/packets/codec_test.go +++ b/mqtt/packets/codec_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/codes.go b/mqtt/packets/codes.go index 7e314de..5af1b74 100644 --- a/mqtt/packets/codes.go +++ b/mqtt/packets/codes.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -126,6 +126,7 @@ var ( ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} + ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."} // MQTTv3 specific bytes. Err3UnsupportedProtocolVersion = Code{Code: 0x01} diff --git a/mqtt/packets/codes_test.go b/mqtt/packets/codes_test.go index 694f47e..aed8e57 100644 --- a/mqtt/packets/codes_test.go +++ b/mqtt/packets/codes_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) { require.Equal(t, "test", c.String()) } -func TestCodesErrorr(t *testing.T) { +func TestCodesError(t *testing.T) { c := Code{ Reason: "error", Code: 0x1, diff --git a/mqtt/packets/fixedheader.go b/mqtt/packets/fixedheader.go index ddf68ca..eb20451 100644 --- a/mqtt/packets/fixedheader.go +++ b/mqtt/packets/fixedheader.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/fixedheader_test.go b/mqtt/packets/fixedheader_test.go index 8f7acf4..fe8c497 100644 --- a/mqtt/packets/fixedheader_test.go +++ b/mqtt/packets/fixedheader_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/packets.go b/mqtt/packets/packets.go index e53fe12..ff5930b 100644 --- a/mqtt/packets/packets.go +++ b/mqtt/packets/packets.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -14,7 +14,7 @@ import ( "sync" ) -// All of the valid packet types and their packet identifier. +// All valid packet types and their packet identifiers. const ( Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. Connect // 1 @@ -37,9 +37,9 @@ const ( var ( // ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification. - ErrNoValidPacketAvailable error = errors.New("no valid packet available") + ErrNoValidPacketAvailable = errors.New("no valid packet available") - // PacketNames is a map of packet bytes to human readable names, for easier debugging. + // PacketNames is a map of packet bytes to human-readable names, for easier debugging. PacketNames = map[byte]string{ 0: "Reserved", 1: "Connect", @@ -272,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription { } // encode encodes a subscription and properties into bytes. -func (p Subscription) encode() byte { +func (s Subscription) encode() byte { var flag byte - flag |= p.Qos + flag |= s.Qos - if p.NoLocal { + if s.NoLocal { flag |= 1 << 2 } - if p.RetainAsPublished { + if s.RetainAsPublished { flag |= 1 << 3 } - flag |= p.RetainHandling << 4 + flag |= s.RetainHandling << 4 return flag } // decode decodes subscription bytes into a subscription struct. -func (p *Subscription) decode(b byte) { - p.Qos = b & 3 // byte - p.NoLocal = 1&(b>>2) > 0 // bool - p.RetainAsPublished = 1&(b>>3) > 0 // bool - p.RetainHandling = 3 & (b >> 4) // byte +func (s *Subscription) decode(b byte) { + s.Qos = b & 3 // byte + s.NoLocal = 1&(b>>2) > 0 // bool + s.RetainAsPublished = 1&(b>>3) > 0 // bool + s.RetainHandling = 3 & (b >> 4) // byte } // ConnectEncode encodes a connect packet. @@ -343,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -505,7 +505,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -548,7 +548,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -619,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -707,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -844,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -901,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -996,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1049,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } @@ -1109,7 +1109,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Encode(buf) - nb.WriteTo(buf) + _, _ = nb.WriteTo(buf) return nil } diff --git a/mqtt/packets/packets_test.go b/mqtt/packets/packets_test.go index c08ff10..1e18f1f 100644 --- a/mqtt/packets/packets_test.go +++ b/mqtt/packets/packets_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) { } pk := new(Packet) - copier.Copy(pk, wanted.Packet) + _ = copier.Copy(pk, wanted.Packet) require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) pk.Mods.AllowResponseInfo = true @@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) { pk := &Packet{FixedHeader: FixedHeader{Type: pkt}} pk.Mods.AllowResponseInfo = true - pk.FixedHeader.Decode(wanted.RawBytes[0]) + _ = pk.FixedHeader.Decode(wanted.RawBytes[0]) if len(wanted.RawBytes) > 0 { pk.FixedHeader.Remaining = int(wanted.RawBytes[1]) } diff --git a/mqtt/packets/properties.go b/mqtt/packets/properties.go index ea77e2b..1fc02fd 100644 --- a/mqtt/packets/properties.go +++ b/mqtt/packets/properties.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -77,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1] Val string `json:"v"` } -// Properties contains all of the mqtt v5 properties available for a packet. +// Properties contains all mqtt v5 properties available for a packet. // Some properties have valid values of 0 or not-present. In this case, we opt for // property flags to indicate the usage of property. // Refer to mqtt v5 2.2.2.2 Property spec for more information. @@ -355,7 +355,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) { } encodeLength(b, int64(buf.Len())) - buf.WriteTo(b) // [MQTT-3.1.3-10] + _, _ = buf.WriteTo(b) // [MQTT-3.1.3-10] } // Decode decodes property bytes into a properties struct. diff --git a/mqtt/packets/properties_test.go b/mqtt/packets/properties_test.go index 8d326ba..b0a2f10 100644 --- a/mqtt/packets/properties_test.go +++ b/mqtt/packets/properties_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/packets/tpackets.go b/mqtt/packets/tpackets.go index 8c21dbd..267721e 100644 --- a/mqtt/packets/tpackets.go +++ b/mqtt/packets/tpackets.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets @@ -40,7 +40,6 @@ const ( TConnectMqtt5 TConnectMqtt5LWT TConnectClean - TConnectCleanLWT TConnectUserPass TConnectUserPassLWT TConnectMalProtocolName @@ -61,7 +60,6 @@ const ( TConnectInvalidProtocolVersion2 TConnectInvalidReservedBit TConnectInvalidClientIDTooLong - TConnectInvalidPasswordNoUsername TConnectInvalidFlagNoUsername TConnectInvalidFlagNoPassword TConnectInvalidUsernameNoFlag @@ -131,12 +129,14 @@ const ( TPublishSpecDenySysTopic TPuback TPubackMqtt5 + TPubackMqtt5NotAuthorized TPubackMalPacketID TPubackMalProperties TPubackUnexpectedError TPubrec TPubrecMqtt5 TPubrecMqtt5IDInUse + TPubrecMqtt5NotAuthorized TPubrecMalPacketID TPubrecMalProperties TPubrecMalReasonCode @@ -184,7 +184,6 @@ const ( TUnsubscribe TUnsubscribeMany TUnsubscribeMqtt5 - TUnsubscribeDropProperties TUnsubscribeMalPacketID TUnsubscribeMalTopicName TUnsubscribeMalProperties @@ -202,7 +201,6 @@ const ( TDisconnect TDisconnectTakeover TDisconnectMqtt5 - TDisconnectNormalMqtt5 TDisconnectSecondConnect TDisconnectReceiveMaximum TDisconnectDropProperties @@ -2274,6 +2272,40 @@ var TPacketData = map[byte]TPacketCases{ }, }, }, + { + Case: TPubackMqtt5NotAuthorized, + Desc: "QOS 1 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Puback << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, { Case: TPubackUnexpectedError, Desc: "unexpected error", @@ -2412,6 +2444,40 @@ var TPacketData = map[byte]TPacketCases{ }, }, }, + { + Case: TPubrecMqtt5NotAuthorized, + Desc: "QOS 2 publish not authorized mqtt5", + Primary: true, + RawBytes: []byte{ + Pubrec << 4, 37, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrNotAuthorized.Code, // Reason Code + 33, // Properties Length + 31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u', + 't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 31, + }, + PacketID: 7, + ReasonCode: ErrNotAuthorized.Code, + Properties: Properties{ + ReasonString: ErrNotAuthorized.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + }, + }, { Case: TPubrecMalReasonCode, Desc: "malformed reason code", diff --git a/mqtt/packets/tpackets_test.go b/mqtt/packets/tpackets_test.go index c50bb55..8114207 100644 --- a/mqtt/packets/tpackets_test.go +++ b/mqtt/packets/tpackets_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package packets diff --git a/mqtt/server.go b/mqtt/server.go index 30fa937..f74b1e0 100644 --- a/mqtt/server.go +++ b/mqtt/server.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co, wind // package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility. @@ -22,12 +22,14 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - "github.com/rs/zerolog" + "log/slog" ) const ( - Version = "2.3.0" // the current server version. + Version = "2.4.0" // the current server version. defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes + LocalListener = "local" + InlineClientId = "inline" ) var ( @@ -36,7 +38,7 @@ var ( MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client - MaximumQos: 2, // maxmimum qos value available to clients + MaximumQos: 2, // maximum qos value available to clients RetainAvailable: 1, // retain messages is available MaximumPacketSize: 0, // no maximum packet size TopicAliasMaximum: math.MaxUint16, // maximum topic alias value @@ -47,15 +49,16 @@ var ( MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client } - ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists. - ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists + ErrConnectionClosed = errors.New("connection not open") // connection is closed + ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default ) // Capabilities indicates the capabilities and features provided by the server. type Capabilities struct { MaximumMessageExpiryInterval int64 `yaml:"maximum-message-expiry-interval"` - MaximumSessionExpiryInterval uint32 `yaml:"maximum-session-expiry-interval"` MaximumClientWritesPending int32 `yaml:"maximum-client-writes-pending"` + MaximumSessionExpiryInterval uint32 `yaml:"maximum-session-expiry-interval"` MaximumPacketSize uint32 `yaml:"maximum-packet-size"` maximumPacketID uint32 // unexported, used for testing only ReceiveMaximum uint16 `yaml:"receive-maximum"` @@ -85,36 +88,44 @@ type Options struct { // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 Capabilities *Capabilities + // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. + ClientNetWriteBufferSize int + + // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. + ClientNetReadBufferSize int + // Logger specifies a custom configured implementation of zerolog to override // the servers default logger configuration. If you wish to change the log level, // of the default logger, you can do so by setting // server := mqtt.New(nil) - // l := server.Log.Level(zerolog.DebugLevel) - // server.Log = &l - Logger *zerolog.Logger - - // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. - ClientNetWriteBufferSize int `yaml:"client-write-buffer-size"` - - // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. - ClientNetReadBufferSize int `yaml:"client-read-buffer-size"` + // level := new(slog.LevelVar) + // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + // Level: level, + // })) + // level.Set(slog.LevelDebug) + Logger *slog.Logger // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. - SysTopicResendInterval int64 `yaml:"sys-topic-resend-interval"` + SysTopicResendInterval int64 + + // Enable Inline client to allow direct subscribing and publishing from the parent codebase, + // with negligible performance difference (disabled by default to prevent confusion in statistics). + InlineClient bool } // Server is an MQTT broker server. It should be created with server.New() // in order to ensure all the internal fields are correctly populated. type Server struct { - Options *Options // configurable server options - Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections - Clients *Clients // clients known to the broker - Topics *TopicsIndex // an index of topic filter subscriptions and retained messages - Info *system.Info // values about the server commonly known as $SYS topics - loop *loop // loop contains tickers for the system event loop - done chan bool // indicate that the server is ending - Log *zerolog.Logger // minimal no-alloc logger - hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. + Options *Options // configurable server options + Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections + Clients *Clients // clients known to the broker + Topics *TopicsIndex // an index of topic filter subscriptions and retained messages + Info *system.Info // values about the server commonly known as $SYS topics + loop *loop // loop contains tickers for the system event loop + done chan bool // indicate that the server is ending + Log *slog.Logger // minimal no-alloc logger + hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage + inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish } // loop contains interval tickers for the system events loop. @@ -123,16 +134,16 @@ type loop struct { clientExpiry *time.Ticker // interval ticker for cleaning expired clients inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages retainedExpiry *time.Ticker // interval ticker for cleaning retained messages - willDelaySend *time.Ticker // interval ticker for sending will messages with a delay + willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay willDelayed *packets.Packets // activate LWT packets which will be sent after a delay } // ops contains server values which can be propagated to other structs. type ops struct { - options *Options // a pointer to the server options and capabilities, for referencing in clients - info *system.Info // pointers to server system info - hooks *Hooks // pointer to the server hooks - log *zerolog.Logger // a structured logger for the client + options *Options // a pointer to the server options and capabilities, for referencing in clients + info *system.Info // pointers to server system info + hooks *Hooks // pointer to the server hooks + log *slog.Logger // a structured logger for the client } // New returns a new instance of comqtt broker. Optional parameters @@ -168,6 +179,11 @@ func New(opts *Options) *Server { }, } + if s.Options.InlineClient { + s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) + s.Clients.Add(s.inlineClient) + } + return s } @@ -192,8 +208,8 @@ func (o *Options) ensureDefaults() { } if o.Logger == nil { - log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr}) - o.Logger = &log + log := slog.New(slog.NewTextHandler(os.Stdout, nil)) + o.Logger = log } } @@ -227,12 +243,12 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) // AddHook attaches a new Hook to the server. Ideally, this should be called // before the server is started with s.Serve(). func (s *Server) AddHook(hook Hook, config any) error { - nl := s.Log.With().Str("hook", hook.ID()).Logger() - hook.SetOpts(&nl, &HookOptions{ + nl := s.Log.With("hook", hook.ID()) + hook.SetOpts(nl, &HookOptions{ Capabilities: s.Options.Capabilities, }) - s.Log.Info().Str("hook", hook.ID()).Msg("added hook") + s.Log.Info("added hook", "hook", hook.ID()) return s.hooks.Add(hook, config) } @@ -242,23 +258,23 @@ func (s *Server) AddListener(l listeners.Listener) error { return ErrListenerIDExists } - nl := s.Log.With().Str("listener", l.ID()).Logger() - err := l.Init(&nl) + nl := s.Log.With(slog.String("listener", l.ID())) + err := l.Init(nl) if err != nil { return err } s.Listeners.Add(l) - s.Log.Info().Str("id", l.ID()).Str("protocol", l.Protocol()).Str("address", l.Address()).Msg("attached listener") + s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address()) return nil } // Serve starts the event loops responsible for establishing client connections // on all attached listeners, publishing the system topics, and starting all hooks. func (s *Server) Serve() error { - //s.Log.Info().Str("version", Version).Msg("comqtt starting") - defer s.Log.Info().Msg("comqtt server started") + //s.Log.Info("version", Version).Msg("comqtt starting") + defer s.Log.Info("comqtt server started") if s.hooks.Provides( StoredClients, @@ -283,8 +299,8 @@ func (s *Server) Serve() error { // eventLoop loops forever, running various server housekeeping methods at different intervals. func (s *Server) eventLoop() { - s.Log.Debug().Msg("system event loop started") - defer s.Log.Debug().Msg("system event loop halted") + s.Log.Debug("system event loop started") + defer s.Log.Debug("system event loop halted") for { select { @@ -375,8 +391,8 @@ func (s *Server) attachClient(cl *Client, listener string) error { } else { cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] } + s.Log.Debug("client disconnected", "error", err, "client", cl.ID, "remote", cl.Net.Remote, "listener", listener) - s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected") expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) @@ -418,10 +434,10 @@ func (s *Server) receivePacket(cl *Client, pk packets.Packet) error { if code, ok := err.(packets.Code); ok && cl.Properties.ProtocolVersion == 5 && code.Code >= packets.ErrUnspecifiedError.Code { - s.DisconnectClient(cl, code) + _ = s.DisconnectClient(cl, code) } - s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("pk", pk).Msg("error processing packet") + s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk) return err } @@ -456,7 +472,7 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code { // session is abandoned. func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok { - s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] + _ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3] if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] s.UnsubscribeClient(existing) existing.ClearInflights(math.MaxInt64, 0) @@ -487,10 +503,8 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { // from increasing memory usage by inflights + subs * client-id. s.UnsubscribeClient(existing) existing.ClearInflights(math.MaxInt64, 0) - s.Log.Debug().Str("client", cl.ID). - Str("old_remote", existing.Net.Remote). - Str("new_remote", cl.Net.Remote). - Msg("session taken over") + + s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote) cl.InheritWay = InheritWayLocal return true // [MQTT-3.2.2-3] @@ -676,13 +690,16 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { }) } -// Publish publishes a publish packet into the broker as if it were sent from the speicfied client. +// Publish publishes a publish packet into the broker as if it were sent from the specified client. // This is a convenience function which wraps InjectPacket. As such, this method can publish packets // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { - cl := s.NewClient(nil, "local", "inline", true) - return s.InjectPacket(cl, packets.Packet{ + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + return s.InjectPacket(s.inlineClient, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, Qos: qos, @@ -694,6 +711,75 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er }) } +// Subscribe adds an inline subscription for the specified topic filter and subscription identifier +// with the provided handler function. +func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if handler == nil { + return packets.ErrInlineSubscriptionHandlerInvalid + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + subscription := packets.Subscription{ + Identifier: subscriptionId, + Filter: filter, + } + + pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client. + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Subscribe}, + Filters: packets.Subscriptions{subscription}, + }) + + inlineSubscription := InlineSubscription{ + Subscription: subscription, + Handler: handler, + } + + _, count := s.Topics.InlineSubscribe(inlineSubscription) + s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}, []int{count}) + + // Handling retained messages. + for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4] + handler(s.inlineClient, inlineSubscription.Subscription, pkv) + } + return nil +} + +// Unsubscribe removes an inline subscription for the specified subscription and topic filter. +// It allows you to unsubscribe a specific subscription from the internal subscription +// associated with the given topic filter. +func (s *Server) Unsubscribe(filter string, subscriptionId int) error { + if !s.Options.InlineClient { + return ErrInlineClientNotEnabled + } + + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{ + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, + Filters: packets.Subscriptions{ + { + Identifier: subscriptionId, + Filter: filter, + }, + }, + }) + + _, count := s.Topics.InlineUnsubscribe(subscriptionId, filter) + s.hooks.OnUnsubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}, []int{count}) + return nil +} + // InjectPacket injects a packet into the broker as if it were sent from the specified client. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { @@ -723,7 +809,21 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { } if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) { - return nil + if pk.FixedHeader.Qos == 0 { + return nil + } + + if cl.Properties.ProtocolVersion != 5 { + return s.DisconnectClient(cl, packets.ErrNotAuthorized) + } + + ackType := packets.Puback + if pk.FixedHeader.Qos == 2 { + ackType = packets.Pubrec + } + + ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized) + return cl.WritePacket(ack) } pk.Origin = cl.ID @@ -746,7 +846,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { } if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { - pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability + pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability } pkx, err := s.hooks.OnPublish(cl, pk) @@ -768,7 +868,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { s.retainMessage(cl, pk) } - if pk.FixedHeader.Qos == 0 { + // If it's inlineClient, it can't handle PUBREC and PUBREL. + // When it publishes a package with a qos > 0, the server treats + // the package as qos=0, and the client receives it as qos=1 or 2. + if pk.FixedHeader.Qos == 0 || cl.Net.Inline { s.PublishToSubscribers(pk) s.hooks.OnPublished(cl, pk) return nil @@ -841,11 +944,15 @@ func (s *Server) PublishToSubscribers(pk packets.Packet) { subscribers.MergeSharedSelected() } + for _, inlineSubscription := range subscribers.InlineSubscriptions { + inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk) + } + for id, subs := range subscribers.Subscriptions { if cl, ok := s.Clients.Get(id); ok { _, err := s.publishToClient(cl, subs, pk) if err != nil { - s.Log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet") + s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) } } } @@ -857,6 +964,9 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet } out := pk.Copy(false) + if !s.hooks.OnACLCheck(cl, pk.TopicName, false) { + return out, packets.ErrNotAuthorized + } if !sub.FwdRetainedFlag && ((cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished) || cl.Properties.ProtocolVersion < 5) { // ![MQTT-3.3.1-13] [v3 MQTT-3.3.1-9] out.FixedHeader.Retain = false // [MQTT-3.3.1-12] } @@ -892,7 +1002,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1] if err != nil { s.hooks.OnPacketIDExhausted(cl, pk) - s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted") + s.Log.Warn("packet ids exhausted", "error", err, "client", cl.ID, "listener", cl.Net.Listener) return out, packets.ErrQuotaExceeded } @@ -943,7 +1053,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4] _, err := s.publishToClient(cl, sub, pkv) if err != nil { - s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message") + s.Log.Debug("failed to publish retained message", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "packet", pkv) continue } s.hooks.OnRetainPublished(cl, pkv) @@ -1182,12 +1292,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error { func (s *Server) UnsubscribeClient(cl *Client) { i := 0 filterMap := cl.State.Subscriptions.GetAll() + + for k := range filterMap { + cl.State.Subscriptions.Delete(k) + } + + if atomic.LoadUint32(&cl.State.isTakenOver) == 1 { + return + } + length := len(filterMap) filters := make([]packets.Subscription, length) reasonCodes := make([]byte, length) counts := make([]int, length) // An array of the number of subscribers for the same filter for k, v := range filterMap { - cl.State.Subscriptions.Delete(k) q, count := s.Topics.Unsubscribe(k, cl.ID) if q { atomic.AddInt64(&s.Info.Subscriptions, -1) @@ -1200,7 +1318,7 @@ func (s *Server) UnsubscribeClient(cl *Client) { i++ } - s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters}, reasonCodes, counts) + s.hooks.OnUnsubscribed(cl, packets.Packet{FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, Filters: filters}, reasonCodes, counts) } // processAuth processes an Auth packet. @@ -1318,7 +1436,7 @@ func (s *Server) Close() error { s.hooks.OnStopped() s.hooks.Stop() - s.Log.Info().Msg("comqtt server stopped") + s.Log.Info("comqtt server stopped") return nil } @@ -1326,7 +1444,7 @@ func (s *Server) Close() error { func (s *Server) closeListenerClients(listener string) { clients := s.Clients.GetByListener(listener) for _, cl := range clients { - s.DisconnectClient(cl, packets.ErrServerShuttingDown) + _ = s.DisconnectClient(cl, packets.ErrServerShuttingDown) } } @@ -1377,9 +1495,7 @@ func (s *Server) readStore() error { return fmt.Errorf("failed to load clients; %w", err) } s.loadClients(clients) - s.Log.Debug(). - Int("len", len(clients)). - Msg("loaded clients from store") + s.Log.Debug("loaded clients from store", "len", len(clients)) } if s.hooks.Provides(StoredSubscriptions) { @@ -1388,9 +1504,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load subscriptions; %w", err) } s.loadSubscriptions(subs) - s.Log.Debug(). - Int("len", len(subs)). - Msg("loaded subscriptions from store") + s.Log.Debug("loaded subscriptions from store", "len", len(subs)) } if s.hooks.Provides(StoredInflightMessages) { @@ -1399,9 +1513,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load inflight; %w", err) } s.loadInflight(inflight) - s.Log.Debug(). - Int("len", len(inflight)). - Msg("loaded inflights from store") + s.Log.Debug("loaded inflights from store", "len", len(inflight)) } if s.hooks.Provides(StoredRetainedMessages) { @@ -1410,9 +1522,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load retained; %w", err) } s.loadRetained(retained) - s.Log.Debug(). - Int("len", len(retained)). - Msg("loaded retained messages from store") + s.Log.Debug("loaded retained messages from store", "len", len(retained)) } if s.hooks.Provides(StoredSysInfo) { @@ -1421,8 +1531,7 @@ func (s *Server) readStore() error { return fmt.Errorf("load server info; %w", err) } s.loadServerInfo(sysInfo.Info) - s.Log.Debug(). - Msg("loaded $SYS info from store") + s.Log.Debug("loaded $SYS info from store") } return nil diff --git a/mqtt/server_test.go b/mqtt/server_test.go index 3234911..3171bb0 100644 --- a/mqtt/server_test.go +++ b/mqtt/server_test.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -8,9 +8,10 @@ import ( "bytes" "encoding/binary" "io" + "log/slog" "net" - "os" "strconv" + "sync" "sync/atomic" "testing" "time" @@ -20,11 +21,10 @@ import ( "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/mqtt/system" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) -var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) +var logger = slog.New(slog.NewTextHandler(io.Discard, nil)) type ProtocolTest []struct { protocolVersion byte @@ -37,6 +37,11 @@ type AllowHook struct { HookBase } +func (h *AllowHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + func (h *AllowHook) ID() string { return "allow-all-auth" } @@ -48,11 +53,36 @@ func (h *AllowHook) Provides(b byte) bool { func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true } func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true } +type DenyHook struct { + HookBase +} + +func (h *DenyHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + +func (h *DenyHook) ID() string { + return "deny-all-auth" +} + +func (h *DenyHook) Provides(b byte) bool { + return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b}) +} + +func (h *DenyHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return false } +func (h *DenyHook) OnACLCheck(cl *Client, topic string, write bool) bool { return false } + type DelayHook struct { HookBase DisconnectDelay time.Duration } +func (h *DelayHook) SetOpts(l *slog.Logger, opts *HookOptions) { + h.Log = l + h.Opts = opts +} + func (h *DelayHook) ID() string { return "delay-hook" } @@ -69,12 +99,24 @@ func newServer() *Server { cc := *DefaultServerCapabilities cc.MaximumMessageExpiryInterval = 0 cc.ReceiveMaximum = 0 + s := New(&Options{ + Logger: logger, + Capabilities: &cc, + }) + _ = s.AddHook(new(AllowHook), nil) + return s +} +func newServerWithInlineClient() *Server { + cc := *DefaultServerCapabilities + cc.MaximumMessageExpiryInterval = 0 + cc.ReceiveMaximum = 0 s := New(&Options{ - Logger: &logger, + Logger: logger, Capabilities: &cc, + InlineClient: true, }) - s.AddHook(new(AllowHook), nil) + _ = s.AddHook(new(AllowHook), nil) return s } @@ -106,6 +148,16 @@ func TestNew(t *testing.T) { require.NotNil(t, s.hooks) require.NotNil(t, s.hooks.Log) require.NotNil(t, s.done) + require.Nil(t, s.inlineClient) + require.Equal(t, 0, s.Clients.Len()) +} + +func TestNewWithInlineClient(t *testing.T) { + s := New(&Options{ + InlineClient: true, + }) + require.NotNil(t, s.inlineClient) + require.Equal(t, 1, s.Clients.Len()) } func TestNewNilOpts(t *testing.T) { @@ -116,7 +168,7 @@ func TestNewNilOpts(t *testing.T) { func TestServerNewClient(t *testing.T) { s := New(nil) - s.Log = &logger + s.Log = logger r, _ := net.Pipe() cl := s.NewClient(r, "testing", "test", false) @@ -143,7 +195,8 @@ func TestServerNewClientInline(t *testing.T) { func TestServerAddHook(t *testing.T) { s := New(nil) - s.Log = &logger + + s.Log = logger require.NotNil(t, s) require.Equal(t, int64(0), s.hooks.Len()) @@ -247,8 +300,8 @@ func TestServerReadConnectionPacket(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _ = r.Close() }() require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o) @@ -268,8 +321,8 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) { }() go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes) + _ = r.Close() }() err := <-o @@ -285,8 +338,8 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -302,8 +355,8 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) { s.Clients.Add(cl) go func() { - r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) - r.Close() + _, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes) + _ = r.Close() }() _, err := s.readConnectionPacket(cl) @@ -322,8 +375,8 @@ func TestEstablishConnection(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -336,14 +389,18 @@ func TestEstablishConnection(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Todo: + // s.Clients is already empty here. Is it necessary to check v.StopCause()? + + // for _, v := range s.Clients.GetAll() { + // require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect + // } require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() // client must be deleted on session close if Clean = true _, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier) @@ -361,15 +418,15 @@ func TestEstablishConnectionAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestEstablishConnectionReadError(t *testing.T) { @@ -383,8 +440,8 @@ func TestEstablishConnectionReadError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error }() // receive the connack @@ -397,9 +454,11 @@ func TestEstablishConnectionReadError(t *testing.T) { err := <-o require.Error(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect ret := <-recv require.Equal(t, append( @@ -408,8 +467,8 @@ func TestEstablishConnectionReadError(t *testing.T) { ret, ) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionInheritExisting(t *testing.T) { @@ -432,9 +491,9 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect. - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect session takeover @@ -455,9 +514,11 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect connackPlusPacket := append( packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, @@ -467,8 +528,8 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover) time.Sleep(time.Microsecond * 100) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -478,12 +539,12 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { require.Empty(t, cl.State.Subscriptions.GetAll()) } -// See https://github.com/mochi-co/mqtt/issues/173 +// See https://github.com/mochi-mqtt/server/issues/173 func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { s := newServer() d := new(DelayHook) d.DisconnectDelay = time.Millisecond * 200 - s.AddHook(d, nil) + _ = s.AddHook(d, nil) defer s.Close() // Clean session, 0 session expiry interval @@ -508,7 +569,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { o1 <- err }() go func() { - w1.Write(cl1RawBytes) + _, _ = w1.Write(cl1RawBytes) }() // receive the first connack @@ -537,7 +598,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { go func() { x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:] x[19] = '.' // differentiate username bytes in debugging - w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) }() // receive the second connack @@ -565,7 +626,7 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { require.NotEmpty(t, clp2.State.Subscriptions.GetAll()) require.Empty(t, clp1.State.Subscriptions.GetAll()) - w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) require.NoError(t, <-o2) } @@ -588,7 +649,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) }() go func() { @@ -623,8 +684,8 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the disconnect @@ -645,15 +706,17 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) @@ -662,7 +725,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { func TestEstablishConnectionBadAuthentication(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) defer s.Close() @@ -673,8 +736,8 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -690,13 +753,13 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv) - w.Close() - r.Close() + _ = w.Close() + _ = r.Close() } func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) defer s.Close() @@ -707,15 +770,15 @@ func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnect(t *testing.T) { @@ -728,8 +791,8 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack @@ -745,10 +808,10 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv) - r.Close() + _ = r.Close() } -// See https://github.com/mochi-co/mqtt/issues/178 +// See https://github.com/mochi-mqtt/server/issues/178 func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { s := newServer() @@ -759,8 +822,8 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() // receive the connack error @@ -772,7 +835,7 @@ func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { err := <-o require.NoError(t, err) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { @@ -785,15 +848,15 @@ func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) - w.Close() + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes) + _ = w.Close() }() err := <-o require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionBadPacket(t *testing.T) { @@ -806,15 +869,15 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) - w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) }() err := <-o require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect) - r.Close() + _ = r.Close() } func TestServerEstablishConnectionOnConnectError(t *testing.T) { @@ -831,14 +894,14 @@ func TestServerEstablishConnectionOnConnectError(t *testing.T) { }() go func() { - w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) + _, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) }() err = <-o require.Error(t, err) require.ErrorIs(t, err, errTestHook) - r.Close() + _ = r.Close() } func TestServerSendConnack(t *testing.T) { @@ -852,7 +915,7 @@ func TestServerSendConnack(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -867,7 +930,7 @@ func TestServerSendConnackFailureReason(t *testing.T) { go func() { err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -884,7 +947,7 @@ func TestServerSendConnackWithServerKeepalive(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, true, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -963,7 +1026,7 @@ func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { go func() { err := s.SendConnack(cl, packets.CodeSuccess, false, nil) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1042,7 +1105,7 @@ func TestServerProcessPacketPingreq(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1071,7 +1134,7 @@ func TestServerProcessPacketPublishInvalid(t *testing.T) { func TestInjectPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1096,17 +1159,18 @@ func TestInjectPacketPublishAndReceive(t *testing.T) { go func() { err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } -func TestServerDirectPublishAndReceive(t *testing.T) { - s := newServer() - s.Serve() +func TestServerPublishAndReceive(t *testing.T) { + s := newServerWithInlineClient() + + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1132,14 +1196,22 @@ func TestServerDirectPublishAndReceive(t *testing.T) { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) require.NoError(t, err) - w1.Close() + _ = w1.Close() time.Sleep(time.Millisecond * 10) - w2.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } +func TestServerPublishNoInlineClient(t *testing.T) { + s := newServer() + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + func TestInjectPacketError(t *testing.T) { s := newServer() defer s.Close() @@ -1164,7 +1236,7 @@ func TestInjectPacketPublishInvalidTopic(t *testing.T) { func TestServerProcessPacketPublishAndReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -1190,8 +1262,8 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.NoError(t, err) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -1256,7 +1328,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1266,15 +1338,15 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { require.Equal(t, int32(4), cl.State.Inflight.sendQuota) } -func TestServerProcessPacketPublishAckFailure(t *testing.T) { +func TestServerProcessPublishAckFailure(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, w := newTestClient() s.Clients.Add(cl) - w.Close() + _ = w.Close() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -1291,7 +1363,7 @@ func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) { cl, _, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) - w.Close() + _ = w.Close() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) @@ -1305,7 +1377,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { hook.err = packets.ErrPayloadFormatInvalid err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1315,7 +1387,7 @@ func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1330,7 +1402,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { hook.err = packets.CodeSuccessIgnore err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1354,8 +1426,8 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() - w2.Close() + _ = w.Close() + _ = w2.Close() }() buf, err := io.ReadAll(r) @@ -1367,7 +1439,7 @@ func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, r, w := newTestClient() @@ -1379,7 +1451,7 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrReceiveMaximum) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1389,22 +1461,107 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { func TestServerProcessPublishInvalidTopic(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) - require.NoError(t, err) // $SYS topics should be ignored? + require.NoError(t, err) // $SYS Topics should be ignored? } func TestServerProcessPublishACLCheckDeny(t *testing.T) { - s := New(&Options{ - Logger: &logger, - }) - s.Serve() - defer s.Close() - cl, _, _ := newTestClient() - err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) - require.NoError(t, err) // ACL check fails silently + tt := []struct { + name string + protocolVersion byte + pk packets.Packet + expectErr error + expectReponse []byte + expectDisconnect bool + }{ + { + name: "v4_QOS0", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v4_QOS1", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v4_QOS2", + protocolVersion: 4, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet, + expectErr: packets.ErrNotAuthorized, + expectReponse: nil, + expectDisconnect: true, + }, + { + name: "v5_QOS0", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet, + expectErr: nil, + expectReponse: nil, + expectDisconnect: false, + }, + { + name: "v5_QOS1", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + { + name: "v5_QOS2", + protocolVersion: 5, + pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet, + expectErr: nil, + expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes, + expectDisconnect: false, + }, + } + + for _, tx := range tt { + t.Run(tx.name, func(t *testing.T) { + cc := *DefaultServerCapabilities + s := New(&Options{ + Logger: logger, + Capabilities: &cc, + }) + _ = s.AddHook(new(DenyHook), nil) + _ = s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = tx.protocolVersion + s.Clients.Add(cl) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := s.processPublish(cl, tx.pk) + require.ErrorIs(t, err, tx.expectErr) + _ = w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + + if tx.expectReponse != nil { + require.Equal(t, tx.expectReponse, buf) + } + + require.Equal(t, tx.expectDisconnect, cl.Closed()) + wg.Wait() + }) + } } func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { @@ -1417,7 +1574,7 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { err := s.AddHook(hook, nil) require.NoError(t, err) - s.Serve() + _ = s.Serve() defer s.Close() cl, _, _ := newTestClient() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) @@ -1431,7 +1588,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1448,7 +1605,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1467,7 +1624,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1483,7 +1640,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1498,7 +1655,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1514,7 +1671,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1534,7 +1691,7 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) { pkx.Origin = cl.ID s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1592,9 +1749,9 @@ func TestPublishToSubscribers(t *testing.T) { go func() { s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w1.Close() - w2.Close() - w3.Close() + _ = w1.Close() + _ = w2.Close() + _ = w3.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv) @@ -1636,7 +1793,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { pkx.Created = time.Now().Unix() - 30 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w1.Close() + _ = w1.Close() }() b := <-cl1Recv @@ -1660,7 +1817,7 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) { go func() { s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1685,7 +1842,7 @@ func TestPublishToSubscribersPkIgnore(t *testing.T) { pk.Ignore = true s.PublishToSubscribers(pk) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1712,9 +1869,9 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1741,9 +1898,9 @@ func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.FixedHeader.Qos = 2 - s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) time.Sleep(time.Microsecond * 100) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1763,7 +1920,7 @@ func TestPublishToClientExceedClientWritesPending(t *testing.T) { cl := newClient(w, &ops{ info: new(system.Info), hooks: new(Hooks), - log: &logger, + log: logger, options: &Options{ Capabilities: &Capabilities{ MaximumClientWritesPending: 3, @@ -1792,10 +1949,10 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) - s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) + _, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() receiverBuf := make(chan []byte) @@ -1848,6 +2005,19 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) { require.ErrorIs(t, err, packets.ErrQuotaExceeded) } +func TestPublishToClientACLNotAuthorized(t *testing.T) { + s := New(&Options{ + Logger: logger, + }) + err := s.AddHook(new(DenyHook), nil) + require.NoError(t, err) + cl, _, _ := newTestClient() + + _, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrNotAuthorized) +} + func TestPublishToClientNoConn(t *testing.T) { s := newServer() cl, _, _ := newTestClient() @@ -1875,10 +2045,10 @@ func TestProcessPublishWithTopicAlias(t *testing.T) { pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server pkx.TopicName = "" pkx.Properties.TopicAlias = 1 - s.processPacket(cl2, pkx) + _ = s.processPacket(cl2, pkx) time.Sleep(time.Millisecond) - w2.Close() - w.Close() + _ = w2.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1898,12 +2068,12 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { @@ -1920,12 +2090,12 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 s.PublishToSubscribers(pkx) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishToSubscribersNoConnection(t *testing.T) { @@ -1938,10 +2108,10 @@ func TestPublishToSubscribersNoConnection(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? - r.Close() + _ = r.Close() s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() } func TestPublishRetainedToClient(t *testing.T) { @@ -1959,7 +2129,7 @@ func TestPublishRetainedToClient(t *testing.T) { go func() { s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -1979,7 +2149,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) { go func() { s.publishRetainedToClient(cl, sub, false) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2000,7 +2170,7 @@ func TestPublishRetainedToClientError(t *testing.T) { retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.Equal(t, int64(1), retained) - w.Close() + _ = w.Close() s.publishRetainedToClient(cl, sub, false) } @@ -2103,7 +2273,7 @@ func TestServerProcessPacketPubrec(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv) @@ -2131,7 +2301,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv) @@ -2181,7 +2351,7 @@ func TestServerProcessPacketPubrel(t *testing.T) { err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota)) require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota)) @@ -2210,7 +2380,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties err := s.processPacket(cl, pk) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv) @@ -2322,7 +2492,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) { err := s.processPacket(cl, *tx.in.Packet) require.NoError(t, err) - w.Close() + _ = w.Close() require.Equal(t, tx.out.RawBytes, <-recv) if i == 0 { @@ -2403,7 +2573,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { } time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() if i != 2 { require.Equal(t, tx.out.RawBytes, <-recv) @@ -2426,7 +2596,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2445,7 +2615,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, pkx) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2471,7 +2641,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2487,7 +2657,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2507,7 +2677,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2528,7 +2698,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2550,7 +2720,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2571,7 +2741,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2592,7 +2762,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { require.NoError(t, err) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2606,7 +2776,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { func TestServerProcessSubscribeNoConnection(t *testing.T) { s := newServer() cl, r, _ := newTestClient() - r.Close() + _ = r.Close() err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.Error(t, err) require.ErrorIs(t, err, io.ErrClosedPipe) @@ -2614,16 +2784,16 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) { func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) - s.Serve() + _ = s.Serve() cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2633,9 +2803,9 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { s := New(&Options{ - Logger: &logger, + Logger: logger, }) - s.Serve() + _ = s.Serve() s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 @@ -2643,7 +2813,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { go func() { err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2660,7 +2830,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2676,7 +2846,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2693,7 +2863,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2729,7 +2899,7 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2744,7 +2914,7 @@ func TestServerRecievePacketDisconnectClient(t *testing.T) { go func() { err := s.DisconnectClient(cl, packets.CodeDisconnect) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2789,7 +2959,7 @@ func TestServerProcessPacketAuth(t *testing.T) { go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) require.NoError(t, err) - w.Close() + _ = w.Close() }() buf, err := io.ReadAll(r) @@ -2823,7 +2993,7 @@ func TestServerProcessPacketAuthFailure(t *testing.T) { func TestServerSendLWT(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2853,8 +3023,8 @@ func TestServerSendLWT(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -2862,7 +3032,7 @@ func TestServerSendLWT(t *testing.T) { func TestServerSendLWTRetain(t *testing.T) { s := newServer() - s.Serve() + _ = s.Serve() defer s.Close() sender, _, w1 := newTestClient() @@ -2893,8 +3063,8 @@ func TestServerSendLWTRetain(t *testing.T) { go func() { s.sendLWT(sender) time.Sleep(time.Millisecond * 10) - w1.Close() - w2.Close() + _ = w1.Close() + _ = w2.Close() }() require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) @@ -2930,7 +3100,7 @@ func TestServerSendLWTDelayed(t *testing.T) { s.sendDelayedLWT(time.Now().Unix()) require.Equal(t, 0, s.loop.willDelayed.Len()) time.Sleep(time.Millisecond) - w.Close() + _ = w.Close() }() recv := make(chan []byte) @@ -2946,7 +3116,7 @@ func TestServerSendLWTDelayed(t *testing.T) { func TestServerReadStore(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) hook.failAt = 1 // clients err := s.readStore() @@ -3007,6 +3177,7 @@ func TestServerLoadInflightMessages(t *testing.T) { {ID: "zen"}, {ID: "mochi-co"}, }) + require.Equal(t, 3, s.Clients.Len()) v := []storage.Message{ @@ -3051,7 +3222,7 @@ func TestServerClose(t *testing.T) { s := newServer() hook := new(modifiedHookBase) - s.AddHook(hook, nil) + _ = s.AddHook(hook, nil) cl, r, _ := newTestClient() cl.Net.Listener = "t1" @@ -3060,7 +3231,7 @@ func TestServerClose(t *testing.T) { err := s.AddListener(listeners.NewMockListener("t1", ":1882")) require.NoError(t, err) - s.Serve() + _ = s.Serve() // receive the disconnect recv := make(chan []byte) @@ -3077,7 +3248,7 @@ func TestServerClose(t *testing.T) { require.Equal(t, true, ok) require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) - s.Close() + _ = s.Close() time.Sleep(time.Millisecond) require.Equal(t, false, listener.(*listeners.MockListener).IsServing()) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv) @@ -3166,7 +3337,6 @@ func TestServerClearExpiredClients(t *testing.T) { require.Equal(t, 4, s.Clients.Len()) s.clearExpiredClients(n) - require.Equal(t, 2, s.Clients.Len()) } @@ -3186,3 +3356,308 @@ func TestAtomicItoa(t *testing.T) { ip := &i require.Equal(t, "22", AtomicItoa(ip)) } + +func TestServerSubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {} + + s := newServerWithInlineClient() + require.NotNil(t, s) + + tt := []struct { + desc string + filter string + identifier int + handler InlineSubFn + expect error + }{ + { + desc: "subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe d/e/f", + filter: "d/e/f", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe d/e/f by different identifier", + filter: "d/e/f", + identifier: 2, + handler: handler, + expect: nil, + }, + { + desc: "subscribe different handler", + filter: "a/b/c", + identifier: 1, + handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {}, + expect: nil, + }, + { + desc: "subscribe $SYS/info", + filter: "$SYS/info", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe invalid ###", + filter: "###", + identifier: 1, + handler: handler, + expect: packets.ErrTopicFilterInvalid, + }, + { + desc: "subscribe invalid handler", + filter: "a/b/c", + identifier: 1, + handler: nil, + expect: packets.ErrInlineSubscriptionHandlerInvalid, + }, + } + + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler)) + }) + } +} + +func TestServerSubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {}) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestServerUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + s := newServerWithInlineClient() + err := s.Subscribe("a/b/c", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 2, handler) + require.Nil(t, err) + + err = s.Unsubscribe("a/b/c", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 2) + require.Nil(t, err) + + err = s.Unsubscribe("not/exist", 1) + require.Nil(t, err) + + err = s.Unsubscribe("#/#/invalid", 1) + require.Equal(t, packets.ErrTopicFilterInvalid, err) +} + +func TestServerUnsubscribeNoInlineClient(t *testing.T) { + s := newServer() + err := s.Unsubscribe("a/b/c", 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrInlineClientNotEnabled) +} + +func TestPublishToInlineSubscriber(t *testing.T) { + s := newServerWithInlineClient() + finishCh := make(chan bool) + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + }() + + require.Equal(t, true, <-finishCh) +} + +func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + + pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet + s.PublishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.PublishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetain(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 1 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + require.Equal(t, true, <-finishCh) +} + +func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) { + s := newServerWithInlineClient() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} diff --git a/mqtt/system/system.go b/mqtt/system/system.go index 647ae00..2ed47d0 100644 --- a/mqtt/system/system.go +++ b/mqtt/system/system.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package system diff --git a/mqtt/topics.go b/mqtt/topics.go index 0cc80d5..f9d122e 100644 --- a/mqtt/topics.go +++ b/mqtt/topics.go @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt @@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio return m } +// InlineSubFn is the signature for a callback function which will be called +// when an inline client receives a message on a topic it is subscribed to. +// The sub argument contains information about the subscription that was matched for any filters. +type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet) + +// InlineSubscriptions represents a map of internal subscriptions keyed on client. +type InlineSubscriptions struct { + internal map[int]InlineSubscription + sync.RWMutex +} + +// NewInlineSubscriptions returns a new instance of InlineSubscriptions. +func NewInlineSubscriptions() *InlineSubscriptions { + return &InlineSubscriptions{ + internal: map[int]InlineSubscription{}, + } +} + +// Add adds a new internal subscription for a client id. +func (s *InlineSubscriptions) Add(val InlineSubscription) { + s.Lock() + defer s.Unlock() + s.internal[val.Identifier] = val +} + +// GetAll returns all internal subscriptions. +func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription { + s.RLock() + defer s.RUnlock() + m := map[int]InlineSubscription{} + for k, v := range s.internal { + m[k] = v + } + return m +} + +// Get returns an internal subscription for a client id. +func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) { + s.RLock() + defer s.RUnlock() + val, ok = s.internal[id] + return val, ok +} + +// Len returns the number of internal subscriptions. +func (s *InlineSubscriptions) Len() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Delete removes an internal subscription by the client id. +func (s *InlineSubscriptions) Delete(id int) { + s.Lock() + defer s.Unlock() + delete(s.internal, id) +} + // Subscriptions is a map of subscriptions keyed on client. type Subscriptions struct { internal map[string]packets.Subscription @@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) { // ClientSubscriptions is a map of aggregated subscriptions for a client. type ClientSubscriptions map[string]packets.Subscription +type InlineSubscription struct { + packets.Subscription + Handler InlineSubFn +} + // Subscribers contains the shared and non-shared subscribers matching a topic. type Subscribers struct { - Shared map[string]map[string]packets.Subscription - SharedSelected map[string]packets.Subscription - Subscriptions map[string]packets.Subscription + Shared map[string]map[string]packets.Subscription + SharedSelected map[string]packets.Subscription + Subscriptions map[string]packets.Subscription + InlineSubscriptions map[int]InlineSubscription } // SelectShared returns one subscriber for each shared subscription group. @@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex { } } +// InlineSubscribe adds a new internal subscription for a topic filter, returning +// true if the subscription was new. +func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) (bool, int) { + x.root.Lock() + defer x.root.Unlock() + + var existed bool + n := x.set(subscription.Filter, 0) + _, existed = n.inlineSubscriptions.Get(subscription.Identifier) + n.inlineSubscriptions.Add(subscription) + + return !existed, n.inlineSubscriptions.Len() +} + +// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client, +// returning true if the subscription existed. +func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) (bool, int) { + x.root.Lock() + defer x.root.Unlock() + + particle := x.seek(filter, 0) + if particle == nil { + return false, 0 + } + + particle.inlineSubscriptions.Delete(id) + + if particle.inlineSubscriptions.Len() == 0 { + x.trim(particle) + } + return true, particle.inlineSubscriptions.Len() +} + // Subscribe adds a new subscription for a client to a topic filter, returning // true if the subscription was new. func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) (bool, int) { @@ -490,9 +588,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack // their subscription ids and highest qos. func (x *TopicsIndex) Subscribers(topic string) *Subscribers { return x.scanSubscribers(topic, 0, nil, &Subscribers{ - Shared: map[string]map[string]packets.Subscription{}, - SharedSelected: map[string]packets.Subscription{}, - Subscriptions: map[string]packets.Subscription{}, + Shared: map[string]map[string]packets.Subscription{}, + SharedSelected: map[string]packets.Subscription{}, + Subscriptions: map[string]packets.Subscription{}, + InlineSubscriptions: map[int]InlineSubscription{}, }) } @@ -514,10 +613,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su } else { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) if wild := particle.particles.get("#"); wild != nil && partKey != "+" { x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 x.gatherSharedSubscriptions(wild, subs) + x.gatherInlineSubscriptions(particle, subs) } } } @@ -526,6 +627,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su if particle := n.particles.get("#"); particle != nil { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) } return subs @@ -568,6 +670,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr } } +// gatherSharedSubscriptions gathers all inline subscriptions for a particle. +func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) { + if subs.InlineSubscriptions == nil { + subs.InlineSubscriptions = map[int]InlineSubscription{} + } + + for id, inline := range particle.inlineSubscriptions.GetAll() { + subs.InlineSubscriptions[id] = inline + } +} + // isolateParticle extracts a particle between d / and d+1 / without allocations. func isolateParticle(filter string, d int) (particle string, hasNext bool) { var next, end int @@ -598,7 +711,7 @@ func IsSharedFilter(filter string) bool { // IsValidFilter returns true if the filter is valid. func IsValidFilter(filter string, forPublish bool) bool { - if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs. + if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish. return false // [MQTT-4.7.3-1] } @@ -639,23 +752,25 @@ func IsValidFilter(filter string, forPublish bool) bool { // particle is a child node on the tree. type particle struct { - parent *particle // a pointer to the parent of the particle - subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address - shared *SharedSubscriptions // a map of shared subscriptions keyed on group name - key string // the key of the particle - retainPath string // path of a retained message - particles particles // a map of child particles - sync.Mutex // mutex for when making changes to the particle + key string // the key of the particle + parent *particle // a pointer to the parent of the particle + particles particles // a map of child particles + subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address + shared *SharedSubscriptions // a map of shared subscriptions keyed on group name + inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle + retainPath string // path of a retained message + sync.Mutex // mutex for when making changes to the particle } // newParticle returns a pointer to a new instance of particle. func newParticle(key string, parent *particle) *particle { return &particle{ - key: key, - parent: parent, - particles: newParticles(), - subscriptions: NewSubscriptions(), - shared: NewSharedSubscriptions(), + key: key, + parent: parent, + particles: newParticles(), + subscriptions: NewSubscriptions(), + shared: NewSharedSubscriptions(), + inlineSubscriptions: NewInlineSubscriptions(), } } diff --git a/mqtt/topics_test.go b/mqtt/topics_test.go index 9de4bfe..8a5c5dc 100644 --- a/mqtt/topics_test.go +++ b/mqtt/topics_test.go @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co +// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co // SPDX-FileContributor: mochi-co package mqtt import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -857,3 +858,227 @@ func TestNewTopicAliases(t *testing.T) { require.NotNil(t, a.Outbound) require.Equal(t, uint16(5), a.Outbound.maximum) } + +func TestNewInlineSubscriptions(t *testing.T) { + subscriptions := NewInlineSubscriptions() + require.NotNil(t, subscriptions) + require.NotNil(t, subscriptions.internal) + require.Equal(t, 0, subscriptions.Len()) +} + +func TestInlineSubscriptionAdd(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) +} + +func TestInlineSubscriptionGet(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) + + _, ok = subscriptions.Get(999) + require.False(t, ok) +} + +func TestInlineSubscriptionsGetAll(t *testing.T) { + subscriptions := NewInlineSubscriptions() + + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3}, + }) + + allSubs := subscriptions.GetAll() + require.Len(t, allSubs, 3) + require.Contains(t, allSubs, 1) + require.Contains(t, allSubs, 2) + require.Contains(t, allSubs, 3) +} + +func TestInlineSubscriptionDelete(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + subscriptions.Delete(1) + _, ok := subscriptions.Get(1) + require.False(t, ok) + require.Empty(t, subscriptions.GetAll()) + require.Zero(t, subscriptions.Len()) +} + +func TestInlineSubscribe(t *testing.T) { + + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + tt := []struct { + desc string + filter string + subscription InlineSubscription + wasNew bool + count int + }{ + { + desc: "subscribe", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "subscribe existed", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: false, + count: 1, + }, + { + desc: "subscribe different identifier", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}}, + wasNew: true, + count: 2, + }, + { + desc: "subscribe case sensitive didnt exist", + filter: "A/B/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "wildcard+ sub", + filter: "d/+", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}}, + wasNew: true, + count: 1, + }, + { + desc: "wildcard# sub", + filter: "d/e/#", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}}, + wasNew: true, + count: 1, + }, + } + + index := NewTopicsIndex() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + exist, count := index.InlineSubscribe(tx.subscription) + require.Equal(t, tx.wasNew, exist) + require.Equal(t, tx.count, count) + }) + } + + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) +} + +func TestInlineUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + index := NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index = NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}}) + sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok, count := index.InlineUnsubscribe(1, "a/b/c/d") + require.Equal(t, 0, count) + require.True(t, ok) + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok, _ = index.InlineUnsubscribe(1, "d/e/f") + require.Equal(t, 0, count) + require.True(t, ok) + require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f")) + + ok, _ = index.InlineUnsubscribe(1, "not/exist") + require.Equal(t, 0, count) + require.False(t, ok) +} diff --git a/plugin/auth/http/http.go b/plugin/auth/http/http.go index 0f2ca98..a96bc69 100644 --- a/plugin/auth/http/http.go +++ b/plugin/auth/http/http.go @@ -4,15 +4,16 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" + "net/url" + "strings" + "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" pa "github.com/wind-c/comqtt/v2/plugin/auth" - "io" - "net/http" - "net/url" - "strings" ) const ( @@ -58,9 +59,7 @@ func (a *Auth) Init(config any) error { } a.config = config.(*Options) - a.Log.Info(). - Str("auth-url", a.config.AuthUrl). - Str("acl-url", a.config.AclUrl) + a.Log.Info("", "auth-url", a.config.AuthUrl, "acl-url", a.config.AclUrl) return nil } diff --git a/plugin/auth/http/http_test.go b/plugin/auth/http/http_test.go index fe44d5c..968a9bf 100644 --- a/plugin/auth/http/http_test.go +++ b/plugin/auth/http/http_test.go @@ -2,21 +2,25 @@ package http import ( "encoding/json" - "github.com/rs/zerolog" + "io" + "testing" + + "log/slog" + "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" "gopkg.in/h2non/gock.v1" - "os" - "testing" ) const path = "./conf.yml" var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -37,7 +41,7 @@ var ( func newAuth(t *testing.T) *Auth { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(&Options{ AuthMode: byte(auth.AuthUsername), @@ -54,7 +58,7 @@ func newAuth(t *testing.T) *Auth { func TestInitFromConfFile(t *testing.T) { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) opts := Options{} err := plugin.LoadYaml(path, &opts) require.NoError(t, err) diff --git a/plugin/auth/mysql/mysql.go b/plugin/auth/mysql/mysql.go index 06a977c..6e18e7e 100644 --- a/plugin/auth/mysql/mysql.go +++ b/plugin/auth/mysql/mysql.go @@ -3,6 +3,7 @@ package mysql import ( "bytes" "fmt" + _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "github.com/wind-c/comqtt/v2/mqtt" @@ -76,12 +77,11 @@ func (a *Auth) Init(config any) error { } a.config = config.(*Options) - a.Log.Info(). - Str("host", a.config.Dsn.Host). - Str("username", a.config.Dsn.LoginName). - Int("password-len", len(a.config.Dsn.LoginPassword)). - Str("db", a.config.Dsn.Schema). - Msg("connecting to mysql") + a.Log.Info("connecting to mysql", + "host", a.config.Dsn.Host, + "username", a.config.Dsn.LoginName, + "password-len", len(a.config.Dsn.LoginPassword), + "db", a.config.Dsn.Schema) dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=UTC", a.config.Dsn.LoginName, a.config.Dsn.LoginPassword, a.config.Dsn.Host, a.config.Dsn.Port, a.config.Dsn.Schema, a.config.Dsn.Charset) @@ -104,7 +104,7 @@ func (a *Auth) Init(config any) error { // Stop closes the mysql connection. func (a *Auth) Stop() error { - a.Log.Info().Msg("disconnecting from mysql") + a.Log.Info("disconnecting from mysql") a.authStmt.Close() a.aclStmt.Close() return a.db.Close() diff --git a/plugin/auth/mysql/mysql_test.go b/plugin/auth/mysql/mysql_test.go index de443cc..2b93c8e 100644 --- a/plugin/auth/mysql/mysql_test.go +++ b/plugin/auth/mysql/mysql_test.go @@ -1,21 +1,25 @@ package mysql import ( - "github.com/rs/zerolog" + "io" + "net" + "testing" + + "log/slog" + "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" - "net" - "os" - "testing" ) const path = "./conf.yml" var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -42,7 +46,7 @@ func teardown(a *Auth, t *testing.T) { func newAuth(t *testing.T) *Auth { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(&Options{ AuthMode: byte(auth.AuthUsername), @@ -80,7 +84,7 @@ func TestInitFromConfFile(t *testing.T) { t.SkipNow() } a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) opts := Options{} err := plugin.LoadYaml(path, &opts) require.NoError(t, err) @@ -91,7 +95,7 @@ func TestInitFromConfFile(t *testing.T) { func TestInitBadConfig(t *testing.T) { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(map[string]any{}) require.Error(t, err) diff --git a/plugin/auth/postgresql/postgresql.go b/plugin/auth/postgresql/postgresql.go index e0c9956..63a4f05 100644 --- a/plugin/auth/postgresql/postgresql.go +++ b/plugin/auth/postgresql/postgresql.go @@ -77,12 +77,11 @@ func (a *Auth) Init(config any) error { } a.config = config.(*Options) - a.Log.Info(). - Str("host", a.config.Dsn.Host). - Str("username", a.config.Dsn.LoginName). - Int("password-len", len(a.config.Dsn.LoginPassword)). - Str("db", a.config.Dsn.Schema). - Msg("connecting to postgresql") + a.Log.Info("connecting to postgresql", + "host", a.config.Dsn.Host, + "username", a.config.Dsn.LoginName, + "password-len", len(a.config.Dsn.LoginPassword), + "db", a.config.Dsn.Schema) dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", a.config.Dsn.Host, a.config.Dsn.Port, a.config.Dsn.LoginName, a.config.Dsn.LoginPassword, a.config.Dsn.Schema, a.config.Dsn.SslMode) @@ -99,12 +98,12 @@ func (a *Auth) Init(config any) error { a.config.Acl.TopicColumn, a.config.Acl.AccessColumn, a.config.Acl.Table, a.config.Acl.UserColumn) a.authStmt, err = sqlxDB.Preparex(authSql) if err != nil { - a.Log.Error().Str("authSql", authSql).Msg("Unable to create prepared statement for auth-sql") + a.Log.Error("Unable to create prepared statement for auth-sql", "authSql", authSql) return err } a.aclStmt, err = sqlxDB.Preparex(aclSql) if err != nil { - a.Log.Error().Str("aclStmt", aclSql).Msg("Unable to create prepared statement for acl-sql") + a.Log.Error("Unable to create prepared statement for acl-sql", "aclStmt", aclSql) return err } a.db = sqlxDB @@ -113,7 +112,7 @@ func (a *Auth) Init(config any) error { // Stop closes the postgresql connection. func (a *Auth) Stop() error { - a.Log.Info().Msg("disconnecting from postgresql") + a.Log.Info("disconnecting from postgresql") a.authStmt.Close() a.aclStmt.Close() return a.db.Close() diff --git a/plugin/auth/postgresql/postgresql_test.go b/plugin/auth/postgresql/postgresql_test.go index 33c813e..6820a60 100644 --- a/plugin/auth/postgresql/postgresql_test.go +++ b/plugin/auth/postgresql/postgresql_test.go @@ -1,21 +1,24 @@ package postgresql import ( - "github.com/rs/zerolog" + "io" + "log/slog" + "net" + "testing" + "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" - "net" - "os" - "testing" ) const path = "./conf.yml" var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -42,7 +45,7 @@ func teardown(a *Auth, t *testing.T) { func newAuth(t *testing.T) *Auth { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(&Options{ AuthMode: byte(auth.AuthUsername), @@ -78,7 +81,7 @@ func TestInitFromConfFile(t *testing.T) { t.Skip("no postgresql server running") } a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) opts := Options{} err := plugin.LoadYaml(path, &opts) require.NoError(t, err) @@ -89,7 +92,7 @@ func TestInitFromConfFile(t *testing.T) { func TestInitBadConfig(t *testing.T) { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(map[string]any{}) require.Error(t, err) diff --git a/plugin/auth/redis/redis.go b/plugin/auth/redis/redis.go index 3b034d6..701ea89 100644 --- a/plugin/auth/redis/redis.go +++ b/plugin/auth/redis/redis.go @@ -5,13 +5,14 @@ import ( "context" "encoding/json" "fmt" + "strconv" + "github.com/go-redis/redis/v8" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" pa "github.com/wind-c/comqtt/v2/plugin/auth" - "strconv" ) // defaultAddr is the default address to the redis service. @@ -86,12 +87,10 @@ func (a *Auth) Init(config any) error { a.config.AclKeyPrefix = defaultAclKeyPrefix } - a.Log.Info(). - Str("address", a.config.RedisOptions.Addr). - Str("username", a.config.RedisOptions.Username). - Int("password-len", len(a.config.RedisOptions.Password)). - Int("db", a.config.RedisOptions.DB). - Msg("connecting to redis service") + a.Log.Info("connecting to redis service", + "address", a.config.RedisOptions.Addr, "username", a.config.RedisOptions.Username, + "password-len", len(a.config.RedisOptions.Password), + "db", a.config.RedisOptions.DB) a.db = redis.NewClient(&redis.Options{ Addr: a.config.RedisOptions.Addr, @@ -104,13 +103,13 @@ func (a *Auth) Init(config any) error { return fmt.Errorf("failed to ping service: %w", err) } - a.Log.Info().Msg("connected to redis service") + a.Log.Info("connected to redis service") return nil } // Stop closes the redis connection. func (a *Auth) Stop() error { - a.Log.Info().Msg("disconnecting from redis service") + a.Log.Info("disconnecting from redis service") return a.db.Close() } @@ -151,7 +150,7 @@ func (a *Auth) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { var ar auth.AuthRule if err = json.Unmarshal([]byte(res), &ar); err != nil { - a.Log.Error().Err(err).Str("data", res).Msg("failed to unmarshal redis auth data") + a.Log.Error("failed to unmarshal redis auth data", "error", err, "data", res) return false } diff --git a/plugin/auth/redis/redis_test.go b/plugin/auth/redis/redis_test.go index 5c79658..2379e3f 100644 --- a/plugin/auth/redis/redis_test.go +++ b/plugin/auth/redis/redis_test.go @@ -2,18 +2,21 @@ package redis import ( "context" + "io" + "log/slog" + "testing" + "github.com/alicebob/miniredis/v2" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/packets" - "os" - "testing" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -34,7 +37,7 @@ var ( func newAuth(t *testing.T, addr string) *Auth { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(&Options{ AuthMode: byte(auth.AuthUsername), @@ -62,7 +65,7 @@ func TestInitUseDefaults(t *testing.T) { defer s.Close() a := newAuth(t, defaultAddr) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(nil) require.NoError(t, err) defer teardown(t, a) @@ -73,7 +76,7 @@ func TestInitUseDefaults(t *testing.T) { func TestInitBadConfig(t *testing.T) { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(map[string]any{}) require.Error(t, err) @@ -81,7 +84,7 @@ func TestInitBadConfig(t *testing.T) { func TestInitBadAddr(t *testing.T) { a := new(Auth) - a.SetOpts(&logger, nil) + a.SetOpts(logger, nil) err := a.Init(&Options{ RedisOptions: &redisOptions{ Addr: "abc:123", diff --git a/plugin/bridge/kafka/kafka.go b/plugin/bridge/kafka/kafka.go index ec4f873..fdaf052 100644 --- a/plugin/bridge/kafka/kafka.go +++ b/plugin/bridge/kafka/kafka.go @@ -9,12 +9,13 @@ import ( "context" "encoding/json" "fmt" + "strings" + "time" + "github.com/segmentio/kafka-go" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" - "strings" - "time" ) const defaultAddr = "localhost:9092" @@ -133,11 +134,10 @@ func (b *Bridge) Init(config any) error { } b.config = config.(*Options) - b.Log.Info(). - Str("brokers", strings.Join(b.config.KafkaOptions.Brokers, ",")). - Str("topic", b.config.KafkaOptions.Topic). - Bool("async", b.config.KafkaOptions.Async). - Msg("connecting to kafka service") + b.Log.Info("connecting to kafka service", + "brokers", strings.Join(b.config.KafkaOptions.Brokers, ","), + "topic", b.config.KafkaOptions.Topic, + "async", b.config.KafkaOptions.Async) var balancer kafka.Balancer switch b.config.KafkaOptions.Balancer { @@ -169,9 +169,9 @@ func (b *Bridge) Init(config any) error { // verify connect if _, err := b.kafkaTopics(); err != nil { - b.Log.Error().Err(err).Msg("cannot connect to kafka service") + b.Log.Error("cannot connect to kafka service", "error", err) } else { - b.Log.Info().Msg("connected to kafka service") + b.Log.Error("connected to kafka service", "error", err) } return nil @@ -200,7 +200,7 @@ func (b *Bridge) kafkaTopics() (map[string]struct{}, error) { // Stop closes the kafka connection. func (b *Bridge) Stop() error { - b.Log.Info().Msg("disconnecting from kafka service") + b.Log.Info("disconnecting from kafka service") return b.writer.Close() } @@ -210,7 +210,7 @@ func (b *Bridge) handler(messages []kafka.Message, err error) { for _, msg := range messages { keys = append(keys, string(msg.Key)) } - b.Log.Err(err).Strs("keys", keys).Msg("write msg to kafka") + b.Log.Error("write msg to kafka", "error", err, "keys", keys) } } @@ -255,7 +255,7 @@ func (b *Bridge) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { } data, err := msg.MarshalBinary() if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnSessionEstablished") + b.Log.Error("bridge-kafka:OnSessionEstablished", "error", err) return } @@ -264,7 +264,7 @@ func (b *Bridge) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { Value: data, }) if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnSessionEstablished") + b.Log.Error("bridge-kafka:OnSessionEstablished", "error", err) } } @@ -284,7 +284,7 @@ func (b *Bridge) OnDisconnect(cl *mqtt.Client, err error, expire bool) { data, err := msg.MarshalBinary() if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnDisconnect") + b.Log.Error("bridge-kafka:OnDisconnect", "error", err) return } @@ -293,7 +293,7 @@ func (b *Bridge) OnDisconnect(cl *mqtt.Client, err error, expire bool) { Value: data, }) if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnDisconnect") + b.Log.Error("bridge-kafka:OnDisconnect", "error", err) } } @@ -314,7 +314,7 @@ func (b *Bridge) OnPublished(cl *mqtt.Client, pk packets.Packet) { } data, err := msg.MarshalBinary() if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnPublished") + b.Log.Error("bridge-kafka:OnPublished", "error", err) return } @@ -323,7 +323,7 @@ func (b *Bridge) OnPublished(cl *mqtt.Client, pk packets.Packet) { Value: data, }) if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnPublished") + b.Log.Error("bridge-kafka:OnPublished", "error", err) } } @@ -352,7 +352,7 @@ func (b *Bridge) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] } data, err := msg.MarshalBinary() if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnSubscribed") + b.Log.Error("bridge-kafka:OnSubscribed", "error", err) return } @@ -361,7 +361,7 @@ func (b *Bridge) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes [] Value: data, }) if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnSubscribed") + b.Log.Error("bridge-kafka:OnSubscribed", "error", err) } } @@ -386,7 +386,7 @@ func (b *Bridge) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes } data, err := msg.MarshalBinary() if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnUnsubscribed") + b.Log.Error("bridge-kafka:OnUnsubscribed", "error", err) return } @@ -395,7 +395,7 @@ func (b *Bridge) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes Value: data, }) if err != nil { - b.Log.Error().Err(err).Msg("bridge-kafka:OnUnsubscribed") + b.Log.Error("bridge-kafka:OnUnsubscribed", "error", err) } } diff --git a/plugin/bridge/kafka/kafka_test.go b/plugin/bridge/kafka/kafka_test.go index 1a7456e..5fc5f18 100644 --- a/plugin/bridge/kafka/kafka_test.go +++ b/plugin/bridge/kafka/kafka_test.go @@ -4,19 +4,22 @@ import ( "context" "errors" "fmt" - "github.com/rs/zerolog" + "io" + "log/slog" + "sync" + "testing" + "github.com/segmentio/kafka-go" "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/mqtt" "github.com/wind-c/comqtt/v2/mqtt/packets" "github.com/wind-c/comqtt/v2/plugin" - "os" - "sync" - "testing" ) var ( - logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) + // Currently, the input is directed to /dev/null. If you need to + // output to stdout, just modify 'io.Discard' here to 'os.Stdout'. + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) client = &mqtt.Client{ ID: "test", @@ -44,7 +47,7 @@ func teardown(t *testing.T, b *Bridge) { func newBridge(t *testing.T) *Bridge { b := new(Bridge) - b.SetOpts(&logger, nil) + b.SetOpts(logger, nil) opts := &Options{} err := plugin.LoadYaml("./conf.yml", opts) require.NoError(t, err) diff --git a/plugin/bridge/kafka/log.go b/plugin/bridge/kafka/log.go index c062a76..a827224 100644 --- a/plugin/bridge/kafka/log.go +++ b/plugin/bridge/kafka/log.go @@ -1,13 +1,16 @@ package kafka -import "github.com/rs/zerolog" +import ( + "fmt" + "log/slog" +) type kafkaLogger struct { - logger *zerolog.Logger + logger *slog.Logger prefix string } -func newKafkaLogger(logger *zerolog.Logger) *kafkaLogger { +func newKafkaLogger(logger *slog.Logger) *kafkaLogger { return &kafkaLogger{ logger: logger, prefix: "kafka: ", @@ -15,5 +18,5 @@ func newKafkaLogger(logger *zerolog.Logger) *kafkaLogger { } func (k kafkaLogger) Printf(format string, v ...interface{}) { - k.logger.Printf(k.prefix+format, v...) + k.logger.Info(fmt.Sprintf(k.prefix+format, v...)) }