From ca9aca9b19b82d2976099c92d9ec7d9de8eb44c4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 6 Feb 2025 10:20:31 +0100 Subject: [PATCH 1/9] Fix nil pointer exception when load empty list and try to cast it (#3282) --- client/internal/engine.go | 3 +-- client/internal/peer/ice/StunTurn.go | 22 ++++++++++++++++++++++ client/internal/peer/ice/StunTurn_test.go | 13 +++++++++++++ client/internal/peer/ice/agent.go | 3 +-- client/internal/peer/ice/config.go | 4 +--- 5 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 client/internal/peer/ice/StunTurn.go create mode 100644 client/internal/peer/ice/StunTurn_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f7cdf376e7..335729d92f8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -13,7 +13,6 @@ import ( "sort" "strings" "sync" - "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -146,7 +145,7 @@ type Engine struct { STUNs []*stun.URI // TURNs is a list of STUN servers used by ICE TURNs []*stun.URI - stunTurn atomic.Value + stunTurn icemaker.StunTurn clientCtx context.Context clientCancel context.CancelFunc diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go new file mode 100644 index 00000000000..63ee8c71384 --- /dev/null +++ b/client/internal/peer/ice/StunTurn.go @@ -0,0 +1,22 @@ +package ice + +import ( + "sync/atomic" + + "github.com/pion/stun/v2" +) + +type StunTurn atomic.Value + +func (s *StunTurn) Load() []*stun.URI { + v := (*atomic.Value)(s).Load() + if v == nil { + return nil + } + + return v.([]*stun.URI) +} + +func (s *StunTurn) Store(value []*stun.URI) { + (*atomic.Value)(s).Store(value) +} diff --git a/client/internal/peer/ice/StunTurn_test.go b/client/internal/peer/ice/StunTurn_test.go new file mode 100644 index 00000000000..7233afa6c9c --- /dev/null +++ b/client/internal/peer/ice/StunTurn_test.go @@ -0,0 +1,13 @@ +package ice + +import ( + "testing" +) + +func TestStunTurn_LoadEmpty(t *testing.T) { + var stStunTurn StunTurn + got := stStunTurn.Load() + if len(got) != 0 { + t.Errorf("StunTurn.Load() = %v, want %v", got, nil) + } +} diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index dc4750f243a..af9e60f2d07 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -5,7 +5,6 @@ import ( "github.com/pion/ice/v3" "github.com/pion/randutil" - "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -39,7 +38,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida agentConfig := &ice.AgentConfig{ MulticastDNSMode: ice.MulticastDNSModeDisabled, NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: config.StunTurn.Load().([]*stun.URI), + Urls: config.StunTurn.Load(), CandidateTypes: candidateTypes, InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList), UDPMux: config.UDPMux, diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go index 8abc842f0d2..dd854a605b7 100644 --- a/client/internal/peer/ice/config.go +++ b/client/internal/peer/ice/config.go @@ -1,14 +1,12 @@ package ice import ( - "sync/atomic" - "github.com/pion/ice/v3" ) type Config struct { // StunTurn is a list of STUN and TURN URLs - StunTurn *atomic.Value // []*stun.URI + StunTurn *StunTurn // []*stun.URI // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering // (e.g. if eth0 is in the list, host candidate of this interface won't be used) From cee4aeea9e48376707974f301f280b3d813131a9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 6 Feb 2025 13:36:57 +0100 Subject: [PATCH 2/9] [management] Check groups when counting peers on networks list (#3284) --- management/server/http/handlers/networks/handler.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 316b936115b..f716348d673 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/groups" @@ -281,7 +282,12 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne } if len(router.PeerGroups) > 0 { for _, groupID := range router.PeerGroups { - peerCounter += len(groups[groupID].Peers) + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Warnf("group %s not found", groupID) + continue + } + peerCounter += len(group.Peers) } } } From b7af53ea40b74f52ac57ec874690e93bd2521abb Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:51:17 +0100 Subject: [PATCH 3/9] [management] add logs for grpc API (#3298) --- management/server/grpcserver.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index eec109ee970..e8e0c422ed6 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "github.com/netbirdio/netbird/encryption" @@ -114,6 +115,18 @@ func NewServer( } func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { + ip := "" + p, ok := peer.FromContext(ctx) + if ok { + ip = p.Addr.String() + } + + log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) + start := time.Now() + defer func() { + log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start)) + }() + // todo introduce something more meaningful with the key expiration/rotation if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountGetKeyRequest() @@ -717,6 +730,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) + start := time.Now() + defer func() { + log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start)) + }() + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) @@ -769,6 +788,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) + start := time.Now() + defer func() { + log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start)) + }() + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) From 05415f72ec827835edd211f332a235f5b8e02ba6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 7 Feb 2025 14:11:53 +0100 Subject: [PATCH 4/9] [client] Add experimental support for userspace routing (#3134) --- client/Dockerfile-rootless | 1 + client/cmd/trace.go | 137 +++ client/firewall/create.go | 4 +- client/firewall/create_linux.go | 14 +- client/firewall/iface.go | 4 + client/firewall/iptables/manager_linux.go | 5 + client/firewall/iptables/router_linux.go | 11 +- client/firewall/manager/firewall.go | 2 + client/firewall/nftables/manager_linux.go | 5 + .../firewall/nftables/manager_linux_test.go | 17 +- client/firewall/nftables/router_linux.go | 8 +- client/firewall/uspfilter/allow_netbird.go | 23 +- .../uspfilter/allow_netbird_windows.go | 20 +- client/firewall/uspfilter/common/iface.go | 16 + client/firewall/uspfilter/conntrack/common.go | 21 +- .../uspfilter/conntrack/common_test.go | 36 +- client/firewall/uspfilter/conntrack/icmp.go | 27 +- .../firewall/uspfilter/conntrack/icmp_test.go | 4 +- client/firewall/uspfilter/conntrack/tcp.go | 41 +- .../firewall/uspfilter/conntrack/tcp_test.go | 14 +- client/firewall/uspfilter/conntrack/udp.go | 19 +- .../firewall/uspfilter/conntrack/udp_test.go | 12 +- .../firewall/uspfilter/forwarder/endpoint.go | 81 ++ .../firewall/uspfilter/forwarder/forwarder.go | 166 +++ client/firewall/uspfilter/forwarder/icmp.go | 109 ++ client/firewall/uspfilter/forwarder/tcp.go | 90 ++ client/firewall/uspfilter/forwarder/udp.go | 288 +++++ client/firewall/uspfilter/localip.go | 134 +++ client/firewall/uspfilter/localip_test.go | 270 +++++ client/firewall/uspfilter/log/log.go | 196 ++++ client/firewall/uspfilter/log/ringbuffer.go | 85 ++ client/firewall/uspfilter/rule.go | 22 +- client/firewall/uspfilter/tracer.go | 390 +++++++ client/firewall/uspfilter/uspfilter.go | 512 +++++++-- .../uspfilter/uspfilter_bench_test.go | 122 +- .../uspfilter/uspfilter_filter_test.go | 1014 +++++++++++++++++ client/firewall/uspfilter/uspfilter_test.go | 66 +- client/iface/device.go | 3 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_kernel_unix.go | 6 + client/iface/device/device_netstack.go | 5 + client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/device_android.go | 3 + client/iface/iface.go | 7 + client/iface/iface_moc.go | 9 +- client/iface/iwginterface.go | 2 + client/iface/iwginterface_windows.go | 2 + client/internal/acl/manager_test.go | 6 +- client/internal/acl/mocks/iface_mapper.go | 30 + client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 29 +- client/internal/routemanager/manager.go | 5 - client/proto/daemon.pb.go | 690 ++++++++--- client/proto/daemon.proto | 35 + client/proto/daemon_grpc.pb.go | 36 + client/server/debug.go | 17 + client/server/trace.go | 123 ++ go.mod | 4 +- go.sum | 8 +- 60 files changed, 4650 insertions(+), 373 deletions(-) create mode 100644 client/cmd/trace.go create mode 100644 client/firewall/uspfilter/common/iface.go create mode 100644 client/firewall/uspfilter/forwarder/endpoint.go create mode 100644 client/firewall/uspfilter/forwarder/forwarder.go create mode 100644 client/firewall/uspfilter/forwarder/icmp.go create mode 100644 client/firewall/uspfilter/forwarder/tcp.go create mode 100644 client/firewall/uspfilter/forwarder/udp.go create mode 100644 client/firewall/uspfilter/localip.go create mode 100644 client/firewall/uspfilter/localip_test.go create mode 100644 client/firewall/uspfilter/log/log.go create mode 100644 client/firewall/uspfilter/log/ringbuffer.go create mode 100644 client/firewall/uspfilter/tracer.go create mode 100644 client/firewall/uspfilter/uspfilter_filter_test.go create mode 100644 client/server/trace.go diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 62bcaf964bd..78314ba121c 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -9,6 +9,7 @@ USER netbird:netbird ENV NB_FOREGROUND_MODE=true ENV NB_USE_NETSTACK_MODE=true +ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true ENV NB_CONFIG=config.json ENV NB_DAEMON_ADDR=unix://netbird.sock ENV NB_DISABLE_DNS=true diff --git a/client/cmd/trace.go b/client/cmd/trace.go new file mode 100644 index 00000000000..b2ff1f1b54e --- /dev/null +++ b/client/cmd/trace.go @@ -0,0 +1,137 @@ +package cmd + +import ( + "fmt" + "math/rand" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var traceCmd = &cobra.Command{ + Use: "trace ", + Short: "Trace a packet through the firewall", + Example: ` + netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack + netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 + netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 + netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, + Args: cobra.ExactArgs(3), + RunE: tracePacket, +} + +func init() { + debugCmd.AddCommand(traceCmd) + + traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)") + traceCmd.Flags().Uint16("sport", 0, "Source port") + traceCmd.Flags().Uint16("dport", 0, "Destination port") + traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type") + traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code") + traceCmd.Flags().Bool("syn", false, "TCP SYN flag") + traceCmd.Flags().Bool("ack", false, "TCP ACK flag") + traceCmd.Flags().Bool("fin", false, "TCP FIN flag") + traceCmd.Flags().Bool("rst", false, "TCP RST flag") + traceCmd.Flags().Bool("psh", false, "TCP PSH flag") + traceCmd.Flags().Bool("urg", false, "TCP URG flag") +} + +func tracePacket(cmd *cobra.Command, args []string) error { + direction := strings.ToLower(args[0]) + if direction != "in" && direction != "out" { + return fmt.Errorf("invalid direction: use 'in' or 'out'") + } + + protocol := cmd.Flag("protocol").Value.String() + if protocol != "tcp" && protocol != "udp" && protocol != "icmp" { + return fmt.Errorf("invalid protocol: use tcp/udp/icmp") + } + + sport, err := cmd.Flags().GetUint16("sport") + if err != nil { + return fmt.Errorf("invalid source port: %v", err) + } + dport, err := cmd.Flags().GetUint16("dport") + if err != nil { + return fmt.Errorf("invalid destination port: %v", err) + } + + // For TCP/UDP, generate random ephemeral port (49152-65535) if not specified + if protocol != "icmp" { + if sport == 0 { + sport = uint16(rand.Intn(16383) + 49152) + } + if dport == 0 { + dport = uint16(rand.Intn(16383) + 49152) + } + } + + var tcpFlags *proto.TCPFlags + if protocol == "tcp" { + syn, _ := cmd.Flags().GetBool("syn") + ack, _ := cmd.Flags().GetBool("ack") + fin, _ := cmd.Flags().GetBool("fin") + rst, _ := cmd.Flags().GetBool("rst") + psh, _ := cmd.Flags().GetBool("psh") + urg, _ := cmd.Flags().GetBool("urg") + + tcpFlags = &proto.TCPFlags{ + Syn: syn, + Ack: ack, + Fin: fin, + Rst: rst, + Psh: psh, + Urg: urg, + } + } + + icmpType, _ := cmd.Flags().GetUint32("icmp-type") + icmpCode, _ := cmd.Flags().GetUint32("icmp-code") + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{ + SourceIp: args[1], + DestinationIp: args[2], + Protocol: protocol, + SourcePort: uint32(sport), + DestinationPort: uint32(dport), + Direction: direction, + TcpFlags: tcpFlags, + IcmpType: &icmpType, + IcmpCode: &icmpCode, + }) + if err != nil { + return fmt.Errorf("trace failed: %v", status.Convert(err).Message()) + } + + printTrace(cmd, args[1], args[2], protocol, sport, dport, resp) + return nil +} + +func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { + cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) + + for _, stage := range resp.Stages { + if stage.ForwardingDetails != nil { + cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails) + } else { + cmd.Printf("%s: %s\n", stage.Name, stage.Message) + } + } + + disposition := map[bool]string{ + true: "\033[32mALLOWED\033[0m", // Green + false: "\033[31mDENIED\033[0m", // Red + }[resp.FinalDisposition] + + cmd.Printf("\nFinal disposition: %s\n", disposition) +} diff --git a/client/firewall/create.go b/client/firewall/create.go index 9466f4b4d6b..37ea5ceb3fa 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -14,13 +14,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface) + fm, err := uspfilter.Create(iface, disableServerRoutes) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 076d08ec27b..be1b37916bb 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) if !iface.IsUserspaceBind() { return fm, err @@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm) + return createUserspaceFirewall(iface, fm, disableServerRoutes) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { fm, err := createFW(iface) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) @@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) { } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) } else { - fm, errUsp = uspfilter.Create(iface) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes) } if errUsp != nil { diff --git a/client/firewall/iface.go b/client/firewall/iface.go index f349f9210a6..d842abaa124 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,6 +1,8 @@ package firewall import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/device" ) @@ -10,4 +12,6 @@ type IFaceMapper interface { Address() device.WGAddress IsUserspaceBind() bool SetFilter(device.PacketFilter) error + GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 75f082fc4c0..679f288e32a 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -213,6 +213,11 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(log.Level) { + // not supported +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a47d3ffe698..6522daa3f41 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering( } rule := genRouteFilteringRuleSpec(params) - if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + // Insert DROP rules at the beginning, append ACCEPT rules at the end + var err error + if action == firewall.ActionDrop { + // after the established rule + err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...) + } else { + err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...) + } + + if err != nil { return nil, fmt.Errorf("add route rule: %v", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index f46e5eb5d57..de25ff1f11c 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -99,6 +99,8 @@ type Manager interface { // Flush the changes to firewall controller Flush() error + + SetLogLevel(log.Level) } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a78626dbcd5..4fe52bd5361 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -318,6 +318,11 @@ func (m *Manager) cleanupNetbirdTables() error { return nil } +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(log.Level) { + // not supported +} + // Flush rule/chain/set operations from the buffer // // Method also get all rules after flush and refreshes handle values in the rulesets diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 8d693725a6d..eaa8ef1f571 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) { Kind: expr.VerdictAccept, }, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") + compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() @@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } + +func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) { + t.Helper() + require.Equal(t, len(got), len(want), "expression count mismatch") + + for i := range got { + if _, isCounter := got[i].(*expr.Counter); isCounter { + _, wantIsCounter := want[i].(*expr.Counter) + require.True(t, wantIsCounter, "expected Counter at index %d", i) + continue + } + + require.Equal(t, got[i], want[i], "expression mismatch at index %d", i) + } +} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 19734673b72..92f81f39cfa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - rule = r.conn.AddRule(rule) + // Insert DROP rules at the beginning, append ACCEPT rules at the end + if action == firewall.ActionDrop { + // TODO: Insert after the established rule + rule = r.conn.InsertRule(rule) + } else { + rule = r.conn.AddRule(rule) + } log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cc07922559d..03f23f5e622 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,11 @@ package uspfilter import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + if m.forwarder != nil { + m.forwarder.Stop() + } + + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } } if m.nativeFirewall != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 0d55d62689c..37958597826 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -1,9 +1,11 @@ package uspfilter import ( + "context" "fmt" "os/exec" "syscall" + "time" log "github.com/sirupsen/logrus" @@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + if m.forwarder != nil { + m.forwarder.Stop() + } + + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } } if !isWindowsFirewallReachable() { diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go new file mode 100644 index 00000000000..d44e7950936 --- /dev/null +++ b/client/firewall/uspfilter/common/iface.go @@ -0,0 +1,16 @@ +package common + +import ( + wgdevice "golang.zx2c4.com/wireguard/device" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// IFaceMapper defines subset methods of interface required for manager +type IFaceMapper interface { + SetFilter(device.PacketFilter) error + Address() iface.WGAddress + GetWGDevice() *wgdevice.Device + GetDevice() *device.FilteredDevice +} diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index e459bc75ae1..f5f5025400f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -10,12 +10,11 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - lastSeen atomic.Int64 // Unix nano for atomic access - established atomic.Bool + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access } // these small methods will be inlined by the compiler @@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() { b.lastSeen.Store(time.Now().UnixNano()) } -// IsEstablished safely checks if connection is established -func (b *BaseConnTrack) IsEstablished() bool { - return b.established.Load() -} - -// SetEstablished safely sets the established state -func (b *BaseConnTrack) SetEstablished(state bool) { - b.established.Store(state) -} - // GetLastSeen safely gets the last seen timestamp func (b *BaseConnTrack) GetLastSeen() time.Time { return time.Unix(0, b.lastSeen.Load()) diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 72d006def57..81fa64b19d7 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -3,8 +3,14 @@ package conntrack import ( "net" "testing" + + "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + func BenchmarkIPOperations(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) { ip := net.ParseIP("192.168.1.1") @@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) { }) } -func BenchmarkAtomicOperations(b *testing.B) { - conn := &BaseConnTrack{} - b.Run("UpdateLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.UpdateLastSeen() - } - }) - - b.Run("IsEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.IsEstablished() - } - }) - - b.Run("SetEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.SetEstablished(i%2 == 0) - } - }) - - b.Run("GetLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.GetLastSeen() - } - }) -} // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() // Generate different IPs @@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) { }) b.Run("UDPHighLoad", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() // Generate different IPs diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index e0a971678f1..25cd9e87d72 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -6,6 +6,8 @@ import ( "time" "github.com/google/gopacket/layers" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -33,6 +35,7 @@ type ICMPConnTrack struct { // ICMPTracker manages ICMP connection states type ICMPTracker struct { + logger *nblog.Logger connections map[ICMPConnKey]*ICMPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -42,12 +45,13 @@ type ICMPTracker struct { } // NewICMPTracker creates a new ICMP connection tracker -func NewICMPTracker(timeout time.Duration) *ICMPTracker { +func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { if timeout == 0 { timeout = DefaultICMPTimeout } tracker := &ICMPTracker{ + logger: logger, connections: make(map[ICMPConnKey]*ICMPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), @@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := makeICMPKey(srcIP, dstIP, id, seq) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u ID: id, Sequence: seq, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn + + t.logger.Trace("New ICMP connection %v", key) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { - switch icmpType { - case uint8(layers.ICMPv4TypeDestinationUnreachable), - uint8(layers.ICMPv4TypeTimeExceeded): - return true - case uint8(layers.ICMPv4TypeEchoReply): - // continue processing - default: + if icmpType != uint8(layers.ICMPv4TypeEchoReply) { return false } @@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && conn.Sequence == seq @@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Debug("Removed ICMP connection %v (timeout)", key) } } } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 21176e719d4..32553c8360f 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a7968dc7375..7c12e8ad01f 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -5,7 +5,10 @@ package conntrack import ( "net" "sync" + "sync/atomic" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -61,12 +64,24 @@ type TCPConnKey struct { // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack - State TCPState + State TCPState + established atomic.Bool sync.RWMutex } +// IsEstablished safely checks if connection is established +func (t *TCPConnTrack) IsEstablished() bool { + return t.established.Load() +} + +// SetEstablished safely sets the established state +func (t *TCPConnTrack) SetEstablished(state bool) { + t.established.Store(state) +} + // TCPTracker manages TCP connection states type TCPTracker struct { + logger *nblog.Logger connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker @@ -76,8 +91,9 @@ type TCPTracker struct { } // NewTCPTracker creates a new TCP connection tracker -func NewTCPTracker(timeout time.Duration) *TCPTracker { +func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { tracker := &TCPTracker{ + logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), @@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { // Create key before lock key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d }, State: TCPStateNew, } - conn.lastSeen.Store(now) + conn.UpdateLastSeen() conn.established.Store(false) t.connections[key] = conn + + t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) } t.mutex.Unlock() @@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.Lock() t.updateState(conn, flags, true) conn.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound TCP packet matches a tracked connection @@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo if flags&TCPRst != 0 { conn.State = TCPStateClosed conn.SetEstablished(false) + + t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) return } @@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo if flags&TCPAck != 0 { conn.State = TCPStateTimeWait // Keep established = false from previous state + + t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } case TCPStateCloseWait: @@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateLastAck: if flags&TCPAck != 0 { conn.State = TCPStateClosed + + t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } case TCPStateTimeWait: // Stay in TIME-WAIT for 2MSL before transitioning to closed // This is handled by the cleanup routine + + t.logger.Trace("TCP connection completed - %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } @@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 6c8f82423bd..5f4c43915fb 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout) + tracker = NewTCPTracker(DefaultTCPTimeout, logger) tt.test(t) }) } @@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a969a4e8425..e73465e3195 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -20,6 +22,7 @@ type UDPConnTrack struct { // UDPTracker manages UDP connection states type UDPTracker struct { + logger *nblog.Logger connections map[ConnKey]*UDPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -29,12 +32,13 @@ type UDPTracker struct { } // NewUDPTracker creates a new UDP connection tracker -func NewUDPTracker(timeout time.Duration) *UDPTracker { +func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { if timeout == 0 { timeout = DefaultUDPTimeout } tracker := &UDPTracker{ + logger: logger, connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), @@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d DestPort: dstPort, }, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn + + t.logger.Trace("New UDP connection: %v", conn) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound packet matches a tracked connection @@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && conn.SourcePort == dstPort @@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("Removed UDP connection %v (timeout)", conn) } } } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 67172189069..fa83ee356a3 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout) + tracker := NewUDPTracker(tt.timeout, logger) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { assert.True(t, conn.DestIP.Equal(dstIP)) assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, dstPort, conn.DestPort) - assert.True(t, conn.IsEstablished()) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1 * time.Second) + tracker := NewUDPTracker(1*time.Second, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), ipPool: NewPreallocatedIPs(), + logger: logger, } // Start cleanup routine @@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go new file mode 100644 index 00000000000..e8a265c94d5 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -0,0 +1,81 @@ +package forwarder + +import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device +type endpoint struct { + logger *nblog.Logger + dispatcher stack.NetworkDispatcher + device *wgdevice.Device + mtu uint32 +} + +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityNone +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) + + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + // Send the packet through WireGuard + address := netHeader.DestinationAddress() + err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) + if err != nil { + e.logger.Error("CreateOutboundPacket: %v", err) + continue + } + written++ + } + + return written, nil +} + +func (e *endpoint) Wait() { + // not required +} + +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (e *endpoint) AddHeader(*stack.PacketBuffer) { + // not required +} + +func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { + return true +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go new file mode 100644 index 00000000000..4ed152b79c9 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -0,0 +1,166 @@ +package forwarder + +import ( + "context" + "fmt" + "net" + "runtime" + + log "github.com/sirupsen/logrus" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +const ( + defaultReceiveWindow = 32768 + defaultMaxInFlight = 1024 + iosReceiveWindow = 16384 + iosMaxInFlight = 256 +) + +type Forwarder struct { + logger *nblog.Logger + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip net.IP + netstack bool +} + +func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + icmp.NewProtocol4, + }, + HandleLocal: false, + }) + + mtu, err := iface.GetDevice().MTU() + if err != nil { + return nil, fmt.Errorf("get MTU: %w", err) + } + nicID := tcpip.NICID(1) + endpoint := &endpoint{ + logger: logger, + device: iface.GetWGDevice(), + mtu: uint32(mtu), + } + + if err := s.CreateNIC(nicID, endpoint); err != nil { + return nil, fmt.Errorf("failed to create NIC: %v", err) + } + + ones, _ := iface.Address().Network.Mask.Size() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), + PrefixLen: ones, + }, + } + + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("failed to add protocol address: %s", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + return nil, fmt.Errorf("creating default subnet: %w", err) + } + + if err := s.SetPromiscuousMode(nicID, true); err != nil { + return nil, fmt.Errorf("set promiscuous mode: %s", err) + } + if err := s.SetSpoofing(nicID, true); err != nil { + return nil, fmt.Errorf("set spoofing: %s", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: defaultSubnet, + NIC: nicID, + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + f := &Forwarder{ + logger: logger, + stack: s, + endpoint: endpoint, + udpForwarder: newUDPForwarder(mtu, logger), + ctx: ctx, + cancel: cancel, + netstack: netstack, + ip: iface.Address().IP, + } + + receiveWindow := defaultReceiveWindow + maxInFlight := defaultMaxInFlight + if runtime.GOOS == "ios" { + receiveWindow = iosReceiveWindow + maxInFlight = iosMaxInFlight + } + + tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + udpForwarder := udp.NewForwarder(s, f.handleUDP) + s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + + s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + + log.Debugf("forwarder: Initialization complete with NIC %d", nicID) + return f, nil +} + +func (f *Forwarder) InjectIncomingPacket(payload []byte) error { + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("packet too small: %d bytes", len(payload)) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if f.endpoint.dispatcher != nil { + f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + } + return nil +} + +// Stop gracefully shuts down the forwarder +func (f *Forwarder) Stop() { + f.cancel() + + if f.udpForwarder != nil { + f.udpForwarder.Stop() + } + + f.stack.Close() + f.stack.Wait() +} + +func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { + if f.netstack && f.ip.Equal(addr.AsSlice()) { + return net.IPv4(127, 0, 0, 1) + } + return addr.AsSlice() +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go new file mode 100644 index 00000000000..14cdc37be85 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -0,0 +1,109 @@ +package forwarder + +import ( + "context" + "net" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// handleICMP handles ICMP packets from the network stack +func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + lc := net.ListenConfig{} + // TODO: support non-root + conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + if err != nil { + f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) + + // This will make netstack reply on behalf of the original destination, that's ok for now + return false + } + defer func() { + if err := conn.Close(); err != nil { + f.logger.Debug("Failed to close ICMP socket: %v", err) + } + }() + + dstIP := f.determineDialAddr(id.LocalAddress) + dst := &net.IPAddr{IP: dstIP} + + // Get the complete ICMP message (header + data) + fullPacket := stack.PayloadSince(pkt.TransportHeader()) + payload := fullPacket.AsSlice() + + icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) + + // For Echo Requests, send and handle response + switch icmpHdr.Type() { + case header.ICMPv4Echo: + return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) + case header.ICMPv4EchoReply: + // dont process our own replies + return true + default: + } + + // For other ICMP types (Time Exceeded, Destination Unreachable, etc) + _, err = conn.WriteTo(payload, dst) + if err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return true + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + + return true +} + +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool { + if _, err := conn.WriteTo(payload, dst); err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return true + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + f.logger.Error("Failed to set read deadline for ICMP response: %v", err) + return true + } + + response := make([]byte, f.endpoint.mtu) + n, _, err := conn.ReadFrom(response) + if err != nil { + if !isTimeout(err) { + f.logger.Error("Failed to read ICMP response: %v", err) + } + return true + } + + ipHdr := make([]byte, header.IPv4MinimumSize) + ip := header.IPv4(ipHdr) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + n), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + fullPacket := make([]byte, 0, len(ipHdr)+n) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, response[:n]...) + + if err := f.InjectIncomingPacket(fullPacket); err != nil { + f.logger.Error("Failed to inject ICMP response: %v", err) + return true + } + + f.logger.Trace("Forwarded ICMP echo reply for %v", id) + return true +} diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go new file mode 100644 index 00000000000..6d7cf3b6a70 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -0,0 +1,90 @@ +package forwarder + +import ( + "context" + "fmt" + "io" + "net" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// handleTCP is called by the TCP forwarder for new connections. +func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { + id := r.ID() + + dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) + if err != nil { + r.Complete(true) + f.logger.Trace("forwarder: dial error for %v: %v", id, err) + return + } + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr) + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + r.Complete(true) + return + } + + // Complete the handshake + r.Complete(false) + + inConn := gonet.NewTCPConn(&wq, ep) + + f.logger.Trace("forwarder: established TCP connection %v", id) + + go f.proxyTCP(id, inConn, outConn, ep) +} + +func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { + defer func() { + if err := inConn.Close(); err != nil { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + ep.Close() + }() + + // Create context for managing the proxy goroutines + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + errChan := make(chan error, 2) + + go func() { + _, err := io.Copy(outConn, inConn) + errChan <- err + }() + + go func() { + _, err := io.Copy(inConn, outConn) + errChan <- err + }() + + select { + case <-ctx.Done(): + f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id) + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + f.logger.Error("proxyTCP: copy error: %v", err) + } + f.logger.Trace("forwarder: tearing down TCP connection %v", id) + return + } +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go new file mode 100644 index 00000000000..97e4662fd39 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -0,0 +1,288 @@ +package forwarder + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +const ( + udpTimeout = 30 * time.Second +) + +type udpPacketConn struct { + conn *gonet.UDPConn + outConn net.Conn + lastSeen atomic.Int64 + cancel context.CancelFunc + ep tcpip.Endpoint +} + +type udpForwarder struct { + sync.RWMutex + logger *nblog.Logger + conns map[stack.TransportEndpointID]*udpPacketConn + bufPool sync.Pool + ctx context.Context + cancel context.CancelFunc +} + +type idleConn struct { + id stack.TransportEndpointID + conn *udpPacketConn +} + +func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { + ctx, cancel := context.WithCancel(context.Background()) + f := &udpForwarder{ + logger: logger, + conns: make(map[stack.TransportEndpointID]*udpPacketConn), + ctx: ctx, + cancel: cancel, + bufPool: sync.Pool{ + New: func() any { + b := make([]byte, mtu) + return &b + }, + }, + } + go f.cleanup() + return f +} + +// Stop stops the UDP forwarder and all active connections +func (f *udpForwarder) Stop() { + f.cancel() + + f.Lock() + defer f.Unlock() + + for id, conn := range f.conns { + conn.cancel() + if err := conn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err) + } + if err := conn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + + conn.ep.Close() + delete(f.conns, id) + } +} + +// cleanup periodically removes idle UDP connections +func (f *udpForwarder) cleanup() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-f.ctx.Done(): + return + case <-ticker.C: + var idleConns []idleConn + + f.RLock() + for id, conn := range f.conns { + if conn.getIdleDuration() > udpTimeout { + idleConns = append(idleConns, idleConn{id, conn}) + } + } + f.RUnlock() + + for _, idle := range idleConns { + idle.conn.cancel() + if err := idle.conn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err) + } + if err := idle.conn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err) + } + + idle.conn.ep.Close() + + f.Lock() + delete(f.conns, idle.id) + f.Unlock() + + f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) + } + } + } +} + +// handleUDP is called by the UDP forwarder for new packets +func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { + if f.ctx.Err() != nil { + f.logger.Trace("forwarder: context done, dropping UDP packet") + return + } + + id := r.ID() + + f.udpForwarder.RLock() + _, exists := f.udpForwarder.conns[id] + f.udpForwarder.RUnlock() + if exists { + f.logger.Trace("forwarder: existing UDP connection for %v", id) + return + } + + dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) + if err != nil { + f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) + // TODO: Send ICMP error message + return + } + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + return + } + + inConn := gonet.NewUDPConn(f.stack, &wq, ep) + connCtx, connCancel := context.WithCancel(f.ctx) + + pConn := &udpPacketConn{ + conn: inConn, + outConn: outConn, + cancel: connCancel, + ep: ep, + } + pConn.updateLastSeen() + + f.udpForwarder.Lock() + // Double-check no connection was created while we were setting up + if _, exists := f.udpForwarder.conns[id]; exists { + f.udpForwarder.Unlock() + pConn.cancel() + if err := inConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) + } + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + return + } + f.udpForwarder.conns[id] = pConn + f.udpForwarder.Unlock() + + f.logger.Trace("forwarder: established UDP connection to %v", id) + go f.proxyUDP(connCtx, pConn, id, ep) +} + +func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { + defer func() { + pConn.cancel() + if err := pConn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) + } + if err := pConn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + + ep.Close() + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + }() + + errChan := make(chan error, 2) + + go func() { + errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + }() + + go func() { + errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + }() + + select { + case <-ctx.Done(): + f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id) + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + f.logger.Error("proxyUDP: copy error: %v", err) + } + f.logger.Trace("forwarder: tearing down UDP connection %v", id) + return + } +} + +func (c *udpPacketConn) updateLastSeen() { + c.lastSeen.Store(time.Now().UnixNano()) +} + +func (c *udpPacketConn) getIdleDuration() time.Duration { + lastSeen := time.Unix(0, c.lastSeen.Load()) + return time.Since(lastSeen) +} + +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { + bufp := bufPool.Get().(*[]byte) + defer bufPool.Put(bufp) + buffer := *bufp + + if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := src.Read(buffer) + if err != nil { + if isTimeout(err) { + continue + } + return fmt.Errorf("read from %s: %w", direction, err) + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + return fmt.Errorf("write to %s: %w", direction, err) + } + + c.updateLastSeen() + } + } +} + +func isClosedError(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) +} + +func isTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go new file mode 100644 index 00000000000..7664b65d554 --- /dev/null +++ b/client/firewall/uspfilter/localip.go @@ -0,0 +1,134 @@ +package uspfilter + +import ( + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" +) + +type localIPManager struct { + mu sync.RWMutex + + // Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory) + ipv4Bitmap [1 << 16]uint32 +} + +func newLocalIPManager() *localIPManager { + return &localIPManager{} +} + +func (m *localIPManager) setBitmapBit(ip net.IP) { + ipv4 := ip.To4() + if ipv4 == nil { + return + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + m.ipv4Bitmap[high] |= 1 << (low % 32) +} + +func (m *localIPManager) checkBitmapBit(ip net.IP) bool { + ipv4 := ip.To4() + if ipv4 == nil { + return false + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 +} + +func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { + if ipv4 := ip.To4(); ipv4 != nil { + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + if int(high) >= len(*newIPv4Bitmap) { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + ipStr := ip.String() + if _, exists := ipv4Set[ipStr]; !exists { + ipv4Set[ipStr] = struct{}{} + *ipv4Addresses = append(*ipv4Addresses, ipStr) + newIPv4Bitmap[high] |= 1 << (low % 32) + } + } + return nil +} + +func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { + addrs, err := iface.Addrs() + if err != nil { + log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) + return + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { + log.Debugf("process IP failed: %v", err) + } + } +} + +func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %v", r) + } + }() + + var newIPv4Bitmap [1 << 16]uint32 + ipv4Set := make(map[string]struct{}) + var ipv4Addresses []string + + // 127.0.0.0/8 + high := uint16(127) << 8 + for i := uint16(0); i < 256; i++ { + newIPv4Bitmap[high|i] = 0xffffffff + } + + if iface != nil { + if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { + return err + } + } + + interfaces, err := net.Interfaces() + if err != nil { + log.Warnf("failed to get interfaces: %v", err) + } else { + for _, intf := range interfaces { + m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + } + } + + m.mu.Lock() + m.ipv4Bitmap = newIPv4Bitmap + m.mu.Unlock() + + log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + return nil +} + +func (m *localIPManager) IsLocalIP(ip net.IP) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + if ipv4 := ip.To4(); ipv4 != nil { + return m.checkBitmapBit(ipv4) + } + + return false +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go new file mode 100644 index 00000000000..02f41bf4f61 --- /dev/null +++ b/client/firewall/uspfilter/localip_test.go @@ -0,0 +1,270 @@ +package uspfilter + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface" +) + +func TestLocalIPManager(t *testing.T) { + tests := []struct { + name string + setupAddr iface.WGAddress + testIP net.IP + expected bool + }{ + { + name: "Localhost range", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.0.0.2"), + expected: true, + }, + { + name: "Localhost standard address", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.0.0.1"), + expected: true, + }, + { + name: "Localhost range edge", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.255.255.255"), + expected: true, + }, + { + name: "Local IP matches", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("192.168.1.1"), + expected: true, + }, + { + name: "Local IP doesn't match", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("192.168.1.2"), + expected: false, + }, + { + name: "IPv6 address", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("fe80::1"), + Network: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + testIP: net.ParseIP("fe80::1"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := newLocalIPManager() + + mock := &IFaceMock{ + AddressFunc: func() iface.WGAddress { + return tt.setupAddr + }, + } + + err := manager.UpdateLocalIPs(mock) + require.NoError(t, err) + + result := manager.IsLocalIP(tt.testIP) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestLocalIPManager_AllInterfaces(t *testing.T) { + manager := newLocalIPManager() + mock := &IFaceMock{} + + // Get actual local interfaces + interfaces, err := net.Interfaces() + require.NoError(t, err) + + var tests []struct { + ip string + expected bool + } + + // Add all local interface IPs to test cases + for _, iface := range interfaces { + addrs, err := iface.Addrs() + require.NoError(t, err) + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if ip4 := ip.To4(); ip4 != nil { + tests = append(tests, struct { + ip string + expected bool + }{ + ip: ip4.String(), + expected: true, + }) + } + } + } + + // Add some external IPs as negative test cases + externalIPs := []string{ + "8.8.8.8", + "1.1.1.1", + "208.67.222.222", + } + for _, ip := range externalIPs { + tests = append(tests, struct { + ip string + expected bool + }{ + ip: ip, + expected: false, + }) + } + + require.NotEmpty(t, tests, "No test cases generated") + + err = manager.UpdateLocalIPs(mock) + require.NoError(t, err) + + t.Logf("Testing %d IPs", len(tests)) + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + result := manager.IsLocalIP(net.ParseIP(tt.ip)) + require.Equal(t, tt.expected, result, "IP: %s", tt.ip) + }) + } +} + +// MapImplementation is a version using map[string]struct{} +type MapImplementation struct { + localIPs map[string]struct{} +} + +func BenchmarkIPChecks(b *testing.B) { + interfaces := make([]net.IP, 16) + for i := range interfaces { + interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) + } + + // Setup bitmap version + bitmapManager := &localIPManager{ + ipv4Bitmap: [1 << 16]uint32{}, + } + for _, ip := range interfaces[:8] { // Add half of IPs + bitmapManager.setBitmapBit(ip) + } + + // Setup map version + mapManager := &MapImplementation{ + localIPs: make(map[string]struct{}), + } + for _, ip := range interfaces[:8] { + mapManager.localIPs[ip.String()] = struct{}{} + } + + b.Run("Bitmap_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Bitmap_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Map_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + // nolint:gosimple + _, _ = mapManager.localIPs[ip.String()] + } + }) + + b.Run("Map_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + // nolint:gosimple + _, _ = mapManager.localIPs[ip.String()] + } + }) +} + +func BenchmarkWGPosition(b *testing.B) { + wgIP := net.ParseIP("10.10.0.1") + + // Create two managers - one checks WG IP first, other checks it last + b.Run("WG_First", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + bm.setBitmapBit(wgIP) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) + + b.Run("WG_Last", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + // Fill with other IPs first + for i := 0; i < 15; i++ { + bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) + } + bm.setBitmapBit(wgIP) // Add WG IP last + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) +} diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go new file mode 100644 index 00000000000..984b6ad08e1 --- /dev/null +++ b/client/firewall/uspfilter/log/log.go @@ -0,0 +1,196 @@ +// Package logger provides a high-performance, non-blocking logger for userspace networking +package log + +import ( + "context" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxBatchSize = 1024 * 16 // 16KB max batch size + maxMessageSize = 1024 * 2 // 2KB per message + bufferSize = 1024 * 256 // 256KB ring buffer + defaultFlushInterval = 2 * time.Second +) + +// Level represents log severity +type Level uint32 + +const ( + LevelPanic Level = iota + LevelFatal + LevelError + LevelWarn + LevelInfo + LevelDebug + LevelTrace +) + +var levelStrings = map[Level]string{ + LevelPanic: "PANC", + LevelFatal: "FATL", + LevelError: "ERRO", + LevelWarn: "WARN", + LevelInfo: "INFO", + LevelDebug: "DEBG", + LevelTrace: "TRAC", +} + +// Logger is a high-performance, non-blocking logger +type Logger struct { + output io.Writer + level atomic.Uint32 + buffer *ringBuffer + shutdown chan struct{} + closeOnce sync.Once + wg sync.WaitGroup + + // Reusable buffer pool for formatting messages + bufPool sync.Pool +} + +func NewFromLogrus(logrusLogger *log.Logger) *Logger { + l := &Logger{ + output: logrusLogger.Out, + buffer: newRingBuffer(bufferSize), + shutdown: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + // Pre-allocate buffer for message formatting + b := make([]byte, 0, maxMessageSize) + return &b + }, + }, + } + logrusLevel := logrusLogger.GetLevel() + l.level.Store(uint32(logrusLevel)) + level := levelStrings[Level(logrusLevel)] + log.Debugf("New uspfilter logger created with loglevel %v", level) + + l.wg.Add(1) + go l.worker() + + return l +} + +func (l *Logger) SetLevel(level Level) { + l.level.Store(uint32(level)) + + log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) +} + +func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { + *buf = (*buf)[:0] + + // Timestamp + *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") + *buf = append(*buf, ' ') + + // Level + *buf = append(*buf, levelStrings[level]...) + *buf = append(*buf, ' ') + + // Message + if len(args) > 0 { + *buf = append(*buf, fmt.Sprintf(format, args...)...) + } else { + *buf = append(*buf, format...) + } + + *buf = append(*buf, '\n') +} + +func (l *Logger) log(level Level, format string, args ...interface{}) { + bufp := l.bufPool.Get().(*[]byte) + l.formatMessage(bufp, level, format, args...) + + if len(*bufp) > maxMessageSize { + *bufp = (*bufp)[:maxMessageSize] + } + _, _ = l.buffer.Write(*bufp) + + l.bufPool.Put(bufp) +} + +func (l *Logger) Error(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelError) { + l.log(LevelError, format, args...) + } +} + +func (l *Logger) Warn(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelWarn) { + l.log(LevelWarn, format, args...) + } +} + +func (l *Logger) Info(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelInfo) { + l.log(LevelInfo, format, args...) + } +} + +func (l *Logger) Debug(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelDebug) { + l.log(LevelDebug, format, args...) + } +} + +func (l *Logger) Trace(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelTrace) { + l.log(LevelTrace, format, args...) + } +} + +// worker periodically flushes the buffer +func (l *Logger) worker() { + defer l.wg.Done() + + ticker := time.NewTicker(defaultFlushInterval) + defer ticker.Stop() + + buf := make([]byte, 0, maxBatchSize) + + for { + select { + case <-l.shutdown: + return + case <-ticker.C: + // Read accumulated messages + n, _ := l.buffer.Read(buf[:cap(buf)]) + if n == 0 { + continue + } + + // Write batch + _, _ = l.output.Write(buf[:n]) + } + } +} + +// Stop gracefully shuts down the logger +func (l *Logger) Stop(ctx context.Context) error { + done := make(chan struct{}) + + l.closeOnce.Do(func() { + close(l.shutdown) + }) + + go func() { + l.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} diff --git a/client/firewall/uspfilter/log/ringbuffer.go b/client/firewall/uspfilter/log/ringbuffer.go new file mode 100644 index 00000000000..dbc8f1289a7 --- /dev/null +++ b/client/firewall/uspfilter/log/ringbuffer.go @@ -0,0 +1,85 @@ +package log + +import "sync" + +// ringBuffer is a simple ring buffer implementation +type ringBuffer struct { + buf []byte + size int + r, w int64 // Read and write positions + mu sync.Mutex +} + +func newRingBuffer(size int) *ringBuffer { + return &ringBuffer{ + buf: make([]byte, size), + size: size, + } +} + +func (r *ringBuffer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + if len(p) > r.size { + p = p[:r.size] + } + + n = len(p) + + // Write data, handling wrap-around + pos := int(r.w % int64(r.size)) + writeLen := min(len(p), r.size-pos) + copy(r.buf[pos:], p[:writeLen]) + + // If we have more data and need to wrap around + if writeLen < len(p) { + copy(r.buf, p[writeLen:]) + } + + // Update write position + r.w += int64(n) + + return n, nil +} + +func (r *ringBuffer) Read(p []byte) (n int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.w == r.r { + return 0, nil + } + + // Calculate available data accounting for wraparound + available := int(r.w - r.r) + if available < 0 { + available += r.size + } + available = min(available, r.size) + + // Limit read to buffer size + toRead := min(available, len(p)) + if toRead == 0 { + return 0, nil + } + + // Read data, handling wrap-around + pos := int(r.r % int64(r.size)) + readLen := min(toRead, r.size-pos) + n = copy(p, r.buf[pos:pos+readLen]) + + // If we need more data and need to wrap around + if readLen < toRead { + n += copy(p[readLen:toRead], r.buf[:toRead-readLen]) + } + + // Update read position + r.r += int64(n) + + return n, nil +} diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index c59d4b264ce..6a4415f7315 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -2,14 +2,15 @@ package uspfilter import ( "net" + "net/netip" "github.com/google/gopacket" firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -// Rule to handle management of rules -type Rule struct { +// PeerRule to handle management of rules +type PeerRule struct { id string ip net.IP ipLayer gopacket.LayerType @@ -24,6 +25,21 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *PeerRule) GetRuleID() string { + return r.id +} + +type RouteRule struct { + id string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action +} + +// GetRuleID returns the rule id +func (r *RouteRule) GetRuleID() string { return r.id } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go new file mode 100644 index 00000000000..a4c653b3b4b --- /dev/null +++ b/client/firewall/uspfilter/tracer.go @@ -0,0 +1,390 @@ +package uspfilter + +import ( + "fmt" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" +) + +type PacketStage int + +const ( + StageReceived PacketStage = iota + StageConntrack + StagePeerACL + StageRouting + StageRouteACL + StageForwarding + StageCompleted +) + +const msgProcessingCompleted = "Processing completed" + +func (s PacketStage) String() string { + return map[PacketStage]string{ + StageReceived: "Received", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + }[s] +} + +type ForwarderAction struct { + Action string + RemoteAddr string + Error error +} + +type TraceResult struct { + Timestamp time.Time + Stage PacketStage + Message string + Allowed bool + ForwarderAction *ForwarderAction +} + +type PacketTrace struct { + SourceIP net.IP + DestinationIP net.IP + Protocol string + SourcePort uint16 + DestinationPort uint16 + Direction fw.RuleDirection + Results []TraceResult +} + +type TCPState struct { + SYN bool + ACK bool + FIN bool + RST bool + PSH bool + URG bool +} + +type PacketBuilder struct { + SrcIP net.IP + DstIP net.IP + Protocol fw.Protocol + SrcPort uint16 + DstPort uint16 + ICMPType uint8 + ICMPCode uint8 + Direction fw.RuleDirection + PayloadSize int + TCPState *TCPState +} + +func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) { + t.Results = append(t.Results, TraceResult{ + Timestamp: time.Now(), + Stage: stage, + Message: message, + Allowed: allowed, + }) +} + +func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) { + t.Results = append(t.Results, TraceResult{ + Timestamp: time.Now(), + Stage: stage, + Message: message, + Allowed: allowed, + ForwarderAction: action, + }) +} + +func (p *PacketBuilder) Build() ([]byte, error) { + ip := p.buildIPLayer() + pktLayers := []gopacket.SerializableLayer{ip} + + transportLayer, err := p.buildTransportLayer(ip) + if err != nil { + return nil, err + } + pktLayers = append(pktLayers, transportLayer...) + + if p.PayloadSize > 0 { + payload := make([]byte, p.PayloadSize) + pktLayers = append(pktLayers, gopacket.Payload(payload)) + } + + return serializePacket(pktLayers) +} + +func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { + return &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), + SrcIP: p.SrcIP, + DstIP: p.DstIP, + } +} + +func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + switch p.Protocol { + case "tcp": + return p.buildTCPLayer(ip) + case "udp": + return p.buildUDPLayer(ip) + case "icmp": + return p.buildICMPLayer() + default: + return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) + } +} + +func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(p.SrcPort), + DstPort: layers.TCPPort(p.DstPort), + Window: 65535, + SYN: p.TCPState != nil && p.TCPState.SYN, + ACK: p.TCPState != nil && p.TCPState.ACK, + FIN: p.TCPState != nil && p.TCPState.FIN, + RST: p.TCPState != nil && p.TCPState.RST, + PSH: p.TCPState != nil && p.TCPState.PSH, + URG: p.TCPState != nil && p.TCPState.URG, + } + if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { + return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + } + return []gopacket.SerializableLayer{tcp}, nil +} + +func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + udp := &layers.UDP{ + SrcPort: layers.UDPPort(p.SrcPort), + DstPort: layers.UDPPort(p.DstPort), + } + if err := udp.SetNetworkLayerForChecksum(ip); err != nil { + return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + } + return []gopacket.SerializableLayer{udp}, nil +} + +func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), + } + if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply { + icmp.Id = uint16(1) + icmp.Seq = uint16(1) + } + return []gopacket.SerializableLayer{icmp}, nil +} + +func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) { + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil { + return nil, fmt.Errorf("serialize packet: %w", err) + } + return buf.Bytes(), nil +} + +func getIPProtocolNumber(protocol fw.Protocol) int { + switch protocol { + case fw.ProtocolTCP: + return int(layers.IPProtocolTCP) + case fw.ProtocolUDP: + return int(layers.IPProtocolUDP) + case fw.ProtocolICMP: + return int(layers.IPProtocolICMPv4) + default: + return 0 + } +} + +func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) { + packetData, err := builder.Build() + if err != nil { + return nil, fmt.Errorf("build packet: %w", err) + } + + return m.TracePacket(packetData, builder.Direction), nil +} + +func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace { + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + trace := &PacketTrace{Direction: direction} + + // Initial packet decoding + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) + return trace + } + + // Extract base packet info + srcIP, dstIP := m.extractIPs(d) + trace.SourceIP = srcIP + trace.DestinationIP = dstIP + + // Determine protocol and ports + switch d.decoded[1] { + case layers.LayerTypeTCP: + trace.Protocol = "TCP" + trace.SourcePort = uint16(d.tcp.SrcPort) + trace.DestinationPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + trace.Protocol = "UDP" + trace.SourcePort = uint16(d.udp.SrcPort) + trace.DestinationPort = uint16(d.udp.DstPort) + case layers.LayerTypeICMPv4: + trace.Protocol = "ICMP" + } + + trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", + trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true) + + if direction == fw.RuleDirectionOUT { + return m.traceOutbound(packetData, trace) + } + + return m.traceInbound(packetData, trace, d, srcIP, dstIP) +} + +func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { + return trace + } + + if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { + return trace + } + + if !m.handleRouting(trace) { + return trace + } + + if m.nativeRouter { + return m.handleNativeRouter(trace) + } + + return m.handleRouteACLs(trace, d, srcIP, dstIP) +} + +func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { + allowed := m.isValidTrackedConnection(d, srcIP, dstIP) + msg := "No existing connection found" + if allowed { + msg = m.buildConntrackStateMessage(d) + trace.AddResult(StageConntrack, msg, true) + trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true) + return true + } + trace.AddResult(StageConntrack, msg, false) + return false +} + +func (m *Manager) buildConntrackStateMessage(d *decoder) string { + msg := "Matched existing connection state" + switch d.decoded[1] { + case layers.LayerTypeTCP: + flags := getTCPFlags(&d.tcp) + msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)", + flags&conntrack.TCPSyn != 0, + flags&conntrack.TCPAck != 0, + flags&conntrack.TCPRst != 0, + flags&conntrack.TCPFin != 0) + case layers.LayerTypeICMPv4: + msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) + } + return msg +} + +func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { + if !m.localForwarding { + trace.AddResult(StageRouting, "Local forwarding disabled", false) + trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) + return true + } + + trace.AddResult(StageRouting, "Packet destined for local delivery", true) + blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) + + msg := "Allowed by peer ACL rules" + if blocked { + msg = "Blocked by peer ACL rules" + } + trace.AddResult(StagePeerACL, msg, !blocked) + + if m.netstack { + m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) + } + + trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) + return true +} + +func (m *Manager) handleRouting(trace *PacketTrace) bool { + if !m.routingEnabled { + trace.AddResult(StageRouting, "Routing disabled", false) + trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) + return false + } + trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true) + return true +} + +func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace { + trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true) + trace.AddResult(StageForwarding, "Forwarding via native router", true) + trace.AddResult(StageCompleted, msgProcessingCompleted, true) + return trace +} + +func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { + proto := getProtocolFromPacket(d) + srcPort, dstPort := getPortsFromPacket(d) + allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + + msg := "Allowed by route ACLs" + if !allowed { + msg = "Blocked by route ACLs" + } + trace.AddResult(StageRouteACL, msg, allowed) + + if allowed && m.forwarder != nil { + m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) + } + + trace.AddResult(StageCompleted, msgProcessingCompleted, allowed) + return trace +} + +func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) { + fwdAction := &ForwarderAction{ + Action: action, + RemoteAddr: remoteAddr, + } + trace.AddResultWithForwarder(StageForwarding, + fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction) +} + +func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { + // will create or update the connection state + dropped := m.processOutgoingHooks(packetData) + if dropped { + trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) + } else { + trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true) + } + return trace +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 757249b2dd5..889e4cbb1a9 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -1,11 +1,14 @@ package uspfilter import ( + "errors" "fmt" "net" "net/netip" "os" + "slices" "strconv" + "strings" "sync" "github.com/google/gopacket" @@ -14,28 +17,48 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 -const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" +const ( + // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. + EnvDisableConntrack = "NB_DISABLE_CONNTRACK" -var ( - errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") -) + // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. + EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" -// IFaceMapper defines subset methods of interface required for manager -type IFaceMapper interface { - SetFilter(device.PacketFilter) error - Address() iface.WGAddress -} + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. + EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" + + // EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack + // Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible + EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" +) // RuleSet is a set of rules grouped by a string key -type RuleSet map[string]Rule +type RuleSet map[string]PeerRule + +type RouteRules []RouteRule + +func (r RouteRules) Sort() { + slices.SortStableFunc(r, func(a, b RouteRule) int { + // Deny rules come first + if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { + return -1 + } + if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop { + return 1 + } + return strings.Compare(a.id, b.id) + }) +} // Manager userspace firewall manager type Manager struct { @@ -43,17 +66,32 @@ type Manager struct { outgoingRules map[string]RuleSet // incomingRules is used for filtering and hooks incomingRules map[string]RuleSet + routeRules RouteRules wgNetwork *net.IPNet decoders sync.Pool - wgIface IFaceMapper + wgIface common.IFaceMapper nativeFirewall firewall.Manager mutex sync.RWMutex - stateful bool + // indicates whether we forward packets not destined for ourselves + routingEnabled bool + // indicates whether we leave forwarding and filtering to the native firewall + nativeRouter bool + // indicates whether we track outbound connections + stateful bool + // indicates whether wireguards runs in netstack mode + netstack bool + // indicates whether we forward local traffic to the native stack + localForwarding bool + + localipmanager *localIPManager + udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker + forwarder *forwarder.Forwarder + logger *nblog.Logger } // decoder for packages @@ -70,22 +108,32 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface IFaceMapper) (*Manager, error) { - return create(iface) +func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { + return create(iface, nil, disableServerRoutes) } -func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { - mgr, err := create(iface) +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { + if nativeFirewall == nil { + return nil, errors.New("native firewall is nil") + } + + mgr, err := create(iface, nativeFirewall, disableServerRoutes) if err != nil { return nil, err } - mgr.nativeFirewall = nativeFirewall return mgr, nil } -func create(iface IFaceMapper) (*Manager, error) { - disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { + disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err) + } + enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding)) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) + } m := &Manager{ decoders: sync.Pool{ @@ -101,52 +149,161 @@ func create(iface IFaceMapper) (*Manager, error) { return d }, }, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), - wgIface: iface, - stateful: !disableConntrack, + nativeFirewall: nativeFirewall, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), + wgIface: iface, + localipmanager: newLocalIPManager(), + routingEnabled: false, + stateful: !disableConntrack, + logger: nblog.NewFromLogrus(log.StandardLogger()), + netstack: netstack.IsEnabled(), + // default true for non-netstack, for netstack only if explicitly enabled + localForwarding: !netstack.IsEnabled() || enableLocalForwarding, + } + + if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { + return nil, fmt.Errorf("update local IPs: %w", err) } // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + m.determineRouting(iface, disableServerRoutes) + + if err := m.blockInvalidRouted(iface); err != nil { + log.Errorf("failed to block invalid routed traffic: %v", err) } if err := iface.SetFilter(m); err != nil { - return nil, err + return nil, fmt.Errorf("set filter: %w", err) } return m, nil } +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { + if m.forwarder == nil { + return nil + } + wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) + if err != nil { + return fmt.Errorf("parse wireguard network: %w", err) + } + log.Debugf("blocking invalid routed traffic for %s", wgPrefix) + + if _, err := m.AddRouteFiltering( + []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, + wgPrefix, + firewall.ProtocolALL, + nil, + nil, + firewall.ActionDrop, + ); err != nil { + return fmt.Errorf("block wg nte : %w", err) + } + + // TODO: Block networks that we're a client of + + return nil +} + +func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) { + disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)) + forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)) + + switch { + case disableUspRouting: + m.routingEnabled = false + m.nativeRouter = false + log.Info("userspace routing is disabled") + + case disableServerRoutes: + // if server routes are disabled we will let packets pass to the native stack + m.routingEnabled = true + m.nativeRouter = true + + log.Info("server routes are disabled") + + case forceUserspaceRouter: + m.routingEnabled = true + m.nativeRouter = false + + log.Info("userspace routing is forced") + + case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): + // if the OS supports routing natively, then we don't need to filter/route ourselves + // netstack mode won't support native routing as there is no interface + + m.routingEnabled = true + m.nativeRouter = true + + log.Info("native routing is enabled") + + default: + m.routingEnabled = true + m.nativeRouter = false + + log.Info("userspace routing enabled by default") + } + + // netstack needs the forwarder for local traffic + if m.netstack && m.localForwarding || + m.routingEnabled && !m.nativeRouter { + + m.initForwarder(iface) + } +} + +// initForwarder initializes the forwarder, it disables routing on errors +func (m *Manager) initForwarder(iface common.IFaceMapper) { + // Only supported in userspace mode as we need to inject packets back into wireguard directly + intf := iface.GetWGDevice() + if intf == nil { + log.Info("forwarding not supported") + m.routingEnabled = false + return + } + + forwarder, err := forwarder.New(iface, m.logger, m.netstack) + if err != nil { + log.Errorf("failed to create forwarder: %v", err) + m.routingEnabled = false + return + } + + m.forwarder = forwarder +} + func (m *Manager) Init(*statemanager.Manager) error { return nil } func (m *Manager) IsServerRouteSupported() bool { - if m.nativeFirewall == nil { - return false - } else { - return true - } + return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil } func (m *Manager) AddNatRule(pair firewall.RouterPair) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.AddNatRule(pair) } - return m.nativeFirewall.AddNatRule(pair) + + // userspace routed packets are always SNATed to the inbound direction + // TODO: implement outbound SNAT + return nil } // RemoveNatRule removes a routing firewall rule func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.RemoveNatRule(pair) } - return m.nativeFirewall.RemoveNatRule(pair) + return nil } // AddPeerFiltering rule to the firewall @@ -162,7 +319,7 @@ func (m *Manager) AddPeerFiltering( _ string, comment string, ) ([]firewall.Rule, error) { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, ipLayer: layers.LayerTypeIPv6, @@ -205,18 +362,56 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { - if m.nativeFirewall == nil { - return nil, errRouteNotSupported +func (m *Manager) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := uuid.New().String() + rule := RouteRule{ + id: ruleID, + sources: sources, + destination: destination, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, } - return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) + + m.routeRules = append(m.routeRules, rule) + m.routeRules.Sort() + + return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.DeleteRouteRule(rule) } - return m.nativeFirewall.DeleteRouteRule(rule) + + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := rule.GetRuleID() + idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + return r.id == ruleID + }) + if idx < 0 { + return fmt.Errorf("route rule not found: %s", ruleID) + } + + m.routeRules = slices.Delete(m.routeRules, idx, idx+1) + return nil } // DeletePeerRule from the firewall by rule definition @@ -224,7 +419,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - r, ok := rule.(*Rule) + r, ok := rule.(*PeerRule) if !ok { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } @@ -255,10 +450,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool { // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules) + return m.dropFilter(packetData) +} + +// UpdateLocalIPs updates the list of local IPs +func (m *Manager) UpdateLocalIPs() error { + return m.localipmanager.UpdateLocalIPs(m.wgIface) } -// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -279,18 +478,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - // Always process UDP hooks - if d.decoded[1] == layers.LayerTypeUDP { - // Track UDP state only if enabled - if m.stateful { - m.trackUDPOutbound(d, srcIP, dstIP) - } - return m.checkUDPHooks(d, dstIP, packetData) - } - - // Track other protocols only if stateful mode is enabled + // Track all protocols if stateful mode is enabled if m.stateful { switch d.decoded[1] { + case layers.LayerTypeUDP: + m.trackUDPOutbound(d, srcIP, dstIP) case layers.LayerTypeTCP: m.trackTCPOutbound(d, srcIP, dstIP) case layers.LayerTypeICMPv4: @@ -298,6 +490,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { } } + // Process UDP hooks even if stateful mode is disabled + if d.decoded[1] == layers.LayerTypeUDP { + return m.checkUDPHooks(d, dstIP, packetData) + } + return false } @@ -379,10 +576,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { } } -// dropFilter implements filtering logic for incoming packets -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { - // TODO: Disable router if --disable-server-router is set - +// dropFilter implements filtering logic for incoming packets. +// If it returns true, the packet should be dropped. +func (m *Manager) dropFilter(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -390,25 +586,120 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { defer m.decoders.Put(d) if !m.isValidPacket(d, packetData) { + m.logger.Trace("Invalid packet structure") return true } srcIP, dstIP := m.extractIPs(d) if srcIP == nil { - log.Errorf("unknown layer: %v", d.decoded[0]) + m.logger.Error("Unknown network layer: %v", d.decoded[0]) return true } - if !m.isWireguardTraffic(srcIP, dstIP) { + // For all inbound traffic, first check if it matches a tracked connection. + // This must happen before any other filtering because the packets are statefully tracked. + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { return false } - // Check connection state only if enabled - if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + if m.localipmanager.IsLocalIP(dstIP) { + return m.handleLocalTraffic(d, srcIP, dstIP, packetData) + } + + return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) +} + +// handleLocalTraffic handles local traffic. +// If it returns true, the packet should be dropped. +func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { + if !m.localForwarding { + m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP) + return true + } + + if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) { + m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s", + srcIP, dstIP) + return true + } + + // if running in netstack mode we need to pass this to the forwarder + if m.netstack { + m.handleNetstackLocalTraffic(packetData) + + // don't process this packet further + return true + } + + return false +} +func (m *Manager) handleNetstackLocalTraffic(packetData []byte) { + if m.forwarder == nil { + return + } + + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + m.logger.Error("Failed to inject local packet: %v", err) + } +} + +// handleRoutedTraffic handles routed traffic. +// If it returns true, the packet should be dropped. +func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { + // Drop if routing is disabled + if !m.routingEnabled { + m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) + return true + } + + // Pass to native stack if native router is enabled or forced + if m.nativeRouter { return false } - return m.applyRules(srcIP, packetData, rules, d) + // Get protocol and ports for route ACL check + proto := getProtocolFromPacket(d) + srcPort, dstPort := getPortsFromPacket(d) + + // Check route ACLs + if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) { + m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", + srcIP, srcPort, dstIP, dstPort, proto) + return true + } + + // Let forwarder handle the packet if it passed route ACLs + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + m.logger.Error("Failed to inject incoming packet: %v", err) + } + + // Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture + return true +} + +func getProtocolFromPacket(d *decoder) firewall.Protocol { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return firewall.ProtocolTCP + case layers.LayerTypeUDP: + return firewall.ProtocolUDP + case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + return firewall.ProtocolICMP + default: + return firewall.ProtocolALL + } +} + +func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.SrcPort), uint16(d.udp.DstPort) + default: + return 0, 0 + } } func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { @@ -424,10 +715,6 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { return true } -func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { - return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) -} - func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { switch d.decoded[1] { case layers.LayerTypeTCP: @@ -462,7 +749,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool return false } -func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { +// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed +func (m *Manager) isSpecialICMP(d *decoder) bool { + if d.decoded[1] != layers.LayerTypeICMPv4 { + return false + } + + icmpType := d.icmp4.TypeCode.Type() + return icmpType == layers.ICMPv4TypeDestinationUnreachable || + icmpType == layers.ICMPv4TypeTimeExceeded +} + +func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if m.isSpecialICMP(d) { + return false + } + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } @@ -496,7 +798,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { return false } -func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { +func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) { payloadLayer := d.decoded[1] for _, rule := range rules { if rule.matchByIP && !ip.Equal(rule.ip) { @@ -533,6 +835,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode return false, false } +// routeACLsPass returns treu if the packet is allowed by the route ACLs +func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + srcAddr := netip.AddrFrom4([4]byte(srcIP.To4())) + dstAddr := netip.AddrFrom4([4]byte(dstIP.To4())) + + for _, rule := range m.routeRules { + if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { + return rule.action == firewall.ActionAccept + } + } + return false +} + +func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + if !rule.destination.Contains(dstAddr) { + return false + } + + sourceMatched := false + for _, src := range rule.sources { + if src.Contains(srcAddr) { + sourceMatched = true + break + } + } + if !sourceMatched { + return false + } + + if rule.proto != firewall.ProtocolALL && rule.proto != proto { + return false + } + + if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { + if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { + return false + } + } + + return true +} + // SetNetwork of the wireguard interface to which filtering applied func (m *Manager) SetNetwork(network *net.IPNet) { m.wgNetwork = network @@ -544,7 +891,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) AddUDPPacketHook( in bool, ip net.IP, dPort uint16, hook func([]byte) bool, ) string { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, protoLayer: layers.LayerTypeUDP, @@ -561,12 +908,12 @@ func (m *Manager) AddUDPPacketHook( m.mutex.Lock() if in { if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(map[string]Rule) + m.incomingRules[r.ip.String()] = make(map[string]PeerRule) } m.incomingRules[r.ip.String()][r.id] = r } else { if _, ok := m.outgoingRules[r.ip.String()]; !ok { - m.outgoingRules[r.ip.String()] = make(map[string]Rule) + m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) } m.outgoingRules[r.ip.String()][r.id] = r } @@ -599,3 +946,10 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } + +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(level log.Level) { + if m.logger != nil { + m.logger.SetLevel(nblog.Level(level)) + } +} diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 46bc4439d83..875bb2425b1 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -1,9 +1,12 @@ +//go:build uspbench + package uspfilter import ( "fmt" "math/rand" "net" + "net/netip" "os" "strings" "testing" @@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Measure inbound packet processing b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(testIn, manager.incomingRules) + manager.dropFilter(testIn) } }) } @@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { manager.processOutgoingHooks(syn) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) // ACK ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) manager.processOutgoingHooks(ack) @@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) // ACK ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], @@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { // First outbound data manager.processOutgoingHooks(outPackets[connIdx]) // Then inbound response - this is what we're actually measuring - manager.dropFilter(inPackets[connIdx], manager.incomingRules) + manager.dropFilter(inPackets[connIdx]) } }) } @@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) { // Connection establishment manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck, manager.incomingRules) + manager.dropFilter(p.synAck) manager.processOutgoingHooks(p.ack) // Data transfer manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response, manager.incomingRules) + manager.dropFilter(p.response) // Connection teardown manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer, manager.incomingRules) - manager.dropFilter(p.finServer, manager.incomingRules) + manager.dropFilter(p.ackServer) + manager.dropFilter(p.finServer) manager.processOutgoingHooks(p.ackClient) } }) @@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) @@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { // Simulate bidirectional traffic manager.processOutgoingHooks(outPackets[connIdx]) - manager.dropFilter(inPackets[connIdx], manager.incomingRules) + manager.dropFilter(inPackets[connIdx]) } }) }) @@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { // Full connection lifecycle manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck, manager.incomingRules) + manager.dropFilter(p.synAck) manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response, manager.incomingRules) + manager.dropFilter(p.response) manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer, manager.incomingRules) - manager.dropFilter(p.finServer, manager.incomingRules) + manager.dropFilter(p.ackServer) + manager.dropFilter(p.finServer) manager.processOutgoingHooks(p.ackClient) } }) @@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) return buf.Bytes() } + +func BenchmarkRouteACLs(b *testing.B) { + manager := setupRoutedManager(b, "10.10.0.100/16") + + // Add several route rules to simulate real-world scenario + rules := []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + port *fw.Port + }{ + { + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + port: &fw.Port{Values: []uint16{80, 443}}, + }, + { + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/12"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolICMP, + }, + { + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("192.168.0.0/16"), + proto: fw.ProtocolUDP, + port: &fw.Port{Values: []uint16{53}}, + }, + } + + for _, r := range rules { + _, err := manager.AddRouteFiltering( + r.sources, + r.dest, + r.proto, + nil, + r.port, + fw.ActionAccept, + ) + if err != nil { + b.Fatal(err) + } + } + + // Test cases that exercise different matching scenarios + cases := []struct { + srcIP string + dstIP string + proto fw.Protocol + dstPort uint16 + }{ + {"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule + {"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule + {"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule + {"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range cases { + srcIP := net.ParseIP(tc.srcIP) + dstIP := net.ParseIP(tc.dstIP) + manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) + } + } +} diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go new file mode 100644 index 00000000000..d7aebb1aab0 --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -0,0 +1,1014 @@ +package uspfilter + +import ( + "net" + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + wgdevice "golang.zx2c4.com/wireguard/device" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/mocks" +) + +func TestPeerACLFiltering(t *testing.T) { + localIP := net.ParseIP("100.10.0.100") + wgNet := &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: localIP, + Network: wgNet, + } + }, + } + + manager, err := Create(ifaceMock, false) + require.NoError(t, err) + require.NotNil(t, manager) + + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) + }) + + manager.wgNetwork = wgNet + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + ruleIP string + ruleProto fw.Protocol + ruleSrcPort *fw.Port + ruleDstPort *fw.Port + ruleAction fw.Action + shouldBeBlocked bool + }{ + { + name: "Allow TCP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow UDP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow ICMP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow all traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow traffic from non-WG source", + srcIP: "192.168.1.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "192.168.1.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow all traffic with 0.0.0.0 rule", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "0.0.0.0", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Allow TCP traffic with source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 31999, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + } + + t.Run("Implicit DROP (no rules)", func(t *testing.T) { + packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) + isDropped := manager.DropIncoming(packet) + require.True(t, isDropped, "Packet should be dropped when no rules exist") + }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rules, err := manager.AddPeerFiltering( + net.ParseIP(tc.ruleIP), + tc.ruleProto, + tc.ruleSrcPort, + tc.ruleDstPort, + tc.ruleAction, + "", + tc.name, + ) + require.NoError(t, err) + require.NotEmpty(t, rules) + + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + + packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + isDropped := manager.DropIncoming(packet) + require.Equal(t, tc.shouldBeBlocked, isDropped) + }) + } +} + +func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { + t.Helper() + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + } + + var err error + switch proto { + case fw.ProtocolTCP: + ipLayer.Protocol = layers.IPProtocolTCP + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + err = tcp.SetNetworkLayerForChecksum(ipLayer) + require.NoError(t, err) + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp) + + case fw.ProtocolUDP: + ipLayer.Protocol = layers.IPProtocolUDP + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + err = udp.SetNetworkLayerForChecksum(ipLayer) + require.NoError(t, err) + err = gopacket.SerializeLayers(buf, opts, ipLayer, udp) + + case fw.ProtocolICMP: + ipLayer.Protocol = layers.IPProtocolICMPv4 + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + } + err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp) + + default: + err = gopacket.SerializeLayers(buf, opts, ipLayer) + } + + require.NoError(t, err) + return buf.Bytes() +} + +func setupRoutedManager(tb testing.TB, network string) *Manager { + tb.Helper() + + ctrl := gomock.NewController(tb) + dev := mocks.NewMockDevice(ctrl) + dev.EXPECT().MTU().Return(1500, nil).AnyTimes() + + localIP, wgNet, err := net.ParseCIDR(network) + require.NoError(tb, err) + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: localIP, + Network: wgNet, + } + }, + GetDeviceFunc: func() *device.FilteredDevice { + return &device.FilteredDevice{Device: dev} + }, + GetWGDeviceFunc: func() *wgdevice.Device { + return &wgdevice.Device{} + }, + } + + manager, err := Create(ifaceMock, false) + require.NoError(tb, err) + require.NotNil(tb, manager) + require.True(tb, manager.routingEnabled) + require.False(tb, manager.nativeRouter) + + tb.Cleanup(func() { + require.NoError(tb, manager.Reset(nil)) + }) + + return manager +} + +func TestRouteACLFiltering(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + type rule struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + } + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + rule rule + shouldPass bool + }{ + { + name: "Allow TCP with specific source and destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow any source to specific destination", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow any source to any destination", + srcIP: "172.16.0.1", + dstIP: "203.0.113.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow UDP DNS traffic", + srcIP: "100.10.0.1", + dstIP: "192.168.1.53", + proto: fw.ProtocolUDP, + srcPort: 54321, + dstPort: 53, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + dstPort: &fw.Port{Values: []uint16{53}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ICMP to any destination", + srcIP: "100.10.0.1", + dstIP: "8.8.8.8", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolICMP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow all protocols but specific port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Implicit deny - wrong destination port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Implicit deny - wrong protocol", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Implicit deny - wrong source network", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Source port match", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{Values: []uint16{12345}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Multiple source networks", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol without ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol with specific ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Multiple source networks with mismatched protocol", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + // Should not match TCP rule + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Allow multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow multiple source ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol with both src and dst ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + srcPort: &fw.Port{Values: []uint16{12345}}, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Port Range - Within Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Port Range - Outside Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Source Port Range - Within Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{ + IsRange: true, + Values: []uint16{32000, 33000}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Mixed Port Specification - Range and Single", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{ + IsRange: true, + Values: []uint16{32000, 33000}, + }, + dstPort: &fw.Port{ + Values: []uint16{443}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Edge Case - Port at Range Boundary", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8100, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "UDP Port Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 5060, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{5060, 5070}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "ALL Protocol with Port Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Drop TCP traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop all traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop traffic from multiple source networks", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rule, err := manager.AddRouteFiltering( + tc.rule.sources, + tc.rule.dest, + tc.rule.proto, + tc.rule.srcPort, + tc.rule.dstPort, + tc.rule.action, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + t.Cleanup(func() { + require.NoError(t, manager.DeleteRouteRule(rule)) + }) + + srcIP := net.ParseIP(tc.srcIP) + dstIP := net.ParseIP(tc.dstIP) + + // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed + // to the forwarder + isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) + require.Equal(t, tc.shouldPass, isAllowed) + }) + } +} + +func TestRouteACLOrder(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + type testCase struct { + name string + rules []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + } + packets []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + } + } + + testCases := []testCase{ + { + name: "Drop rules take precedence over accept", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept rule added first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80, 443}}, + action: fw.ActionAccept, + }, + { + // Drop rule added second but should be evaluated first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by the drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (port 80 not in drop rule) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: true, + }, + }, + }, + { + name: "Multiple drop rules take precedence", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept all + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolALL, + action: fw.ActionAccept, + }, + { + // Drop specific port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + { + // Drop different port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by first drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be dropped by second drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (different port) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + shouldPass: true, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var rules []fw.Rule + for _, r := range tc.rules { + rule, err := manager.AddRouteFiltering( + r.sources, + r.dest, + r.proto, + r.srcPort, + r.dstPort, + r.action, + ) + require.NoError(t, err) + require.NotNil(t, rule) + rules = append(rules, rule) + } + + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeleteRouteRule(rule)) + } + }) + + for i, p := range tc.packets { + srcIP := net.ParseIP(p.srcIP) + dstIP := net.ParseIP(p.dstIP) + + isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) + require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) + } + }) + } +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 9d795de691f..089bf8f5531 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -9,17 +9,38 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + type IFaceMock struct { - SetFilterFunc func(device.PacketFilter) error - AddressFunc func() iface.WGAddress + SetFilterFunc func(device.PacketFilter) error + AddressFunc func() iface.WGAddress + GetWGDeviceFunc func() *wgdevice.Device + GetDeviceFunc func() *device.FilteredDevice +} + +func (i *IFaceMock) GetWGDevice() *wgdevice.Device { + if i.GetWGDeviceFunc == nil { + return nil + } + return i.GetWGDeviceFunc() +} + +func (i *IFaceMock) GetDevice() *device.FilteredDevice { + if i.GetDeviceFunc == nil { + return nil + } + return i.GetDeviceFunc() } func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { @@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) - var addedRule Rule + var addedRule PeerRule if tt.in { if len(manager.incomingRules[tt.ip.String()]) != 1 { t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) @@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.incomingRules) { + if m.dropFilter(buf.Bytes()) { t.Errorf("expected packet to be accepted") return } @@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface) + manager, err := Create(iface, false) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, false) require.NoError(t, err) time.Sleep(time.Second) @@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.decoders = sync.Pool{ New: func() any { d := &decoder{ @@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { for _, cp := range checkPoints { time.Sleep(cp.sleep) - drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + drop = manager.dropFilter(inboundBuf.Bytes()) require.Equal(t, cp.shouldAllow, !drop, cp.description) // If the connection should still be valid, verify it exists @@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Verify the invalid packet is dropped - drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + drop = manager.dropFilter(testBuf.Bytes()) require.True(t, drop, tc.description) }) } diff --git a/client/iface/device.go b/client/iface/device.go index 0d4e6914554..2a170adfb41 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -3,6 +3,8 @@ package iface import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -15,4 +17,5 @@ type WGTunDevice interface { DeviceName() string Close() error FilteredDevice() *device.FilteredDevice + Device() *wgdevice.Device } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index b5a128bc1cc..fe7ed175207 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 0dfed4d9071..3314b576b25 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -9,6 +9,7 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -151,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string { return t.name } +// Device returns the wireguard device, not applicable for kernel devices +func (t *TunKernelDevice) Device() *device.Device { + return nil +} + func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f5d39e9e074..c7d297187ed 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } + +// Device returns the wireguard device +func (t *TunNetstackDevice) Device() *device.Device { + return t.device +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 3562f312ded..4ac87aecbb8 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -124,6 +124,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *USPDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 86968d06d7e..e603d7696f9 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 3d15080fff4..028f6fa7d78 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -1,6 +1,8 @@ package iface import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -13,4 +15,5 @@ type WGTunDevice interface { DeviceName() string Close() error FilteredDevice() *device.FilteredDevice + Device() *wgdevice.Device } diff --git a/client/iface/iface.go b/client/iface/iface.go index 1fb9c269179..64219975f5d 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -11,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice { return w.tun.FilteredDevice() } +// GetWGDevice returns the WireGuard device +func (w *WGIface) GetWGDevice() *wgdevice.Device { + return w.tun.Device() +} + // GetStats returns the last handshake time, rx and tx bytes for the given peer func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { return w.configurer.GetStats(peerKey) diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index d91a7224ff2..5f57bc82159 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -4,6 +4,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -29,6 +30,7 @@ type MockWGIface struct { SetFilterFunc func(filter device.PacketFilter) error GetFilterFunc func() device.PacketFilter GetDeviceFunc func() *device.FilteredDevice + GetWGDeviceFunc func() *wgdevice.Device GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy @@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice { return m.GetDeviceFunc() } +func (m *MockWGIface) GetWGDevice() *wgdevice.Device { + return m.GetWGDeviceFunc() +} + func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } func (m *MockWGIface) GetProxy() wgproxy.Proxy { - //TODO implement me - panic("implement me") + return m.GetProxyFunc() } diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index f5ab2953905..472ab45f9d8 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -6,6 +6,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -32,5 +33,6 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 96eec52a502..c9183cafdce 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -4,6 +4,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -30,6 +31,7 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 6049b4f48e2..217dbce9f45 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) { IP: ip, Network: network, }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return @@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { IP: ip, Network: network, }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 3ed12b6dd76..08aa4fd5a01 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,6 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + wgdevice "golang.zx2c4.com/wireguard/device" + iface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) } + +// GetDevice mocks base method. +func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDevice") + ret0, _ := ret[0].(*device.FilteredDevice) + return ret0 +} + +// GetDevice indicates an expected call of GetDevice. +func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice)) +} + +// GetWGDevice mocks base method. +func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWGDevice") + ret0, _ := ret[0].(*wgdevice.Device) + return ret0 +} + +// GetWGDevice indicates an expected call of GetWGDevice. +func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice)) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c166820c457..14ff1bb713e 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface) + pf, err := uspfilter.Create(wgIface, false) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index 335729d92f8..14e0d348fda 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -42,13 +42,13 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgm "github.com/netbirdio/netbird/management/client" - "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/relay/client" @@ -193,6 +193,10 @@ type Peer struct { WgAllowedIps string } +type localIpUpdater interface { + UpdateLocalIPs() error +} + // NewEngine creates a new Connection Engine with probes attached func NewEngine( clientCtx context.Context, @@ -433,7 +437,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -883,6 +887,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } + if e.firewall != nil { + if localipfw, ok := e.firewall.(localIpUpdater); ok { + if err := localipfw.UpdateLocalIPs(); err != nil { + log.Errorf("failed to update local IPs: %v", err) + } + } + } + // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) @@ -1446,6 +1458,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager } +// GetFirewallManager returns the firewall manager +func (e *Engine) GetFirewallManager() manager.Manager { + return e.firewall +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { @@ -1657,6 +1674,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// GetWgAddr returns the wireguard address +func (e *Engine) GetWgAddr() net.IP { + if e.wgInterface == nil { + return nil + } + return e.wgInterface.Address().IP +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { if !enabled { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6f73fb166c6..34bd67893d3 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -422,11 +422,6 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] haID := newRoute.GetHAUniqueID() if newRoute.Peer == m.pubKey { ownNetworkIDs[haID] = true - // only linux is supported for now - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } newServerRoutesMap[newRoute.ID] = newRoute } } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 30f7473cd00..c9651efed9c 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2571,6 +2571,330 @@ func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{39} } +type TCPFlags struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Syn bool `protobuf:"varint,1,opt,name=syn,proto3" json:"syn,omitempty"` + Ack bool `protobuf:"varint,2,opt,name=ack,proto3" json:"ack,omitempty"` + Fin bool `protobuf:"varint,3,opt,name=fin,proto3" json:"fin,omitempty"` + Rst bool `protobuf:"varint,4,opt,name=rst,proto3" json:"rst,omitempty"` + Psh bool `protobuf:"varint,5,opt,name=psh,proto3" json:"psh,omitempty"` + Urg bool `protobuf:"varint,6,opt,name=urg,proto3" json:"urg,omitempty"` +} + +func (x *TCPFlags) Reset() { + *x = TCPFlags{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[40] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TCPFlags) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TCPFlags) ProtoMessage() {} + +func (x *TCPFlags) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[40] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. +func (*TCPFlags) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{40} +} + +func (x *TCPFlags) GetSyn() bool { + if x != nil { + return x.Syn + } + return false +} + +func (x *TCPFlags) GetAck() bool { + if x != nil { + return x.Ack + } + return false +} + +func (x *TCPFlags) GetFin() bool { + if x != nil { + return x.Fin + } + return false +} + +func (x *TCPFlags) GetRst() bool { + if x != nil { + return x.Rst + } + return false +} + +func (x *TCPFlags) GetPsh() bool { + if x != nil { + return x.Psh + } + return false +} + +func (x *TCPFlags) GetUrg() bool { + if x != nil { + return x.Urg + } + return false +} + +type TracePacketRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SourceIp string `protobuf:"bytes,1,opt,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"` + DestinationIp string `protobuf:"bytes,2,opt,name=destination_ip,json=destinationIp,proto3" json:"destination_ip,omitempty"` + Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + SourcePort uint32 `protobuf:"varint,4,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"` + DestinationPort uint32 `protobuf:"varint,5,opt,name=destination_port,json=destinationPort,proto3" json:"destination_port,omitempty"` + Direction string `protobuf:"bytes,6,opt,name=direction,proto3" json:"direction,omitempty"` + TcpFlags *TCPFlags `protobuf:"bytes,7,opt,name=tcp_flags,json=tcpFlags,proto3,oneof" json:"tcp_flags,omitempty"` + IcmpType *uint32 `protobuf:"varint,8,opt,name=icmp_type,json=icmpType,proto3,oneof" json:"icmp_type,omitempty"` + IcmpCode *uint32 `protobuf:"varint,9,opt,name=icmp_code,json=icmpCode,proto3,oneof" json:"icmp_code,omitempty"` +} + +func (x *TracePacketRequest) Reset() { + *x = TracePacketRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[41] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TracePacketRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TracePacketRequest) ProtoMessage() {} + +func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[41] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. +func (*TracePacketRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{41} +} + +func (x *TracePacketRequest) GetSourceIp() string { + if x != nil { + return x.SourceIp + } + return "" +} + +func (x *TracePacketRequest) GetDestinationIp() string { + if x != nil { + return x.DestinationIp + } + return "" +} + +func (x *TracePacketRequest) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *TracePacketRequest) GetSourcePort() uint32 { + if x != nil { + return x.SourcePort + } + return 0 +} + +func (x *TracePacketRequest) GetDestinationPort() uint32 { + if x != nil { + return x.DestinationPort + } + return 0 +} + +func (x *TracePacketRequest) GetDirection() string { + if x != nil { + return x.Direction + } + return "" +} + +func (x *TracePacketRequest) GetTcpFlags() *TCPFlags { + if x != nil { + return x.TcpFlags + } + return nil +} + +func (x *TracePacketRequest) GetIcmpType() uint32 { + if x != nil && x.IcmpType != nil { + return *x.IcmpType + } + return 0 +} + +func (x *TracePacketRequest) GetIcmpCode() uint32 { + if x != nil && x.IcmpCode != nil { + return *x.IcmpCode + } + return 0 +} + +type TraceStage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Allowed bool `protobuf:"varint,3,opt,name=allowed,proto3" json:"allowed,omitempty"` + ForwardingDetails *string `protobuf:"bytes,4,opt,name=forwarding_details,json=forwardingDetails,proto3,oneof" json:"forwarding_details,omitempty"` +} + +func (x *TraceStage) Reset() { + *x = TraceStage{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[42] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TraceStage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TraceStage) ProtoMessage() {} + +func (x *TraceStage) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[42] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. +func (*TraceStage) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{42} +} + +func (x *TraceStage) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *TraceStage) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *TraceStage) GetAllowed() bool { + if x != nil { + return x.Allowed + } + return false +} + +func (x *TraceStage) GetForwardingDetails() string { + if x != nil && x.ForwardingDetails != nil { + return *x.ForwardingDetails + } + return "" +} + +type TracePacketResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Stages []*TraceStage `protobuf:"bytes,1,rep,name=stages,proto3" json:"stages,omitempty"` + FinalDisposition bool `protobuf:"varint,2,opt,name=final_disposition,json=finalDisposition,proto3" json:"final_disposition,omitempty"` +} + +func (x *TracePacketResponse) Reset() { + *x = TracePacketResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[43] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TracePacketResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TracePacketResponse) ProtoMessage() {} + +func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[43] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. +func (*TracePacketResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{43} +} + +func (x *TracePacketResponse) GetStages() []*TraceStage { + if x != nil { + return x.Stages + } + return nil +} + +func (x *TracePacketResponse) GetFinalDisposition() bool { + if x != nil { + return x.FinalDisposition + } + return false +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -2920,87 +3244,141 @@ var file_daemon_proto_rawDesc = []byte{ 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, - 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, - 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, - 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, - 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, - 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, - 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, - 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, + 0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, + 0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, + 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, + 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, + 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, + 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, + 0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, + 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, + 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, + 0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, + 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, + 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, + 0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, + 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, + 0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, + 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, + 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, + 0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, + 0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, + 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, + 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, + 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, + 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, + 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, + 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, + 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, + 0x43, 0x45, 0x10, 0x07, 0x32, 0xdd, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, + 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, - 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, - 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, - 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, - 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, + 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, - 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, - 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, - 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, - 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, - 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, - 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, - 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, - 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, - 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, + 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, + 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, + 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, + 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, + 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3016,7 +3394,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 45) var file_daemon_proto_goTypes = []interface{}{ (LogLevel)(0), // 0: daemon.LogLevel (*LoginRequest)(nil), // 1: daemon.LoginRequest @@ -3059,16 +3437,20 @@ var file_daemon_proto_goTypes = []interface{}{ (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse - nil, // 41: daemon.Network.ResolvedIPsEntry - (*durationpb.Duration)(nil), // 42: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp + (*TCPFlags)(nil), // 41: daemon.TCPFlags + (*TracePacketRequest)(nil), // 42: daemon.TracePacketRequest + (*TraceStage)(nil), // 43: daemon.TraceStage + (*TracePacketResponse)(nil), // 44: daemon.TracePacketResponse + nil, // 45: daemon.Network.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 46: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 46, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 47, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 47, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 46, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -3076,48 +3458,52 @@ var file_daemon_proto_depIdxs = []int32{ 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 45, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State - 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest - 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse - 33, // [33:49] is the sub-list for method output_type - 17, // [17:33] is the sub-list for method input_type - 17, // [17:17] is the sub-list for extension type_name - 17, // [17:17] is the sub-list for extension extendee - 0, // [0:17] is the sub-list for field type_name + 41, // 16: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 43, // 17: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 24, // 18: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 1, // 19: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 3, // 20: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 5, // 21: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 7, // 22: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 9, // 23: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 11, // 24: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 20, // 25: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 26: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 27: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 26, // 28: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 28, // 29: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 30, // 30: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 33, // 31: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 35, // 32: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 37, // 33: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 39, // 34: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 42, // 35: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 2, // 36: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 37: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 38: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 39: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 40: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 41: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 42: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 43: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 44: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 27, // 45: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 46: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 47: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 34, // 48: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 36, // 49: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 38, // 50: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 40, // 51: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 44, // 52: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 36, // [36:53] is the sub-list for method output_type + 19, // [19:36] is the sub-list for method input_type + 19, // [19:19] is the sub-list for extension type_name + 19, // [19:19] is the sub-list for extension extendee + 0, // [0:19] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -3606,15 +3992,65 @@ func file_daemon_proto_init() { return nil } } + file_daemon_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TCPFlags); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TracePacketRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TraceStage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[43].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TracePacketResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_daemon_proto_msgTypes[41].OneofWrappers = []interface{}{} + file_daemon_proto_msgTypes[42].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 41, + NumMessages: 45, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8db3add08f5..412449076e9 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -57,6 +57,8 @@ service DaemonService { // SetNetworkMapPersistence enables or disables network map persistence rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} + + rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} } @@ -356,3 +358,36 @@ message SetNetworkMapPersistenceRequest { } message SetNetworkMapPersistenceResponse {} + +message TCPFlags { + bool syn = 1; + bool ack = 2; + bool fin = 3; + bool rst = 4; + bool psh = 5; + bool urg = 6; +} + +message TracePacketRequest { + string source_ip = 1; + string destination_ip = 2; + string protocol = 3; + uint32 source_port = 4; + uint32 destination_port = 5; + string direction = 6; + optional TCPFlags tcp_flags = 7; + optional uint32 icmp_type = 8; + optional uint32 icmp_code = 9; +} + +message TraceStage { + string name = 1; + string message = 2; + bool allowed = 3; + optional string forwarding_details = 4; +} + +message TracePacketResponse { + repeated TraceStage stages = 1; + bool final_disposition = 2; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 39424aee938..9dcb543a80c 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -51,6 +51,7 @@ type DaemonServiceClient interface { DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) + TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) } type daemonServiceClient struct { @@ -205,6 +206,15 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in * return out, nil } +func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { + out := new(TracePacketResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -242,6 +252,7 @@ type DaemonServiceServer interface { DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) + TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -297,6 +308,9 @@ func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStat func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") } +func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -598,6 +612,24 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex return interceptor(ctx, in, info, handler) } +func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TracePacketRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).TracePacket(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/TracePacket", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -669,6 +701,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetNetworkMapPersistence", Handler: _DaemonService_SetNetworkMapPersistence_Handler, }, + { + MethodName: "TracePacket", + Handler: _DaemonService_TracePacket_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "daemon.proto", diff --git a/client/server/debug.go b/client/server/debug.go index a37195b290c..749220d62c8 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -538,7 +538,24 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( } log.SetLevel(level) + + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, fmt.Errorf("firewall manager not initialized") + } + + fwManager.SetLogLevel(level) + log.Infof("Log level set to %s", level.String()) + return &proto.SetLogLevelResponse{}, nil } diff --git a/client/server/trace.go b/client/server/trace.go new file mode 100644 index 00000000000..66b83d8cf86 --- /dev/null +++ b/client/server/trace.go @@ -0,0 +1,123 @@ +package server + +import ( + "context" + "fmt" + "net" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/proto" +) + +type packetTracer interface { + TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error) +} + +func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, fmt.Errorf("firewall manager not initialized") + } + + tracer, ok := fwManager.(packetTracer) + if !ok { + return nil, fmt.Errorf("firewall manager does not support packet tracing") + } + + srcIP := net.ParseIP(req.GetSourceIp()) + if req.GetSourceIp() == "self" { + srcIP = engine.GetWgAddr() + } + + dstIP := net.ParseIP(req.GetDestinationIp()) + if req.GetDestinationIp() == "self" { + dstIP = engine.GetWgAddr() + } + + if srcIP == nil || dstIP == nil { + return nil, fmt.Errorf("invalid IP address") + } + + var tcpState *uspfilter.TCPState + if flags := req.GetTcpFlags(); flags != nil { + tcpState = &uspfilter.TCPState{ + SYN: flags.GetSyn(), + ACK: flags.GetAck(), + FIN: flags.GetFin(), + RST: flags.GetRst(), + PSH: flags.GetPsh(), + URG: flags.GetUrg(), + } + } + + var dir fw.RuleDirection + switch req.GetDirection() { + case "in": + dir = fw.RuleDirectionIN + case "out": + dir = fw.RuleDirectionOUT + default: + return nil, fmt.Errorf("invalid direction") + } + + var protocol fw.Protocol + switch req.GetProtocol() { + case "tcp": + protocol = fw.ProtocolTCP + case "udp": + protocol = fw.ProtocolUDP + case "icmp": + protocol = fw.ProtocolICMP + default: + return nil, fmt.Errorf("invalid protocolcol") + } + + builder := &uspfilter.PacketBuilder{ + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + SrcPort: uint16(req.GetSourcePort()), + DstPort: uint16(req.GetDestinationPort()), + Direction: dir, + TCPState: tcpState, + ICMPType: uint8(req.GetIcmpType()), + ICMPCode: uint8(req.GetIcmpCode()), + } + trace, err := tracer.TracePacketFromBuilder(builder) + if err != nil { + return nil, fmt.Errorf("trace packet: %w", err) + } + + resp := &proto.TracePacketResponse{} + + for _, result := range trace.Results { + stage := &proto.TraceStage{ + Name: result.Stage.String(), + Message: result.Message, + Allowed: result.Allowed, + } + if result.ForwarderAction != nil { + details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr) + stage.ForwardingDetails = &details + } + resp.Stages = append(resp.Stages, stage) + } + + if len(trace.Results) > 0 { + resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed + } + + return resp, nil +} diff --git a/go.mod b/go.mod index 77d570662fa..3e1208e5ac3 100644 --- a/go.mod +++ b/go.mod @@ -102,6 +102,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 + gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 ) require ( @@ -237,7 +238,6 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect - gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 // indirect k8s.io/apimachinery v0.26.2 // indirect ) @@ -245,7 +245,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 4b9e90eba3d..54b77dbee11 100644 --- a/go.sum +++ b/go.sum @@ -535,8 +535,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= +github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -1250,8 +1250,8 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 h1:sCEaoA7ZmkuFwa2IR61pl4+RYZPwCJOiaSYT0k+BRf8= -gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From 58b2eb4b92254ccd4bb921d95f1662d6bd474e2e Mon Sep 17 00:00:00 2001 From: ransomware <4thel00z@gmail.com> Date: Fri, 7 Feb 2025 15:05:41 +0100 Subject: [PATCH 5/9] [signal] Fix context propagation in signal server (#3251) --- signal/server/signal.go | 101 +++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/signal/server/signal.go b/signal/server/signal.go index 305fd052b2e..abc1c367bc5 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -52,13 +52,13 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher(ctx, meter) + d, err := dispatcher.NewDispatcher(ctx, meter) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } s := &Server{ - dispatcher: dispatcher, + dispatcher: d, registry: peer.NewRegistry(appMetrics), metrics: appMetrics, } @@ -75,7 +75,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. return &proto.EncryptedMessage{}, nil } - return s.dispatcher.SendMessage(context.Background(), msg) + return s.dispatcher.SendMessage(ctx, msg) } // ConnectStream connects to the exchange stream @@ -98,76 +98,81 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) for { - // read incoming messages - msg, err := stream.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - - log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - - _, err = s.dispatcher.SendMessage(stream.Context(), msg) - if err != nil { - log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + select { + case <-stream.Context().Done(): + log.Debugf("stream closed for peer [%s] [streamID %d] due to context cancellation", p.Id, p.StreamID) + return stream.Context().Err() + default: + // read incoming messages + msg, err := stream.Recv() + if err == io.EOF { + break + } else if err != nil { + return err + } + + log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + + _, err = s.dispatcher.SendMessage(stream.Context(), msg) + if err != nil { + log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + } } } - - <-stream.Context().Done() - return stream.Context().Err() } func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { log.Debugf("registering new peer") - if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta { - if id, found := meta[proto.HeaderId]; found { - p := peer.NewPeer(id[0], stream) - - s.registry.Register(p) - s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) - - return p, nil - } else { - s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) - return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId) - } - } else { + meta, hasMeta := metadata.FromIncomingContext(stream.Context()) + if !hasMeta { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") } + + id, found := meta[proto.HeaderId] + if !found { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) + return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) + } + + p := peer.NewPeer(id[0], stream) + s.registry.Register(p) + s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) + return p, nil } func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) s.registry.Deregister(p) - s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - getRegistrationStart := time.Now() // lookup the target peer where the message is going to - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - // forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - // todo respond to the sender? - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - // in milliseconds - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - } - } else { + dstPeer, found := s.registry.Get(msg.RemoteKey) + + if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? } + + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() + + // forward the message to the target peer + if err := dstPeer.Stream.Send(msg); err != nil { + log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + // todo respond to the sender? + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return + } + + // in milliseconds + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) } From 5953b43ead2e0059807c3d295a2f9b324f0773b3 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 10 Feb 2025 10:32:50 +0100 Subject: [PATCH 6/9] [client, relay] Fix/wg watch (#3261) Fix WireGuard watcher related issues - Fix race handling between TURN and Relayed reconnection - Move the WgWatcher logic to separate struct - Handle timeouts in a more defensive way - Fix initial Relay client reconnection to the home server --- client/internal/peer/conn.go | 49 +++----- client/internal/peer/guard/guard.go | 36 ++---- client/internal/peer/wg_watcher.go | 154 ++++++++++++++++++++++++ client/internal/peer/wg_watcher_test.go | 98 +++++++++++++++ client/internal/peer/worker_ice.go | 17 +-- client/internal/peer/worker_relay.go | 127 ++++--------------- relay/client/client.go | 18 --- relay/client/guard.go | 71 +++++++---- relay/client/manager.go | 10 +- 9 files changed, 365 insertions(+), 215 deletions(-) create mode 100644 client/internal/peer/wg_watcher.go create mode 100644 client/internal/peer/wg_watcher_test.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index b8cb2582fb9..7caafa53d31 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -32,8 +32,8 @@ const ( defaultWgKeepAlive = 25 * time.Second connPriorityRelay ConnPriority = 1 - connPriorityICETurn ConnPriority = 1 - connPriorityICEP2P ConnPriority = 2 + connPriorityICETurn ConnPriority = 2 + connPriorityICEP2P ConnPriority = 3 ) type WgConfig struct { @@ -66,14 +66,6 @@ type ConnConfig struct { ICEConfig icemaker.Config } -type WorkerCallbacks struct { - OnRelayReadyCallback func(info RelayConnInfo) - OnRelayStatusChanged func(ConnStatus) - - OnICEConnReadyCallback func(ConnPriority, ICEConnInfo) - OnICEStatusChanged func(ConnStatus) -} - type Conn struct { log *log.Entry mu sync.Mutex @@ -135,21 +127,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu semaphore: semaphore, } - rFns := WorkerRelayCallbacks{ - OnConnReady: conn.relayConnectionIsReady, - OnDisconnected: conn.onWorkerRelayStateDisconnected, - } - - wFns := WorkerICECallbacks{ - OnConnReady: conn.iCEConnectionIsReady, - OnStatusChanged: conn.onWorkerICEStateDisconnected, - } - ctrl := isController(config) - conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) + conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) if err != nil { return nil, err } @@ -304,7 +286,7 @@ func (conn *Conn) GetKey() string { } // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected -func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { +func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -376,7 +358,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon } // todo review to make sense to handle connecting and disconnected status also? -func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { +func (conn *Conn) onICEStateDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() @@ -384,7 +366,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { return } - conn.log.Tracef("ICE connection state changed to %s", newState) + conn.log.Tracef("ICE connection state changed to disconnected") if conn.wgProxyICE != nil { if err := conn.wgProxyICE.CloseConn(); err != nil { @@ -404,10 +386,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.currentConnPriority = connPriorityRelay } - changed := conn.statusICE.Get() != newState && newState != StatusConnecting - conn.statusICE.Set(newState) - - conn.guard.SetICEConnDisconnected(changed) + changed := conn.statusICE.Get() != StatusDisconnected + if changed { + conn.guard.SetICEConnDisconnected() + } + conn.statusICE.Set(StatusDisconnected) peerState := State{ PubKey: conn.config.Key, @@ -422,7 +405,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { } } -func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { +func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -474,7 +457,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } -func (conn *Conn) onWorkerRelayStateDisconnected() { +func (conn *Conn) onRelayDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() @@ -497,8 +480,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { } changed := conn.statusRelay.Get() != StatusDisconnected + if changed { + conn.guard.SetRelayedConnDisconnected() + } conn.statusRelay.Set(StatusDisconnected) - conn.guard.SetRelayedConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index bf3527a6264..1fc2b4a4a90 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -29,8 +29,8 @@ type Guard struct { isConnectedOnAllWay isConnectedFunc timeout time.Duration srWatcher *SRWatcher - relayedConnDisconnected chan bool - iCEConnDisconnected chan bool + relayedConnDisconnected chan struct{} + iCEConnDisconnected chan struct{} } func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { @@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, isConnectedOnAllWay: isConnectedFn, timeout: timeout, srWatcher: srWatcher, - relayedConnDisconnected: make(chan bool, 1), - iCEConnDisconnected: make(chan bool, 1), + relayedConnDisconnected: make(chan struct{}, 1), + iCEConnDisconnected: make(chan struct{}, 1), } } @@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) { } } -func (g *Guard) SetRelayedConnDisconnected(changed bool) { +func (g *Guard) SetRelayedConnDisconnected() { select { - case g.relayedConnDisconnected <- changed: + case g.relayedConnDisconnected <- struct{}{}: default: } } -func (g *Guard) SetICEConnDisconnected(changed bool) { +func (g *Guard) SetICEConnDisconnected() { select { - case g.iCEConnDisconnected <- changed: + case g.iCEConnDisconnected <- struct{}{}: default: } } @@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { g.triggerOfferSending() } - case changed := <-g.relayedConnDisconnected: - if !changed { - continue - } + case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() ticker = g.prepareExponentTicker(ctx) tickerChannel = ticker.C - case changed := <-g.iCEConnDisconnected: - if !changed { - continue - } + case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() ticker = g.prepareExponentTicker(ctx) @@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) { g.log.Infof("start listen for reconnect events...") for { select { - case changed := <-g.relayedConnDisconnected: - if !changed { - continue - } + case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, triggering reconnect") g.triggerOfferSending() - case changed := <-g.iCEConnDisconnected: - if !changed { - continue - } + case <-g.iCEConnDisconnected: g.log.Debugf("ICE state changed, try to send new offer") g.triggerOfferSending() case <-srReconnectedChan: diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go new file mode 100644 index 00000000000..6670c6517e7 --- /dev/null +++ b/client/internal/peer/wg_watcher.go @@ -0,0 +1,154 @@ +package peer + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +const ( + wgHandshakePeriod = 3 * time.Minute +) + +var ( + wgHandshakeOvertime = 30 * time.Second // allowed delay in network + checkPeriod = wgHandshakePeriod + wgHandshakeOvertime +) + +type WGInterfaceStater interface { + GetStats(key string) (configurer.WGStats, error) +} + +type WGWatcher struct { + log *log.Entry + wgIfaceStater WGInterfaceStater + peerKey string + + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + waitGroup sync.WaitGroup +} + +func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher { + return &WGWatcher{ + log: log, + wgIfaceStater: wgIfaceStater, + peerKey: peerKey, + } +} + +// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. +func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { + w.log.Debugf("enable WireGuard watcher") + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctx != nil && w.ctx.Err() == nil { + w.log.Errorf("WireGuard watcher already enabled") + return + } + + ctx, ctxCancel := context.WithCancel(parentCtx) + w.ctx = ctx + w.ctxCancel = ctxCancel + + initialHandshake, err := w.wgState() + if err != nil { + w.log.Warnf("failed to read initial wg stats: %v", err) + } + + w.waitGroup.Add(1) + go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) +} + +// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit +func (w *WGWatcher) DisableWgWatcher() { + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxCancel == nil { + return + } + + w.log.Debugf("disable WireGuard watcher") + + w.ctxCancel() + w.ctxCancel = nil + w.waitGroup.Wait() +} + +// wgStateCheck help to check the state of the WireGuard handshake and relay connection +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { + w.log.Infof("WireGuard watcher started") + defer w.waitGroup.Done() + + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + defer ctxCancel() + + lastHandshake := initialHandshake + + for { + select { + case <-timer.C: + handshake, ok := w.handshakeCheck(lastHandshake) + if !ok { + onDisconnectedFn() + return + } + lastHandshake = *handshake + + resetTime := time.Until(handshake.Add(checkPeriod)) + timer.Reset(resetTime) + + w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) + case <-ctx.Done(): + w.log.Infof("WireGuard watcher stopped") + return + } + } +} + +// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one +func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) { + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + return nil, false + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + // the current know handshake did not change + if handshake.Equal(lastHandshake) { + w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + return nil, false + } + + // in case if the machine is suspended, the handshake time will be in the past + if handshake.Add(checkPeriod).Before(time.Now()) { + w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + return nil, false + } + + // error handling for handshake time in the future + if handshake.After(time.Now()) { + w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake) + return nil, false + } + + return &handshake, true +} + +func (w *WGWatcher) wgState() (time.Time, error) { + wgState, err := w.wgIfaceStater.GetStats(w.peerKey) + if err != nil { + return time.Time{}, err + } + return wgState.LastHandshake, nil +} diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go new file mode 100644 index 00000000000..a5b9026adb1 --- /dev/null +++ b/client/internal/peer/wg_watcher_test.go @@ -0,0 +1,98 @@ +package peer + +import ( + "context" + "testing" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type MocWgIface struct { + initial bool + lastHandshake time.Time + stop bool +} + +func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) { + if !m.initial { + m.initial = true + return configurer.WGStats{}, nil + } + + if !m.stop { + m.lastHandshake = time.Now() + } + + stats := configurer.WGStats{ + LastHandshake: m.lastHandshake, + } + + return stats, nil +} + +func (m *MocWgIface) disconnect() { + m.stop = true +} + +func TestWGWatcher_EnableWgWatcher(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + watcher.EnableWgWatcher(ctx, func() { + mlog.Infof("onDisconnectedFn") + onDisconnected <- struct{}{} + }) + + // wait for initial reading + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} + +func TestWGWatcher_ReEnable(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + + watcher.EnableWgWatcher(ctx, func() {}) + watcher.DisableWgWatcher() + + watcher.EnableWgWatcher(ctx, func() { + onDisconnected <- struct{}{} + }) + + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 00831849295..7dd84a98e56 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -31,20 +31,15 @@ type ICEConnInfo struct { RelayedOnLocal bool } -type WorkerICECallbacks struct { - OnConnReady func(ConnPriority, ICEConnInfo) - OnStatusChanged func(ConnStatus) -} - type WorkerICE struct { ctx context.Context log *log.Entry config ConnConfig + conn *Conn signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover statusRecorder *Status hasRelayOnLocally bool - conn WorkerICECallbacks agent *ice.Agent muxAgent sync.Mutex @@ -60,16 +55,16 @@ type WorkerICE struct { lastKnownState ice.ConnectionState } -func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { +func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { w := &WorkerICE{ ctx: ctx, log: log, config: config, + conn: conn, signaler: signaler, iFaceDiscover: ifaceDiscover, statusRecorder: statusRecorder, hasRelayOnLocally: hasRelayOnLocally, - conn: callBacks, } localUfrag, localPwd, err := icemaker.GenerateICECredentials() @@ -154,8 +149,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } - w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(selectedPriority(pair), ci) + w.log.Debugf("on ICE conn is ready to use") + go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -220,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: if w.lastKnownState != ice.ConnectionStateDisconnected { w.lastKnownState = ice.ConnectionStateDisconnected - w.conn.OnStatusChanged(StatusDisconnected) + w.conn.onICEStateDisconnected() } w.closeAgent(agentCancel) default: diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c22dcdeda5d..56c19cd1e3c 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -6,52 +6,41 @@ import ( "net" "sync" "sync/atomic" - "time" log "github.com/sirupsen/logrus" relayClient "github.com/netbirdio/netbird/relay/client" ) -var ( - wgHandshakePeriod = 3 * time.Minute - wgHandshakeOvertime = 30 * time.Second -) - type RelayConnInfo struct { relayedConn net.Conn rosenpassPubKey []byte rosenpassAddr string } -type WorkerRelayCallbacks struct { - OnConnReady func(RelayConnInfo) - OnDisconnected func() -} - type WorkerRelay struct { log *log.Entry isController bool config ConnConfig + conn *Conn relayManager relayClient.ManagerService - callBacks WorkerRelayCallbacks - relayedConn net.Conn - relayLock sync.Mutex - ctxWgWatch context.Context - ctxCancelWgWatch context.CancelFunc - ctxLock sync.Mutex + relayedConn net.Conn + relayLock sync.Mutex relaySupportedOnRemotePeer atomic.Bool + + wgWatcher *WGWatcher } -func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay { r := &WorkerRelay{ log: log, isController: ctrl, config: config, + conn: conn, relayManager: relayManager, - callBacks: callbacks, + wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key), } return r } @@ -87,7 +76,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relayedConn = relayedConn w.relayLock.Unlock() - err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) + err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected) if err != nil { log.Errorf("failed to add close listener: %s", err) _ = relayedConn.Close() @@ -95,7 +84,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("peer conn opened via Relay: %s", srv) - go w.callBacks.OnConnReady(RelayConnInfo{ + go w.conn.onRelayConnectionIsReady(RelayConnInfo{ relayedConn: relayedConn, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassAddr: remoteOfferAnswer.RosenpassAddr, @@ -103,32 +92,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { - w.log.Debugf("enable WireGuard watcher") - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil { - return - } - - ctx, ctxCancel := context.WithCancel(ctx) - w.ctxWgWatch = ctx - w.ctxCancelWgWatch = ctxCancel - - w.wgStateCheck(ctx, ctxCancel) + w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected) } func (w *WorkerRelay) DisableWgWatcher() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch == nil { - return - } - - w.log.Debugf("disable WireGuard watcher") - - w.ctxCancelWgWatch() + w.wgWatcher.DisableWgWatcher() } func (w *WorkerRelay) RelayInstanceAddress() (string, error) { @@ -150,57 +118,17 @@ func (w *WorkerRelay) CloseConn() { return } - err := w.relayedConn.Close() - if err != nil { + if err := w.relayedConn.Close(); err != nil { w.log.Warnf("failed to close relay connection: %v", err) } } -// wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { - w.log.Debugf("WireGuard watcher started") - lastHandshake, err := w.wgState() - if err != nil { - w.log.Warnf("failed to read wg stats: %v", err) - lastHandshake = time.Time{} - } - - go func(lastHandshake time.Time) { - timer := time.NewTimer(wgHandshakeOvertime) - defer timer.Stop() - defer ctxCancel() - - for { - select { - case <-timer.C: - handshake, err := w.wgState() - if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) - timer.Reset(wgHandshakeOvertime) - continue - } - - w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) - - if handshake.Equal(lastHandshake) { - w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) - w.relayLock.Lock() - _ = w.relayedConn.Close() - w.relayLock.Unlock() - w.callBacks.OnDisconnected() - return - } - - resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) - lastHandshake = handshake - timer.Reset(resetTime) - case <-ctx.Done(): - w.log.Debugf("WireGuard watcher stopped") - return - } - } - }(lastHandshake) +func (w *WorkerRelay) onWGDisconnected() { + w.relayLock.Lock() + _ = w.relayedConn.Close() + w.relayLock.Unlock() + w.conn.onRelayDisconnected() } func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { @@ -217,20 +145,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st return remoteRelayAddress } -func (w *WorkerRelay) wgState() (time.Time, error) { - wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key) - if err != nil { - return time.Time{}, err - } - return wgState.LastHandshake, nil -} - -func (w *WorkerRelay) onRelayMGDisconnected() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch != nil { - w.ctxCancelWgWatch() - } - go w.callBacks.OnDisconnected() +func (w *WorkerRelay) onRelayClientDisconnected() { + w.wgWatcher.DisableWgWatcher() + go w.conn.onRelayDisconnected() } diff --git a/relay/client/client.go b/relay/client/client.go index 3c23b70d27d..9e7e54393d4 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -141,7 +141,6 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func(string) - onConnectedListener func() listenerMutex sync.Mutex } @@ -190,7 +189,6 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - go c.notifyConnected() return nil } @@ -238,12 +236,6 @@ func (c *Client) SetOnDisconnectListener(fn func(string)) { c.onDisconnectListener = fn } -func (c *Client) SetOnConnectedListener(fn func()) { - c.listenerMutex.Lock() - defer c.listenerMutex.Unlock() - c.onConnectedListener = fn -} - // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -559,16 +551,6 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener(c.connectionURL) } -func (c *Client) notifyConnected() { - c.listenerMutex.Lock() - defer c.listenerMutex.Unlock() - - if c.onConnectedListener == nil { - return - } - go c.onConnectedListener() -} - func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/guard.go b/relay/client/guard.go index b971363a878..554330ea318 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -14,8 +14,9 @@ var ( // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + // OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance. OnNewRelayClient chan *Client + OnReconnected chan struct{} serverPicker *ServerPicker } @@ -23,6 +24,7 @@ type Guard struct { func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ OnNewRelayClient: make(chan *Client, 1), + OnReconnected: make(chan struct{}, 1), serverPicker: sp, } return g @@ -39,14 +41,13 @@ func NewGuard(sp *ServerPicker) *Guard { // - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { - if relayClient == nil { - goto RETRY - } - if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { + // try to reconnect to the same server + if ok := g.tryToQuickReconnect(ctx, relayClient); ok { + g.notifyReconnected() return } -RETRY: + // start a ticker to pick a new server ticker := exponentTicker(ctx) defer ticker.Stop() @@ -64,28 +65,19 @@ RETRY: } } -func (g *Guard) retry(ctx context.Context) error { - log.Infof("try to pick up a new Relay server") - relayClient, err := g.serverPicker.PickServer(ctx) - if err != nil { - return err +func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool { + if rc == nil { + return false } - // prevent to work with a deprecated Relay client instance - g.drainRelayClientChan() - - g.OnNewRelayClient <- relayClient - return nil -} - -func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { - ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) - defer cancel() - <-ctx.Done() + if !g.isServerURLStillValid(rc) { + return false + } - if parentCtx.Err() != nil { + if cancelled := waiteBeforeRetry(parentCtx); !cancelled { return false } + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) if err := rc.Connect(); err != nil { @@ -95,6 +87,20 @@ func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { return true } +func (g *Guard) retry(ctx context.Context) error { + log.Infof("try to pick up a new Relay server") + relayClient, err := g.serverPicker.PickServer(ctx) + if err != nil { + return err + } + + // prevent to work with a deprecated Relay client instance + g.drainRelayClientChan() + + g.OnNewRelayClient <- relayClient + return nil +} + func (g *Guard) drainRelayClientChan() { select { case <-g.OnNewRelayClient: @@ -111,6 +117,13 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool { return false } +func (g *Guard) notifyReconnected() { + select { + case g.OnReconnected <- struct{}{}: + default: + } +} + func exponentTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 2 * time.Second, @@ -121,3 +134,15 @@ func exponentTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } + +func waiteBeforeRetry(ctx context.Context) bool { + timer := time.NewTimer(1500 * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + return true + case <-ctx.Done(): + return false + } +} diff --git a/relay/client/manager.go b/relay/client/manager.go index d847bb879f1..26b11305058 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -165,6 +165,9 @@ func (m *Manager) Ready() bool { } func (m *Manager) SetOnReconnectedListener(f func()) { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + m.onReconnectedListenerFn = f } @@ -284,6 +287,9 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { } func (m *Manager) onServerConnected() { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + if m.onReconnectedListenerFn == nil { return } @@ -304,8 +310,11 @@ func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) listenGuardEvent(ctx context.Context) { for { select { + case <-m.reconnectGuard.OnReconnected: + m.onServerConnected() case rc := <-m.reconnectGuard.OnNewRelayClient: m.storeClient(rc) + m.onServerConnected() case <-ctx.Done(): return } @@ -317,7 +326,6 @@ func (m *Manager) storeClient(client *Client) { defer m.relayClientMu.Unlock() m.relayClient = client - m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) } From 488b697479e95db88bca281022e965af192527ed Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:13:34 +0100 Subject: [PATCH 7/9] [client] Support dns upstream failover for nameserver groups with same match domain (#3178) --- client/internal/dns/handler_chain.go | 24 +- client/internal/dns/handler_chain_test.go | 26 +- client/internal/dns/local.go | 7 +- client/internal/dns/server.go | 231 ++++++---- client/internal/dns/server_test.go | 525 +++++++++++++++++++++- client/internal/dns/upstream.go | 59 ++- client/internal/dns/upstream_android.go | 3 +- client/internal/dns/upstream_general.go | 3 +- client/internal/dns/upstream_ios.go | 3 +- client/internal/dns/upstream_test.go | 38 +- client/internal/peer/status.go | 4 +- 11 files changed, 742 insertions(+), 181 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 673f410e295..3286daabf7f 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -12,7 +12,7 @@ import ( const ( PriorityDNSRoute = 100 PriorityMatchDomain = 50 - PriorityDefault = 0 + PriorityDefault = 1 ) type SubdomainMatcher interface { @@ -26,7 +26,6 @@ type HandlerEntry struct { Pattern string OrigPattern string IsWildcard bool - StopHandler handlerWithStop MatchSubdomains bool } @@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string { } // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority -func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { c.mu.Lock() defer c.mu.Unlock() @@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority // First remove any existing handler with same pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { - if c.handlers[i].StopHandler != nil { - c.handlers[i].StopHandler.stop() - } c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) break } @@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority Pattern: pattern, OrigPattern: origPattern, IsWildcard: isWildcard, - StopHandler: stopHandler, MatchSubdomains: matchSubdomains, } @@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { - if entry.StopHandler != nil { - entry.StopHandler.stop() - } c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) return } @@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if log.IsLevelEnabled(log.TraceLevel) { log.Tracef("current handlers (%d):", len(handlers)) for _, h := range handlers { - log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", - h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) } } @@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } if !matched { - log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", - qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority) continue } - log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", - qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) chainWriter := &ResponseWriterChain{ ResponseWriter: w, diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index d04bfbbb396..8c66446ee60 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { dnsRouteHandler := &nbdns.MockHandler{} // Setup handlers with different priorities - chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) - chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) - chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) // Create test request r := new(dns.Msg) @@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { pattern = "*." + tt.handlerDomain[2:] } - chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) + chain.AddHandler(pattern, handler, nbdns.PriorityDefault) r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) @@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() } - chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority) } // Create and execute request @@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { handler3 := &nbdns.MockHandler{} // Add handlers in priority order - chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) - chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) - chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) // Create test request r := new(dns.Msg) @@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { if op.action == "add" { handler := &nbdns.MockHandler{} handlers[op.priority] = handler - chain.AddHandler(op.pattern, handler, op.priority, nil) + chain.AddHandler(op.pattern, handler, op.priority) } else { chain.RemoveHandler(op.pattern, op.priority) } @@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { r.SetQuestion(testQuery, dns.TypeA) // Add handlers in mixed order - chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) - chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) - chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) // Test 1: Initial state with all three handlers w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} @@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { handler = mockHandler } - chain.AddHandler(pattern, handler, h.priority, nil) + chain.AddHandler(pattern, handler, h.priority) } // Execute request diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 9a78d4d5057..1fe88f750d9 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -29,10 +29,15 @@ func (d *localResolver) String() string { return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) } +// ID returns the unique handler ID +func (d *localResolver) id() handlerID { + return "local-resolver" +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) > 0 { - log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) } replyMessage := &dns.Msg{} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1fe913fd9c1..f714f9857bb 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,7 +5,6 @@ import ( "fmt" "net/netip" "runtime" - "strings" "sync" "github.com/miekg/dns" @@ -42,7 +41,12 @@ type Server interface { ProbeAvailability() } -type registeredHandlerMap map[string]handlerWithStop +type handlerID string + +type nsGroupsByDomain struct { + domain string + groups []*nbdns.NameServerGroup +} // DefaultServer dns server object type DefaultServer struct { @@ -52,7 +56,6 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap - handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager @@ -77,14 +80,17 @@ type handlerWithStop interface { dns.Handler stop() probeAvailability() + id() handlerID } -type muxUpdate struct { +type handlerWrapper struct { domain string handler handlerWithStop priority int } +type registeredHandlerMap map[handlerID]handlerWrapper + // NewDefaultServer returns a new dns server func NewDefaultServer( ctx context.Context, @@ -158,13 +164,12 @@ func newDefaultServer( ) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: NewHandlerChain(), - dnsMuxMap: make(registeredHandlerMap), - handlerPriorities: make(map[string]int), + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p log.Warn("skipping empty domain") continue } - s.handlerChain.AddHandler(domain, handler, priority, nil) - s.handlerPriorities[domain] = priority + s.handlerChain.AddHandler(domain, handler, priority) s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) } } @@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) { log.Debugf("deregistering handler %v with priority %d", domains, priority) for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.handlerChain.RemoveHandler(domain, priority) // Only deregister from service if no handlers remain if !s.handlerChain.HasHandlers(domain) { - if domain == "" { - log.Warn("skipping empty domain") - continue - } s.service.DeregisterMux(nbdns.NormalizeZone(domain)) } } @@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone + func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { s.hostsDNSHolder.set(hostsDnsList) - _, ok := s.dnsMuxMap[nbdns.RootZone] - if ok { + // Check if there's any root handler + var hasRootHandler bool + for _, handler := range s.dnsMuxMap { + if handler.domain == nbdns.RootZone { + hasRootHandler = true + break + } + } + + if hasRootHandler { log.Debugf("on new host DNS config but skip to apply it") return } + log.Debugf("update host DNS settings: %+v", hostsDnsList) s.addHostRootZone() } @@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() { go func(mux handlerWithStop) { defer wg.Done() mux.probeAvailability() - }(mux) + }(mux.handler) } wg.Wait() } @@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { return nil } -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { - var muxUpdates []muxUpdate +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) { + var muxUpdates []handlerWrapper localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { @@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) return nil, nil, fmt.Errorf("received an empty list of records") } - muxUpdates = append(muxUpdates, muxUpdate{ + muxUpdates = append(muxUpdates, handlerWrapper{ domain: customZone.Domain, handler: s.localResolver, priority: PriorityMatchDomain, @@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) return muxUpdates, localRecords, nil } -func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { +func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { + var muxUpdates []handlerWrapper - var muxUpdates []muxUpdate for _, nsGroup := range nameServerGroups { if len(nsGroup.NameServers) == 0 { log.Warn("received a nameserver group with empty nameserver list") continue } + if !nsGroup.Primary && len(nsGroup.Domains) == 0 { + return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") + } + + for _, domain := range nsGroup.Domains { + if domain == "" { + return nil, fmt.Errorf("received a nameserver group with an empty domain element") + } + } + } + + groupedNS := groupNSGroupsByDomain(nameServerGroups) + + for _, domainGroup := range groupedNS { + basePriority := PriorityMatchDomain + if domainGroup.domain == nbdns.RootZone { + basePriority = PriorityDefault + } + + updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + if err != nil { + return nil, err + } + muxUpdates = append(muxUpdates, updates...) + } + + return muxUpdates, nil +} + +func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { + var muxUpdates []handlerWrapper + + for i, nsGroup := range domainGroup.groups { + // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts + priority := basePriority - i + + // Check if we're about to overlap with the next priority tier + if basePriority == PriorityMatchDomain && priority <= PriorityDefault { + log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityMatchDomain-PriorityDefault) + break + } + + log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) handler, err := newUpstreamResolver( s.ctx, s.wgInterface.Name(), @@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam s.wgInterface.Address().Network, s.statusRecorder, s.hostsDNSHolder, + domainGroup.domain, ) if err != nil { - return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) + return nil, fmt.Errorf("create upstream resolver: %v", err) } + for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", @@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam // after some period defined by upstream it tries to reactivate self by calling this hook // everything we need here is just to re-apply current configuration because it already // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) - - if nsGroup.Primary { - muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, - priority: PriorityDefault, - }) - continue - } + handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - if len(nsGroup.Domains) == 0 { - handler.stop() - return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") - } - - for _, domain := range nsGroup.Domains { - if domain == "" { - handler.stop() - return nil, fmt.Errorf("received a nameserver group with an empty domain element") - } - muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, - priority: PriorityMatchDomain, - }) - } + muxUpdates = append(muxUpdates, handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }) } return muxUpdates, nil } -func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { - muxUpdateMap := make(registeredHandlerMap) - handlersByPriority := make(map[string]int) +func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { + // this will introduce a short period of time when the server is not able to handle DNS requests + for _, existing := range s.dnsMuxMap { + s.deregisterHandler([]string{existing.domain}, existing.priority) + existing.handler.stop() + } - var isContainRootUpdate bool + muxUpdateMap := make(registeredHandlerMap) + var containsRootUpdate bool - // First register new handlers for _, update := range muxUpdates { - s.registerHandler([]string{update.domain}, update.handler, update.priority) - muxUpdateMap[update.domain] = update.handler - handlersByPriority[update.domain] = update.priority - - if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { - existingHandler.stop() - } - if update.domain == nbdns.RootZone { - isContainRootUpdate = true + containsRootUpdate = true } + s.registerHandler([]string{update.domain}, update.handler, update.priority) + muxUpdateMap[update.handler.id()] = update } - // Then deregister old handlers not in the update - for key, existingHandler := range s.dnsMuxMap { - _, found := muxUpdateMap[key] - if !found { - if !isContainRootUpdate && key == nbdns.RootZone { + // If there's no root update and we had a root handler, restore it + if !containsRootUpdate { + for _, existing := range s.dnsMuxMap { + if existing.domain == nbdns.RootZone { s.addHostRootZone() - existingHandler.stop() - } else { - existingHandler.stop() - // Deregister with the priority that was used to register - if oldPriority, ok := s.handlerPriorities[key]; ok { - s.deregisterHandler([]string{key}, oldPriority) - } + break } } } s.dnsMuxMap = muxUpdateMap - s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string { func (s *DefaultServer) upstreamCallbacks( nsGroup *nbdns.NameServerGroup, handler dns.Handler, + priority int, ) (deactivate func(error), reactivate func()) { var removeIndex map[string]int deactivate = func(err error) { @@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) + s.deregisterHandler([]string{nbdns.RootZone}, priority) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) + s.deregisterHandler([]string{item.Domain}, priority) removeIndex[item.Domain] = i } } @@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks( } s.updateNSState(nsGroup, err, false) - } + reactivate = func() { s.mux.Lock() defer s.mux.Unlock() @@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, PriorityMatchDomain) + s.registerHandler([]string{domain}, handler, priority) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) + s.registerHandler([]string{nbdns.RootZone}, handler, priority) } if s.hostManager != nil { @@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() { s.wgInterface.Address().Network, s.statusRecorder, s.hostsDNSHolder, + nbdns.RootZone, ) if err != nil { log.Errorf("unable to create a new upstream resolver, error: %v", err) @@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { for _, ns := range nsGroup.NameServers { servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) } - return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ",")) + return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) +} + +// groupNSGroupsByDomain groups nameserver groups by their match domains +func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain { + domainMap := make(map[string][]*nbdns.NameServerGroup) + + for _, group := range nsGroups { + if group.Primary { + domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group) + continue + } + + for _, domain := range group.Domains { + if domain == "" { + continue + } + domainMap[domain] = append(domainMap[domain], group) + } + } + + var result []nsGroupsByDomain + for domain, groups := range domainMap { + result = append(result, nsGroupsByDomain{ + domain: domain, + groups: groups, + }) + } + + return result } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 14ff1bb713e..db49f96a2cf 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -13,6 +13,7 @@ import ( "github.com/golang/mock/gomock" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -88,6 +89,18 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } +func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { + var srvs []string + for _, srv := range servers { + srvs = append(srvs, getNSHostPort(srv)) + } + return &upstreamResolverBase{ + domain: domain, + upstreamServers: srvs, + cancel: func() {}, + } +} + func TestUpdateDNSServer(t *testing.T) { nameServers := []nbdns.NameServer{ { @@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler}, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{ + generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + domain: "netbird.io", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + dummyHandler.id(): handlerWrapper{ + domain: "netbird.cloud", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + generateDummyHandler(".", nameServers).id(): handlerWrapper{ + domain: nbdns.RootZone, + handler: dummyHandler, + priority: PriorityDefault, + }, + }, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { - name: "New Config Should Succeed", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler}, - initSerial: 0, - inputSerial: 1, + name: "New Config Should Succeed", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: buildRecordKey(zoneRecords[0].Name, 1, 1), + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler}, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{ + generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + domain: "netbird.io", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + "local-resolver": handlerWrapper{ + domain: "netbird.cloud", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "Smaller Config Serial Should Be Skipped", @@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Empty Config Should Succeed and Clean Maps", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, + name: "Empty Config Should Succeed and Clean Maps", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: zoneRecords[0].Name, + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, @@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) { expectedLocalMap: make(registrationMap), }, { - name: "Disabled Service Should clean map", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, + name: "Disabled Service Should clean map", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: zoneRecords[0].Name, + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, @@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } }() - dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}} + dnsServer.dnsMuxMap = registeredHandlerMap{ + "id1": handlerWrapper{ + domain: zoneRecords[0].Name, + handler: &localResolver{}, + priority: PriorityMatchDomain, + }, + } dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} dnsServer.updateSerial = 0 @@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - handlerChain: NewHandlerChain(), - handlerPriorities: make(map[string]int), - hostManager: hostManager, + handlerChain: NewHandlerChain(), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { NameServers: []nbdns.NameServer{ {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, }, - }, nil) + }, nil, 0) deactivate(nil) expected := "domain0,domain2" @@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { Subdomains: true, } - chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) - chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) testCases := []struct { name string @@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { }) } } + +type mockHandler struct { + Id string +} + +func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (m *mockHandler) stop() {} +func (m *mockHandler) probeAvailability() {} +func (m *mockHandler) id() handlerID { return handlerID(m.Id) } + +type mockService struct{} + +func (m *mockService) Listen() error { return nil } +func (m *mockService) Stop() {} +func (m *mockService) RuntimeIP() string { return "127.0.0.1" } +func (m *mockService) RuntimePort() int { return 53 } +func (m *mockService) RegisterMux(string, dns.Handler) {} +func (m *mockService) DeregisterMux(string) {} + +func TestDefaultServer_UpdateMux(t *testing.T) { + baseMatchHandlers := registeredHandlerMap{ + "upstream-group1": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + "upstream-group2": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + } + + baseRootHandlers := registeredHandlerMap{ + "upstream-root1": { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + "upstream-root2": { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + } + + baseMixedHandlers := registeredHandlerMap{ + "upstream-group1": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + "upstream-group2": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + "upstream-other": { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + } + + tests := []struct { + name string + initialHandlers registeredHandlerMap + updates []handlerWrapper + expectedHandlers map[string]string // map[handlerID]domain + description string + }{ + { + name: "Remove group1 from update", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Only group2 remains + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group2": "example.com", + }, + description: "When group1 is not included in the update, it should be removed while group2 remains", + }, + { + name: "Remove group2 from update", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Only group1 remains + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + }, + description: "When group2 is not included in the update, it should be removed while group1 remains", + }, + { + name: "Add group3 in first position", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Add group3 with highest priority + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group3", + }, + priority: PriorityMatchDomain + 1, + }, + // Keep existing groups with their original priorities + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-group3": "example.com", + }, + description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups", + }, + { + name: "Add group3 in last position", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Keep existing groups with their original priorities + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + // Add group3 with lowest priority + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group3", + }, + priority: PriorityMatchDomain - 2, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-group3": "example.com", + }, + description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups", + }, + // Root zone tests + { + name: "Remove root1 from update", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root2": ".", + }, + description: "When root1 is not included in the update, it should be removed while root2 remains", + }, + { + name: "Remove root2 from update", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + }, + description: "When root2 is not included in the update, it should be removed while root1 remains", + }, + { + name: "Add root3 in first position", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root3", + }, + priority: PriorityDefault + 1, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + "upstream-root2": ".", + "upstream-root3": ".", + }, + description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers", + }, + { + name: "Add root3 in last position", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root3", + }, + priority: PriorityDefault - 2, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + "upstream-root2": ".", + "upstream-root3": ".", + }, + description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers", + }, + // Mixed domain tests + { + name: "Update with mixed domains - remove one of duplicate domain", + initialHandlers: baseMixedHandlers, + updates: []handlerWrapper{ + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-other": "other.com", + }, + description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains", + }, + { + name: "Update with mixed domains - add new domain", + initialHandlers: baseMixedHandlers, + updates: []handlerWrapper{ + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + { + domain: "new.com", + handler: &mockHandler{ + Id: "upstream-new", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-other": "other.com", + "upstream-new": "new.com", + }, + description: "When updating mixed domains, should maintain existing duplicates and add new domain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &DefaultServer{ + dnsMuxMap: tt.initialHandlers, + handlerChain: NewHandlerChain(), + service: &mockService{}, + } + + // Perform the update + server.updateMux(tt.updates) + + // Verify the results + assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap), + "Number of handlers after update doesn't match expected") + + // Check each expected handler + for id, expectedDomain := range tt.expectedHandlers { + handler, exists := server.dnsMuxMap[handlerID(id)] + assert.True(t, exists, "Expected handler %s not found", id) + if exists { + assert.Equal(t, expectedDomain, handler.domain, + "Domain mismatch for handler %s", id) + } + } + + // Verify no unexpected handlers exist + for handlerID := range server.dnsMuxMap { + _, expected := tt.expectedHandlers[string(handlerID)] + assert.True(t, expected, "Unexpected handler found: %s", handlerID) + } + + // Verify the handlerChain state and order + previousPriority := 0 + for _, chainEntry := range server.handlerChain.handlers { + // Verify priority order + if previousPriority > 0 { + assert.True(t, chainEntry.Priority <= previousPriority, + "Handlers in chain not properly ordered by priority") + } + previousPriority = chainEntry.Priority + + // Verify handler exists in mux + foundInMux := false + for _, muxEntry := range server.dnsMuxMap { + if chainEntry.Handler == muxEntry.handler && + chainEntry.Priority == muxEntry.priority && + chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { + foundInMux = true + break + } + } + assert.True(t, foundInMux, + "Handler in chain not found in dnsMuxMap") + } + }) + } +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index f0aa12b6539..4c69a173d8c 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,9 +2,13 @@ package dns import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "net" + "slices" + "strings" "sync" "sync/atomic" "time" @@ -40,6 +44,7 @@ type upstreamResolverBase struct { cancel context.CancelFunc upstreamClient upstreamClient upstreamServers []string + domain string disabled bool failsCount atomic.Int32 successCount atomic.Int32 @@ -53,12 +58,13 @@ type upstreamResolverBase struct { statusRecorder *peer.Status } -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ ctx: ctx, cancel: cancel, + domain: domain, upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, @@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string { return fmt.Sprintf("upstream %v", u.upstreamServers) } +// ID returns the unique handler ID +func (u *upstreamResolverBase) id() handlerID { + servers := slices.Clone(u.upstreamServers) + slices.Sort(servers) + + hash := sha256.New() + hash.Write([]byte(u.domain + ":")) + hash.Write([]byte(strings.Join(servers, ","))) + return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) +} + func (u *upstreamResolverBase) MatchSubdomains() bool { return true } @@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { u.checkUpstreamFails(err) }() - log.WithField("question", r.Question[0]).Trace("received an upstream question") + log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records if r.Extra == nil { r.SetEdns0(4096, false) @@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { select { case <-u.ctx.Done(): + log.Tracef("%s has been stopped", u) return default: } @@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if err != nil { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - log.WithError(err).WithField("upstream", upstream). - Warn("got an error while connecting to upstream") + log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) continue } - u.failsCount.Add(1) - log.WithError(err).WithField("upstream", upstream). - Error("got other error while querying the upstream") - return + log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + continue } - if rm == nil { - log.WithError(err).WithField("upstream", upstream). - Warn("no response from upstream") - return - } - // those checks need to be independent of each other due to memory address issues - if !rm.Response { - log.WithError(err).WithField("upstream", upstream). - Warn("no response from upstream") - return + if rm == nil || !rm.Response { + log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + continue } u.successCount.Add(1) - log.Tracef("took %s to query the upstream %s", t, upstream) + log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) - err = w.WriteMsg(rm) - if err != nil { - log.WithError(err).Error("got an error while writing the upstream resolver response") + if err = w.WriteMsg(rm); err != nil { + log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) - log.Error("all queries to the upstream nameservers failed with timeout") + log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) + + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeServerFailure) + if err := w.WriteMsg(m); err != nil { + log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + } } // checkUpstreamFails counts fails and disables or enables upstream resolving diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 36ea05e4409..a9e46ca0279 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -27,8 +27,9 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, + domain string, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index a29350f8ce5..51acbf7a6cc 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -23,8 +23,9 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, + domain string, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 60ed79d8769..7d3301e14e1 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -30,8 +30,9 @@ func newUpstreamResolver( net *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, + domain string, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index c1251dcc1e9..c5adc085817 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { timeout time.Duration cancelCTX bool expectedAnswer string + acceptNXDomain bool }{ { name: "Should Resolve A Record", @@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { expectedAnswer: "1.1.1.1", }, { - name: "Should Not Resolve If Can't Connect To Both Servers", - inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), - InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, - timeout: 200 * time.Millisecond, - responseShouldBeNil: true, + name: "Should Not Resolve If Can't Connect To Both Servers", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, + timeout: 200 * time.Millisecond, + acceptNXDomain: true, }, { name: "Should Not Resolve If Parent Context Is Canceled", @@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { responseShouldBeNil: true, }, } - // should resolve if first upstream times out - // should not write when both fails - // should not resolve if parent context is canceled for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { @@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { t.Fatalf("should write a response message") } - foundAnswer := false - for _, answer := range responseMSG.Answer { - if strings.Contains(answer.String(), testCase.expectedAnswer) { - foundAnswer = true - break - } + if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError { + return } - if !foundAnswer { - t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) + if testCase.expectedAnswer != "" { + foundAnswer := false + for _, answer := range responseMSG.Answer { + if strings.Contains(answer.String(), testCase.expectedAnswer) { + foundAnswer = true + break + } + } + + if !foundAnswer { + t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) + } } }) } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0df2a2e81d7..311ddbd7f39 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetDNSStates() []NSGroupState { d.mux.Lock() defer d.mux.Unlock() - return d.nsGroupStates + + // shallow copy is good enough, as slices fields are currently not updated + return slices.Clone(d.nsGroupStates) } func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { From 44407a158a9416e076e5c9ca615134c42c64a036 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:42:04 +0100 Subject: [PATCH 8/9] [client] Fix dns handler chain test (#3307) --- client/internal/dns/handler_chain_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 8c66446ee60..94aa987af93 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -795,7 +795,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { if op.action == "add" { handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain} handlers[op.pattern] = handler - chain.AddHandler(op.pattern, handler, op.priority, nil) + chain.AddHandler(op.pattern, handler, op.priority) } else { chain.RemoveHandler(op.pattern, op.priority) } From 18f84f0df5e978f0bfb4435ca3d9282b8ee9ac31 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:09:17 +0100 Subject: [PATCH 9/9] [client] Check for fwmark support and use fallback routing if not supported (#3220) --- client/iface/configurer/usp.go | 2 +- client/internal/connect.go | 3 + client/internal/routemanager/manager.go | 11 +- .../routemanager/systemops/systemops_linux.go | 29 +---- .../systemops/systemops_unix_test.go | 1 + util/grpc/dialer.go | 1 - util/net/env.go | 21 ++-- util/net/env_generic.go | 12 ++ util/net/env_linux.go | 119 ++++++++++++++++++ util/net/net_linux.go | 20 +-- 10 files changed, 163 insertions(+), 56 deletions(-) create mode 100644 util/net/env_generic.go create mode 100644 util/net/env_linux.go diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 21d65ab2a5d..391269dd09e 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } func getFwmark() int { - if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { + if nbnet.AdvancedRouting() { return nbnet.NetbirdFwmark } return 0 diff --git a/client/internal/connect.go b/client/internal/connect.go index ddd10e5cdf4..a0d585ffe04 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -31,6 +31,7 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" signal "github.com/netbirdio/netbird/signal/client" "github.com/netbirdio/netbird/util" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -109,6 +110,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) + nbnet.Init() + backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 34bd67893d3..9c7f1f6faea 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -113,13 +113,14 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, } + useNoop := netstack.IsEnabled() || config.DisableClientRoutes + dm.setupRefCounters(useNoop) + // don't proceed with client routes if it is disabled if config.DisableClientRoutes { return dm } - dm.setupRefCounters() - if runtime.GOOS == "android" { cr := dm.initialClientRoutes(config.InitialRoutes) dm.notifier.SetInitialClientRoutes(cr) @@ -127,7 +128,7 @@ func NewManager(config ManagerConfig) *DefaultManager { return dm } -func (m *DefaultManager) setupRefCounters() { +func (m *DefaultManager) setupRefCounters(useNoop bool) { m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) @@ -137,7 +138,7 @@ func (m *DefaultManager) setupRefCounters() { }, ) - if netstack.IsEnabled() { + if useNoop { m.routeRefCounter = refcounter.New( func(netip.Prefix, struct{}) (struct{}, error) { return struct{}{}, refcounter.ErrIgnore @@ -449,7 +450,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro } func isRouteSupported(route *route.Route) bool { - if !nbnet.CustomRoutingDisabled() || route.IsDynamic() { + if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() { return true } diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 1da92cc8059..d724cb1a7ab 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -53,20 +53,6 @@ type ruleParams struct { description string } -// isLegacy determines whether to use the legacy routing setup -func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() -} - -// setIsLegacy sets the legacy routing setup -func setIsLegacy(b bool) { - if b { - os.Setenv("NB_USE_LEGACY_ROUTING", "true") - } else { - os.Unsetenv("NB_USE_LEGACY_ROUTING") - } -} - func getSetupRules() []ruleParams { return []ruleParams{ {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, @@ -87,7 +73,7 @@ func getSetupRules() []ruleParams { // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { - if isLegacy() { + if !nbnet.AdvancedRouting() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } @@ -103,11 +89,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - if errors.Is(err, syscall.EOPNOTSUPP) { - log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - setIsLegacy(true) - return r.setupRefCounter(initAddresses, stateManager) - } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } @@ -130,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if isLegacy() { + if !nbnet.AdvancedRouting() { return r.cleanupRefCounter(stateManager) } @@ -168,7 +149,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro } func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - if isLegacy() { + if !nbnet.AdvancedRouting() { return r.genericAddVPNRoute(prefix, intf) } @@ -191,7 +172,7 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { } func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { - if isLegacy() { + if !nbnet.AdvancedRouting() { return r.genericRemoveVPNRoute(prefix, intf) } @@ -504,7 +485,7 @@ func getAddressFamily(prefix netip.Prefix) int { } func hasSeparateRouting() ([]netip.Prefix, error) { - if isLegacy() { + if !nbnet.AdvancedRouting() { return GetRoutesFromTable() } return nil, ErrRoutingIsSeparate diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index a6000d963a4..d88c1ab6bca 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -85,6 +85,7 @@ var testCases = []testCase{ } func TestRouting(t *testing.T) { + nbnet.Init() for _, tc := range testCases { // todo resolve test execution on freebsd if runtime.GOOS == "freebsd" { diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 83a11c65dac..f6d6d2f0456 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -40,7 +40,6 @@ func WithCustomDialer() grpc.DialOption { } } - log.Debug("Using nbnet.NewDialer()") conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) diff --git a/util/net/env.go b/util/net/env.go index 099da39b760..32425665dea 100644 --- a/util/net/env.go +++ b/util/net/env.go @@ -2,6 +2,7 @@ package net import ( "os" + "strconv" log "github.com/sirupsen/logrus" @@ -10,20 +11,24 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) +// CustomRoutingDisabled returns true if custom routing is disabled. +// This will fall back to the operation mode before the exit node functionality was implemented. +// In particular exclusion routes won't be set up and all dialers and listeners will use net.Dial and net.Listen, respectively. func CustomRoutingDisabled() bool { if netstack.IsEnabled() { return true } - return os.Getenv(envDisableCustomRouting) == "true" -} -func SkipSocketMark() bool { - if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { - log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) - return true + var customRoutingDisabled bool + if val := os.Getenv(envDisableCustomRouting); val != "" { + var err error + customRoutingDisabled, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableCustomRouting, err) + } } - return false + + return customRoutingDisabled } diff --git a/util/net/env_generic.go b/util/net/env_generic.go new file mode 100644 index 00000000000..6d142a8387c --- /dev/null +++ b/util/net/env_generic.go @@ -0,0 +1,12 @@ +//go:build !linux || android + +package net + +func Init() { + // nothing to do on non-linux +} + +func AdvancedRouting() bool { + // non-linux currently doesn't support advanced routing + return false +} diff --git a/util/net/env_linux.go b/util/net/env_linux.go new file mode 100644 index 00000000000..124bf64de56 --- /dev/null +++ b/util/net/env_linux.go @@ -0,0 +1,119 @@ +//go:build linux && !android + +package net + +import ( + "errors" + "os" + "strconv" + "syscall" + "time" + + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + // these have the same effect, skip socket env supported for backward compatibility + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" +) + +var advancedRoutingSupported bool + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +func checkAdvancedRoutingSupport() bool { + var err error + + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + var skipSocketMark bool + if val := os.Getenv(envSkipSocketMark); val != "" { + skipSocketMark, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envSkipSocketMark, err) + } + } + + // requested to disable advanced routing + if legacyRouting || skipSocketMark || + // envCustomRoutingDisabled disables the custom dialers. + // There is no point in using advanced routing without those, as they set up fwmarks on the sockets. + CustomRoutingDisabled() || + // netstack mode doesn't need routing at all + netstack.IsEnabled() { + + log.Info("advanced routing has been requested to be disabled") + return false + } + + if !CheckFwmarkSupport() || !CheckRuleOperationsSupport() { + log.Warn("system doesn't support required routing features, falling back to legacy routing") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +func CheckFwmarkSupport() bool { + // temporarily enable advanced routing to check fwmarks are supported + old := advancedRoutingSupported + advancedRoutingSupported = true + defer func() { + advancedRoutingSupported = old + }() + + dialer := NewDialer() + dialer.Timeout = 100 * time.Millisecond + + conn, err := dialer.Dial("udp", "127.0.0.1:9") + if err != nil { + log.Warnf("failed to dial with fwmark: %v", err) + return false + } + if err := conn.Close(); err != nil { + log.Warnf("failed to close connection: %v", err) + + } + + return true +} + +func CheckRuleOperationsSupport() bool { + rule := netlink.NewRule() + // low precedence, semi-random + rule.Priority = 32321 + rule.Table = syscall.RT_TABLE_MAIN + rule.Family = netlink.FAMILY_V4 + + if err := netlink.RuleAdd(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warn("IP rule operations are not supported") + return false + } + log.Warnf("failed to test rule support: %v", err) + return false + } + + if err := netlink.RuleDel(rule); err != nil { + log.Warnf("failed to delete test rule: %v", err) + } + return true +} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index fc486ebd496..eae483a26a9 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -5,13 +5,11 @@ package net import ( "fmt" "syscall" - - log "github.com/sirupsen/logrus" ) // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { - if isSocketMarkDisabled() { + if !AdvancedRouting() { return nil } @@ -25,7 +23,7 @@ func SetSocketMark(conn syscall.Conn) error { // SetSocketOpt sets the SO_MARK option on the given file descriptor func SetSocketOpt(fd int) error { - if isSocketMarkDisabled() { + if !AdvancedRouting() { return nil } @@ -36,7 +34,7 @@ func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - if isSocketMarkDisabled() { + if !AdvancedRouting() { return } setErr = setSocketOptInt(int(fd)) @@ -55,15 +53,3 @@ func setRawSocketMark(conn syscall.RawConn) error { func setSocketOptInt(fd int) error { return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) } - -func isSocketMarkDisabled() bool { - if CustomRoutingDisabled() { - log.Infof("Custom routing is disabled, skipping SO_MARK") - return true - } - - if SkipSocketMark() { - return true - } - return false -}