From f82742b392bfc459dc767ead2636a9df60e76356 Mon Sep 17 00:00:00 2001 From: "fox.cpp" Date: Thu, 30 Jan 2025 01:45:19 +0300 Subject: [PATCH] Implement net.Listener hand-over during config reload Log library refactored a little to make it easier to enable debug logging. --- framework/log/log.go | 67 ++++++++------ framework/log/orderedjson.go | 2 +- framework/log/zap.go | 2 +- framework/module/lifetime.go | 4 +- framework/module/registry.go | 4 +- framework/resource/netresource/dup.go | 27 ++++++ framework/resource/netresource/fd.go | 47 ++++++++++ framework/resource/netresource/listen.go | 38 ++++++++ framework/resource/netresource/tracker.go | 91 +++++++++++++++++++ framework/resource/resource.go | 18 ++++ framework/resource/singleton.go | 64 +++++++++++++ framework/resource/tracker.go | 47 ++++++++++ internal/auth/netauth/netauth.go | 2 +- .../endpoint/dovecot_sasld/dovecot_sasl.go | 5 +- internal/endpoint/imap/imap.go | 5 +- internal/endpoint/openmetrics/om.go | 4 +- internal/endpoint/smtp/smtp.go | 5 +- internal/updatepipe/unix_pipe.go | 3 +- maddy.go | 3 + 19 files changed, 394 insertions(+), 44 deletions(-) create mode 100644 framework/resource/netresource/dup.go create mode 100644 framework/resource/netresource/fd.go create mode 100644 framework/resource/netresource/listen.go create mode 100644 framework/resource/netresource/tracker.go create mode 100644 framework/resource/resource.go create mode 100644 framework/resource/singleton.go create mode 100644 framework/resource/tracker.go diff --git a/framework/log/log.go b/framework/log/log.go index da8647f6..366cf156 100644 --- a/framework/log/log.go +++ b/framework/log/log.go @@ -42,6 +42,8 @@ import ( // No serialization is provided by Logger, its log.Output responsibility to // ensure goroutine-safety if necessary. type Logger struct { + Parent *Logger + Out Output Name string Debug bool @@ -51,30 +53,37 @@ type Logger struct { Fields map[string]interface{} } -func (l Logger) Zap() *zap.Logger { +func (l *Logger) Zap() *zap.Logger { // TODO: Migrate to using zap natively. return zap.New(zapLogger{L: l}) } -func (l Logger) Debugf(format string, val ...interface{}) { - if !l.Debug { +func (l *Logger) IsDebug() bool { + if l.Parent == nil { + return l.Debug + } + return l.Debug || l.Parent.IsDebug() +} + +func (l *Logger) Debugf(format string, val ...interface{}) { + if !l.IsDebug() { return } l.log(true, l.formatMsg(fmt.Sprintf(format, val...), nil)) } -func (l Logger) Debugln(val ...interface{}) { - if !l.Debug { +func (l *Logger) Debugln(val ...interface{}) { + if !l.IsDebug() { return } l.log(true, l.formatMsg(strings.TrimRight(fmt.Sprintln(val...), "\n"), nil)) } -func (l Logger) Printf(format string, val ...interface{}) { +func (l *Logger) Printf(format string, val ...interface{}) { l.log(false, l.formatMsg(fmt.Sprintf(format, val...), nil)) } -func (l Logger) Println(val ...interface{}) { +func (l *Logger) Println(val ...interface{}) { l.log(false, l.formatMsg(strings.TrimRight(fmt.Sprintln(val...), "\n"), nil)) } @@ -87,13 +96,13 @@ func (l Logger) Println(val ...interface{}) { // followed by corresponding values. That is, for example, []interface{"key", // "value", "key2", "value2"}. // -// If value in fields implements LogFormatter, it will be represented by the +// If value in fields implements Formatter, it will be represented by the // string returned by FormatLog method. Same goes for fmt.Stringer and error // interfaces. // // Additionally, time.Time is written as a string in ISO 8601 format. // time.Duration follows fmt.Stringer rule above. -func (l Logger) Msg(msg string, fields ...interface{}) { +func (l *Logger) Msg(msg string, fields ...interface{}) { m := make(map[string]interface{}, len(fields)/2) fieldsToMap(fields, m) l.log(false, l.formatMsg(msg, m)) @@ -112,7 +121,7 @@ func (l Logger) Msg(msg string, fields ...interface{}) { // In the context of Error method, "msg" typically indicates the top-level // context in which the error is *handled*. For example, if error leads to // rejection of SMTP DATA command, msg will probably be "DATA error". -func (l Logger) Error(msg string, err error, fields ...interface{}) { +func (l *Logger) Error(msg string, err error, fields ...interface{}) { if err == nil { return } @@ -133,8 +142,8 @@ func (l Logger) Error(msg string, err error, fields ...interface{}) { l.log(false, l.formatMsg(msg, allFields)) } -func (l Logger) DebugMsg(kind string, fields ...interface{}) { - if !l.Debug { +func (l *Logger) DebugMsg(kind string, fields ...interface{}) { + if !l.IsDebug() { return } m := make(map[string]interface{}, len(fields)/2) @@ -162,7 +171,7 @@ func fieldsToMap(fields []interface{}, out map[string]interface{}) { } } -func (l Logger) formatMsg(msg string, fields map[string]interface{}) string { +func (l *Logger) formatMsg(msg string, fields map[string]interface{}) string { formatted := strings.Builder{} formatted.WriteString(msg) @@ -184,14 +193,17 @@ func (l Logger) formatMsg(msg string, fields map[string]interface{}) string { return formatted.String() } -type LogFormatter interface { +type Formatter interface { FormatLog() string } // Write implements io.Writer, all bytes sent // to it will be written as a separate log messages. // No line-buffering is done. -func (l Logger) Write(s []byte) (int, error) { +func (l *Logger) Write(s []byte) (int, error) { + if !l.IsDebug() { + return len(s), nil + } l.log(false, strings.TrimRight(string(s), "\n")) return len(s), nil } @@ -199,15 +211,13 @@ func (l Logger) Write(s []byte) (int, error) { // DebugWriter returns a writer that will act like Logger.Write // but will use debug flag on messages. If Logger.Debug is false, // Write method of returned object will be no-op. -func (l Logger) DebugWriter() io.Writer { - if !l.Debug { - return io.Discard - } - l.Debug = true - return &l +func (l *Logger) DebugWriter() io.Writer { + l2 := l.Sublogger("") + l2.Debug = true + return l2 } -func (l Logger) log(debug bool, s string) { +func (l *Logger) log(debug bool, s string) { if l.Name != "" { s = l.Name + ": " + s } @@ -224,14 +234,15 @@ func (l Logger) log(debug bool, s string) { // Logging is disabled - do nothing. } -func (l Logger) Sublogger(name string) Logger { - if l.Name != "" { +func (l *Logger) Sublogger(name string) *Logger { + if l.Name != "" && name != "" { name = l.Name + "/" + name } - return Logger{ - Out: l.Out, - Name: name, - Debug: l.Debug, + return &Logger{ + Parent: l, + Out: l.Out, + Name: name, + Debug: l.Debug, } } diff --git a/framework/log/orderedjson.go b/framework/log/orderedjson.go index 834f3a6c..9de32baf 100644 --- a/framework/log/orderedjson.go +++ b/framework/log/orderedjson.go @@ -58,7 +58,7 @@ func marshalOrderedJSON(output *strings.Builder, m map[string]interface{}) error val = casted.Format("2006-01-02T15:04:05.000") case time.Duration: val = casted.String() - case LogFormatter: + case Formatter: val = casted.FormatLog() case fmt.Stringer: val = casted.String() diff --git a/framework/log/zap.go b/framework/log/zap.go index 23821f84..893bf484 100644 --- a/framework/log/zap.go +++ b/framework/log/zap.go @@ -7,7 +7,7 @@ import ( // TODO: Migrate to using actual zapcore to improve logging performance type zapLogger struct { - L Logger + L *Logger } func (l zapLogger) Enabled(level zapcore.Level) bool { diff --git a/framework/module/lifetime.go b/framework/module/lifetime.go index 1b0339da..ede5d2b4 100644 --- a/framework/module/lifetime.go +++ b/framework/module/lifetime.go @@ -38,7 +38,7 @@ type ReloadModule interface { } type LifetimeTracker struct { - logger log.Logger + logger *log.Logger instances []*struct { mod LifetimeModule started bool @@ -114,7 +114,7 @@ func (lt *LifetimeTracker) StopAll() error { return nil } -func NewLifetime(log log.Logger) *LifetimeTracker { +func NewLifetime(log *log.Logger) *LifetimeTracker { return &LifetimeTracker{ logger: log, } diff --git a/framework/module/registry.go b/framework/module/registry.go index ec62edf7..ce133402 100644 --- a/framework/module/registry.go +++ b/framework/module/registry.go @@ -35,14 +35,14 @@ type registryEntry struct { } type Registry struct { - logger log.Logger + logger *log.Logger instances map[string]registryEntry initialized map[string]struct{} started map[string]struct{} aliases map[string]string } -func NewRegistry(log log.Logger) *Registry { +func NewRegistry(log *log.Logger) *Registry { return &Registry{ logger: log, instances: make(map[string]registryEntry), diff --git a/framework/resource/netresource/dup.go b/framework/resource/netresource/dup.go new file mode 100644 index 00000000..00d047dc --- /dev/null +++ b/framework/resource/netresource/dup.go @@ -0,0 +1,27 @@ +package netresource + +import "net" + +func dupTCPListener(l *net.TCPListener) (*net.TCPListener, error) { + f, err := l.File() + if err != nil { + return nil, err + } + l2, err := net.FileListener(f) + if err != nil { + return nil, err + } + return l2.(*net.TCPListener), nil +} + +func dupUnixListener(l *net.UnixListener) (*net.UnixListener, error) { + f, err := l.File() + if err != nil { + return nil, err + } + l2, err := net.FileListener(f) + if err != nil { + return nil, err + } + return l2.(*net.UnixListener), nil +} diff --git a/framework/resource/netresource/fd.go b/framework/resource/netresource/fd.go new file mode 100644 index 00000000..395ddbcc --- /dev/null +++ b/framework/resource/netresource/fd.go @@ -0,0 +1,47 @@ +package netresource + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "strings" +) + +func ListenFD(fd uint) (net.Listener, error) { + file := os.NewFile(uintptr(fd), strconv.FormatUint(uint64(fd), 10)) + defer file.Close() + return net.FileListener(file) +} + +func ListenFDName(name string) (net.Listener, error) { + listenPDStr := os.Getenv("LISTEN_PID") + if listenPDStr == "" { + return nil, errors.New("$LISTEN_PID is not set") + } + listenPid, err := strconv.Atoi(listenPDStr) + if err != nil { + return nil, errors.New("$LISTEN_PID is not integer") + } + if listenPid != os.Getpid() { + return nil, fmt.Errorf("$LISTEN_PID (%d) is not our PID (%d)", listenPid, os.Getpid()) + } + + names := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":") + fd := uintptr(0) + for i, fdName := range names { + if fdName == name { + fd = uintptr(3 + i) + break + } + } + + if fd == 0 { + return nil, fmt.Errorf("name %s not found in $LISTEN_FDNAMES", name) + } + + file := os.NewFile(3+fd, name) + defer file.Close() + return net.FileListener(file) +} diff --git a/framework/resource/netresource/listen.go b/framework/resource/netresource/listen.go new file mode 100644 index 00000000..e2328002 --- /dev/null +++ b/framework/resource/netresource/listen.go @@ -0,0 +1,38 @@ +package netresource + +import ( + "fmt" + "net" + "strconv" + + "github.com/foxcpp/maddy/framework/log" +) + +var ( + tracker = NewListenerTracker(log.DefaultLogger.Sublogger("netresource")) +) + +func CloseUnusedListeners() error { + return tracker.Close() +} + +func ResetListenersUsage() { + tracker.ResetUsage() +} + +func Listen(network, addr string) (net.Listener, error) { + switch network { + case "fd": + fd, err := strconv.ParseUint(addr, 10, strconv.IntSize) + if err != nil { + return nil, fmt.Errorf("invalid FD number: %v", addr) + } + return ListenFD(uint(fd)) + case "fdname": + return ListenFDName(addr) + case "tcp", "tcp4", "tcp6", "unix": + return tracker.Get(network, addr) + default: + return nil, fmt.Errorf("unsupported network: %v", network) + } +} diff --git a/framework/resource/netresource/tracker.go b/framework/resource/netresource/tracker.go new file mode 100644 index 00000000..97989c62 --- /dev/null +++ b/framework/resource/netresource/tracker.go @@ -0,0 +1,91 @@ +package netresource + +import ( + "fmt" + "net" + "net/netip" + + "github.com/foxcpp/maddy/framework/log" + "github.com/foxcpp/maddy/framework/resource" +) + +type ListenerTracker struct { + logger *log.Logger + tcp *resource.Tracker[*net.TCPListener] + unix *resource.Tracker[*net.UnixListener] +} + +func (lt *ListenerTracker) Get(network, addr string) (net.Listener, error) { + switch network { + case "tcp", "tcp4", "tcp6": + l, err := lt.tcp.GetOpen(addr, func() (*net.TCPListener, error) { + addrPort, err := netip.ParseAddrPort(addr) + if err != nil { + return nil, err + } + lt.logger.DebugMsg("new listener", "network", network, "address", addr) + return net.ListenTCP(network, net.TCPAddrFromAddrPort(addrPort)) + }) + if err != nil { + return nil, err + } + + // We return duplicated listener so when listener is closed by user endpoint + // the tracked resource remains available and listening on the port doesn't + // actually stop. + l2, err := dupTCPListener(l) + if err != nil { + return nil, err + } + return l2, nil + case "unix": + l, err := lt.unix.GetOpen(addr, func() (*net.UnixListener, error) { + addr, err := net.ResolveUnixAddr(network, addr) + if err != nil { + return nil, err + } + lt.logger.DebugMsg("new listener", "network", network, "address", addr) + return net.ListenUnix(network, addr) + }) + if err != nil { + return nil, err + } + + l2, err := dupUnixListener(l) + if err != nil { + return nil, err + } + return l2, nil + default: + return nil, fmt.Errorf("unsupported network type: %s", network) + } +} + +func (lt *ListenerTracker) ResetUsage() { + lt.tcp.MarkAllUnused() + lt.unix.MarkAllUnused() +} + +func (lt *ListenerTracker) CloseUnused() error { + lt.tcp.CloseUnused(func(key string) bool { + return false + }) + lt.unix.CloseUnused(func(key string) bool { + return false + }) + return nil +} + +func (lt *ListenerTracker) Close() error { + lt.tcp.Close() + lt.unix.Close() + return nil +} + +func NewListenerTracker(log *log.Logger) *ListenerTracker { + return &ListenerTracker{ + logger: log, + tcp: resource.NewTracker[*net.TCPListener](resource.NewSingleton[*net.TCPListener]()), + unix: resource.NewTracker[*net.UnixListener](resource.NewSingleton[*net.UnixListener]()), + } +} diff --git a/framework/resource/resource.go b/framework/resource/resource.go new file mode 100644 index 00000000..a34942a6 --- /dev/null +++ b/framework/resource/resource.go @@ -0,0 +1,18 @@ +package resource + +import ( + "io" +) + +type Resource = io.Closer + +type CheckableResource interface { + Resource + IsUsable() bool +} + +type Container[T Resource] interface { + io.Closer + GetOpen(key string, open func() (T, error)) (T, error) + CloseUnused(isUsed func(key string) bool) error +} diff --git a/framework/resource/singleton.go b/framework/resource/singleton.go new file mode 100644 index 00000000..2614cee2 --- /dev/null +++ b/framework/resource/singleton.go @@ -0,0 +1,64 @@ +package resource + +import ( + "sync" +) + +// Singleton represents a set of resources identified by an unique key. +type Singleton[T Resource] struct { + lock sync.RWMutex + resources map[string]T +} + +func NewSingleton[T Resource]() *Singleton[T] { + return &Singleton[T]{ + resources: make(map[string]T), + } +} + +func (s *Singleton[T]) GetOpen(key string, open func() (T, error)) (T, error) { + s.lock.Lock() + defer s.lock.Unlock() + + existing, ok := s.resources[key] + if ok { + return existing, nil + } + + res, err := open() + if err != nil { + var empty T + return empty, err + } + + s.resources[key] = res + + return res, nil +} + +func (s *Singleton[T]) CloseUnused(isUsed func(key string) bool) error { + s.lock.Lock() + defer s.lock.Unlock() + + for key, res := range s.resources { + if isUsed(key) { + continue + } + res.Close() + delete(s.resources, key) + } + + return nil +} + +func (s *Singleton[T]) Close() error { + s.lock.Lock() + defer s.lock.Unlock() + + for key, res := range s.resources { + res.Close() + delete(s.resources, key) + } + + return nil +} diff --git a/framework/resource/tracker.go b/framework/resource/tracker.go new file mode 100644 index 00000000..318cf298 --- /dev/null +++ b/framework/resource/tracker.go @@ -0,0 +1,47 @@ +package resource + +import ( + "sync" +) + +// Tracker is a container wrapper that tracks whether resources were used since +// last MarkAllUnused call. +type Tracker[T Resource] struct { + C Container[T] + + usedLock sync.Mutex + used map[string]bool +} + +func NewTracker[T Resource](c Container[T]) *Tracker[T] { + return &Tracker[T]{C: c, used: make(map[string]bool)} +} + +func (t *Tracker[T]) Close() error { + return t.C.Close() +} + +func (t *Tracker[T]) MarkAllUnused() { + t.usedLock.Lock() + defer t.usedLock.Unlock() + + t.used = make(map[string]bool) +} + +func (t *Tracker[T]) GetOpen(key string, open func() (T, error)) (T, error) { + t.usedLock.Lock() + t.used[key] = true + t.usedLock.Unlock() + + return t.C.GetOpen(key, open) +} + +func (t *Tracker[T]) CloseUnused(isUsed func(key string) bool) error { + t.usedLock.Lock() + defer t.usedLock.Unlock() + + return t.C.CloseUnused(func(key string) bool { + used := t.used[key] + return used && isUsed(key) + }) +} diff --git a/internal/auth/netauth/netauth.go b/internal/auth/netauth/netauth.go index 62d6c6cc..db84cd18 100644 --- a/internal/auth/netauth/netauth.go +++ b/internal/auth/netauth/netauth.go @@ -42,7 +42,7 @@ func (a *Auth) Configure(inlineArgs []string, cfg *config.Map) error { return fmt.Errorf("%s: inline arguments are not used", modName) } - l := hclog.New(&hclog.LoggerOptions{Output: a.log}) + l := hclog.New(&hclog.LoggerOptions{Output: &a.log}) n, err := netauth.NewWithLog(l) if err != nil { return err diff --git a/internal/endpoint/dovecot_sasld/dovecot_sasl.go b/internal/endpoint/dovecot_sasld/dovecot_sasl.go index b1bcd159..f4dee1a7 100644 --- a/internal/endpoint/dovecot_sasld/dovecot_sasl.go +++ b/internal/endpoint/dovecot_sasld/dovecot_sasl.go @@ -31,6 +31,7 @@ import ( modconfig "github.com/foxcpp/maddy/framework/config/module" "github.com/foxcpp/maddy/framework/log" "github.com/foxcpp/maddy/framework/module" + "github.com/foxcpp/maddy/framework/resource/netresource" "github.com/foxcpp/maddy/internal/auth" "github.com/foxcpp/maddy/internal/authz" ) @@ -79,7 +80,7 @@ func (endp *Endpoint) Configure(_ []string, cfg *config.Map) error { } endp.srv = dovecotsasl.NewServer() - endp.srv.Log = stdlog.New(endp.log, "", 0) + endp.srv.Log = stdlog.New(&endp.log, "", 0) for _, mech := range endp.saslAuth.SASLMechanisms() { endp.srv.AddMechanism(mech, mechInfo[mech], func(req *dovecotsasl.AuthReq) sasl.Server { @@ -106,7 +107,7 @@ func (endp *Endpoint) Configure(_ []string, cfg *config.Map) error { func (endp *Endpoint) Start() error { for _, addr := range endp.endpoints { - l, err := net.Listen(addr.Network(), addr.Address()) + l, err := netresource.Listen(addr.Network(), addr.Address()) if err != nil { return fmt.Errorf("%s: %v", modName, err) } diff --git a/internal/endpoint/imap/imap.go b/internal/endpoint/imap/imap.go index 765c429f..13116225 100644 --- a/internal/endpoint/imap/imap.go +++ b/internal/endpoint/imap/imap.go @@ -42,6 +42,7 @@ import ( tls2 "github.com/foxcpp/maddy/framework/config/tls" "github.com/foxcpp/maddy/framework/log" "github.com/foxcpp/maddy/framework/module" + "github.com/foxcpp/maddy/framework/resource/netresource" "github.com/foxcpp/maddy/internal/auth" "github.com/foxcpp/maddy/internal/authz" "github.com/foxcpp/maddy/internal/proxy_protocol" @@ -126,7 +127,7 @@ func (endp *Endpoint) Configure(_ []string, cfg *config.Map) error { if ioErrors { endp.serv.ErrorLog = &endp.Log } else { - endp.serv.ErrorLog = log.Logger{Out: log.NopOutput{}} + endp.serv.ErrorLog = &log.Logger{Out: log.NopOutput{}} } if ioDebug { endp.serv.Debug = endp.Log.DebugWriter() @@ -174,7 +175,7 @@ func (endp *Endpoint) setupListeners(addresses []config.Endpoint) error { for _, addr := range addresses { var l net.Listener var err error - l, err = net.Listen(addr.Network(), addr.Address()) + l, err = netresource.Listen(addr.Network(), addr.Address()) if err != nil { return fmt.Errorf("imap: %v", err) } diff --git a/internal/endpoint/openmetrics/om.go b/internal/endpoint/openmetrics/om.go index 136d4af5..44ac64db 100644 --- a/internal/endpoint/openmetrics/om.go +++ b/internal/endpoint/openmetrics/om.go @@ -21,13 +21,13 @@ package openmetrics import ( "errors" "fmt" - "net" "net/http" "sync" "github.com/foxcpp/maddy/framework/config" "github.com/foxcpp/maddy/framework/log" "github.com/foxcpp/maddy/framework/module" + "github.com/foxcpp/maddy/framework/resource/netresource" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -83,7 +83,7 @@ func (e *Endpoint) InstanceName() string { func (e *Endpoint) Start() error { for _, endp := range e.endpoints { - l, err := net.Listen(endp.Network(), endp.Address()) + l, err := netresource.Listen(endp.Network(), endp.Address()) if err != nil { e.Stop() return fmt.Errorf("%s: %v", modName, err) diff --git a/internal/endpoint/smtp/smtp.go b/internal/endpoint/smtp/smtp.go index 7e5f3ab0..ee433e86 100644 --- a/internal/endpoint/smtp/smtp.go +++ b/internal/endpoint/smtp/smtp.go @@ -41,6 +41,7 @@ import ( "github.com/foxcpp/maddy/framework/future" "github.com/foxcpp/maddy/framework/log" "github.com/foxcpp/maddy/framework/module" + "github.com/foxcpp/maddy/framework/resource/netresource" "github.com/foxcpp/maddy/internal/auth" "github.com/foxcpp/maddy/internal/authz" "github.com/foxcpp/maddy/internal/limits" @@ -107,7 +108,7 @@ func New(modName string, addrs []string) (module.LifetimeModule, error) { func (endp *Endpoint) Configure(_ []string, cfg *config.Map) error { endp.serv = smtp.NewServer(endp) - endp.serv.ErrorLog = endp.Log + endp.serv.ErrorLog = &endp.Log endp.serv.LMTP = endp.lmtp endp.serv.EnableSMTPUTF8 = true endp.serv.EnableREQUIRETLS = true @@ -328,7 +329,7 @@ func (endp *Endpoint) setupListeners(addresses []config.Endpoint) error { for _, addr := range addresses { var l net.Listener var err error - l, err = net.Listen(addr.Network(), addr.Address()) + l, err = netresource.Listen(addr.Network(), addr.Address()) if err != nil { return fmt.Errorf("%s: %w", endp.name, err) } diff --git a/internal/updatepipe/unix_pipe.go b/internal/updatepipe/unix_pipe.go index a8249f90..945c4f95 100644 --- a/internal/updatepipe/unix_pipe.go +++ b/internal/updatepipe/unix_pipe.go @@ -27,6 +27,7 @@ import ( mess "github.com/foxcpp/go-imap-mess" "github.com/foxcpp/maddy/framework/log" + "github.com/foxcpp/maddy/framework/resource/netresource" ) // UnixSockPipe implements the UpdatePipe interface by serializating updates @@ -74,7 +75,7 @@ func (usp *UnixSockPipe) readUpdates(conn net.Conn, updCh chan<- mess.Update) { } func (usp *UnixSockPipe) Listen(upd chan<- mess.Update) error { - l, err := net.Listen("unix", usp.SockPath) + l, err := netresource.Listen("unix", usp.SockPath) if err != nil { return err } diff --git a/maddy.go b/maddy.go index 829cdf76..e0a33e1d 100644 --- a/maddy.go +++ b/maddy.go @@ -36,6 +36,7 @@ import ( "github.com/foxcpp/maddy/framework/hooks" "github.com/foxcpp/maddy/framework/log" "github.com/foxcpp/maddy/framework/module" + "github.com/foxcpp/maddy/framework/resource/netresource" "github.com/foxcpp/maddy/internal/authz" maddycli "github.com/foxcpp/maddy/internal/cli" "github.com/urfave/cli/v2" @@ -422,12 +423,14 @@ func moduleReload(oldContainer *container.C, configPath string) *container.C { oldContainer.DefaultLogger.Msg("configuration loaded") + netresource.ResetListenersUsage() oldContainer.DefaultLogger.Msg("starting new server") if err := moduleStart(newContainer); err != nil { oldContainer.DefaultLogger.Error("failed to start new server", err) container.Global = oldContainer return oldContainer } + netresource.CloseUnusedListeners() newContainer.DefaultLogger.Msg("server started", "version", Version) oldContainer.DefaultLogger.Msg("stopping server")