diff --git a/internal/lookuper/config.go b/internal/lookuper/config.go index 42e4229..8b21770 100644 --- a/internal/lookuper/config.go +++ b/internal/lookuper/config.go @@ -8,7 +8,7 @@ import ( "github.com/ghodss/yaml" "github.com/pabateman/dns-lookuper/internal/printer" - "github.com/pabateman/dns-lookuper/internal/resolver/v1" + "github.com/pabateman/dns-lookuper/internal/resolver/v2" cli "github.com/urfave/cli/v2" ) @@ -78,7 +78,7 @@ var ( }, &cli.StringFlag{ Name: argMode, - Usage: fmt.Sprintf("accept one of values: '%s', '%s' or '%s'", resolver.ModeIpv4, resolver.ModeIpv6, resolver.ModeAll), + Usage: fmt.Sprintf("accept one of values: '%s' or '%s'", resolver.ModeIpv4, resolver.ModeIpv6), Aliases: []string{"m"}, EnvVars: []string{"DNS_LOOKUPER_MODE"}, Value: modeDefault, @@ -151,7 +151,6 @@ var ( } modeEnum = []string{ - resolver.ModeAll, resolver.ModeIpv4, resolver.ModeIpv6, } diff --git a/internal/lookuper/lookuper.go b/internal/lookuper/lookuper.go index a9df179..721950e 100644 --- a/internal/lookuper/lookuper.go +++ b/internal/lookuper/lookuper.go @@ -9,7 +9,7 @@ import ( "github.com/pabateman/dns-lookuper/internal/parser" "github.com/pabateman/dns-lookuper/internal/printer" - "github.com/pabateman/dns-lookuper/internal/resolver/v1" + "github.com/pabateman/dns-lookuper/internal/resolver/v2" log "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" @@ -109,36 +109,32 @@ func performTask(t *task, s *settings) error { return fmt.Errorf("error while parsing lookup timeout: %+v", err) } - resolver := resolver.NewResolver(). + r := resolver.NewResolver(). WithMode(t.Mode). WithTimeout(lookupTimeout) - responses, err := resolver.Resolve(domainNames.ParsedNames) + responses, err := r.Resolve(domainNames.ParsedNames) if err != nil { return fmt.Errorf("error while resolving domain name: %+v", err) } - unresolvedErrors := make([]error, 0) + responsesNxdomain := resolver.FilterResponsesNxdomain(responses) - for _, r := range responses { - if r.Error != nil { - unresolvedErrors = append(unresolvedErrors, r.Error) - } - } - - if len(unresolvedErrors) > 0 { + if len(responsesNxdomain) > 0 { if s.Fail { - for _, err := range unresolvedErrors { - log.Error(err) + for _, response := range responsesNxdomain { + log.Errorf("%s: no such host", response.Name) } return fmt.Errorf("encountered errors while resolving domain names") } else { - for _, err := range unresolvedErrors { - log.Warn(err) + for _, response := range responsesNxdomain { + log.Warnf("%s: no such host", response.Name) } } } + responses = resolver.FilterResponsesNoerror(responses) + var outputFile *os.File if t.Output == "-" || t.Output == "/dev/stdout" { diff --git a/internal/printer/printer.go b/internal/printer/printer.go index 53ce3f7..bb00784 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -9,7 +9,7 @@ import ( "github.com/ghodss/yaml" "github.com/valyala/fasttemplate" - "github.com/pabateman/dns-lookuper/internal/resolver/v1" + "github.com/pabateman/dns-lookuper/internal/resolver/v2" ) const ( @@ -123,7 +123,7 @@ func (p *Printer) printTemplate() error { for _, address := range response.Addresses { s := t.ExecuteString(map[string]interface{}{ "host": response.Name, - "address": address.String(), + "address": address, }) if _, err := io.WriteString(p.writer, fmt.Sprintln(s)); err != nil { @@ -145,9 +145,7 @@ func (p *Printer) printList() error { addresses := make([]string, 0) for _, response := range p.entries { - for _, address := range response.Addresses { - addresses = append(addresses, address.String()) - } + addresses = append(addresses, response.Addresses...) } slices.Sort(addresses) diff --git a/internal/printer/printer_test.go b/internal/printer/printer_test.go index ca1e515..a54458f 100644 --- a/internal/printer/printer_test.go +++ b/internal/printer/printer_test.go @@ -3,12 +3,11 @@ package printer import ( "bytes" "fmt" - "net" "os" "path" "testing" - "github.com/pabateman/dns-lookuper/internal/resolver/v1" + "github.com/pabateman/dns-lookuper/internal/resolver/v2" "github.com/stretchr/testify/require" ) @@ -16,17 +15,17 @@ var ( entries = []resolver.Response{ { Name: "cloudflare.com", - Addresses: []net.IP{ - {8, 8, 8, 8}, - {1, 1, 1, 1}, + Addresses: []string{ + "8.8.8.8", + "1.1.1.1", }, }, { Name: "google.com", - Addresses: []net.IP{ - {10, 10, 10, 10}, - {123, 123, 123, 123}, - {8, 8, 8, 8}, + Addresses: []string{ + "10.10.10.10", + "123.123.123.123", + "8.8.8.8", }, }, } diff --git a/internal/resolver/v1/resolver_test.go b/internal/resolver/v1/resolver_test.go index a670a42..fe4f6b6 100644 --- a/internal/resolver/v1/resolver_test.go +++ b/internal/resolver/v1/resolver_test.go @@ -98,10 +98,10 @@ func deepCopyResponses(r []Response) []Response { func filterResponses(r []Response, f func(net.IP) bool) []Response { for i := range r { for { - indexes := slices.IndexFunc(r[i].Addresses, f) + index := slices.IndexFunc(r[i].Addresses, f) - if indexes != -1 { - r[i].Addresses = slices.Delete(r[i].Addresses, indexes, indexes+1) + if index != -1 { + r[i].Addresses = slices.Delete(r[i].Addresses, index, index+1) } else { break } diff --git a/internal/resolver/v2/resolver.go b/internal/resolver/v2/resolver.go index be7fec7..04b5879 100644 --- a/internal/resolver/v2/resolver.go +++ b/internal/resolver/v2/resolver.go @@ -1,7 +1,8 @@ package resolver import ( - "net" + "fmt" + "strings" "time" "github.com/miekg/dns" @@ -10,8 +11,7 @@ import ( const ( ModeIpv4 = "ipv4" ModeIpv6 = "ipv6" - ModeAll = "all" - ModeDefault = ModeAll + ModeDefault = ModeIpv4 ) const ( @@ -20,22 +20,24 @@ const ( type Response struct { Name string `json:"name"` - Addresses []net.IP `json:"addresses"` - Error error `json:"-"` + Addresses []string `json:"addresses"` + rcode int } type Resolver struct { resolver *dns.Client timeout time.Duration - mode string + mode uint16 } func NewResolver() *Resolver { return &Resolver{ resolver: &dns.Client{ - Timeout: TimeoutDefault, + SingleInflight: true, + Timeout: TimeoutDefault, }, timeout: TimeoutDefault, + mode: getQueryTypes(ModeDefault), } } @@ -44,6 +46,86 @@ func (r *Resolver) WithTimeout(t time.Duration) *Resolver { return r } +func (r *Resolver) WithMode(m string) *Resolver { + r.mode = getQueryTypes(m) + return r +} + func (r *Resolver) Resolve(dn []string) ([]Response, error) { - return nil, nil + result := make([]Response, 0) + config, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + return nil, err + } + + for _, name := range dn { + result = append(result, Response{ + Name: name, + Addresses: make([]string, 0), + }, + ) + + msgQuery := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{ + { + Name: dns.Fqdn(name), + Qtype: r.mode, + Qclass: dns.ClassINET, + }, + }, + } + + msqResponse, _, err := r.resolver.Exchange( + msgQuery, + fmt.Sprintf("%s:%s", config.Servers[0], "53"), + ) + + if err != nil { + return nil, err + } + + result[len(result)-1].rcode = msqResponse.MsgHdr.Rcode + + for _, response := range msqResponse.Answer { + result[len(result)-1].Addresses = append(result[len(result)-1].Addresses, strings.Split(response.String(), "\t")[4]) + } + + } + + return result, nil +} + +func FilterResponsesByRcode(rs []Response, rcode int) []Response { + result := make([]Response, 0) + + for _, r := range rs { + if r.rcode == rcode { + result = append(result, r) + } + } + + return result +} + +func FilterResponsesNoerror(rs []Response) []Response { + return FilterResponsesByRcode(rs, dns.StringToRcode["NOERROR"]) +} + +func FilterResponsesNxdomain(rs []Response) []Response { + return FilterResponsesByRcode(rs, dns.StringToRcode["NXDOMAIN"]) +} + +func getQueryTypes(m string) uint16 { + switch m { + case ModeIpv4: + return dns.TypeA + case ModeIpv6: + return dns.TypeAAAA + default: + return dns.TypeA + } } diff --git a/internal/resolver/v2/resolver_test.go b/internal/resolver/v2/resolver_test.go index c926e37..7b34ba0 100644 --- a/internal/resolver/v2/resolver_test.go +++ b/internal/resolver/v2/resolver_test.go @@ -1,9 +1,11 @@ package resolver import ( - "net" + "slices" "testing" + "time" + "github.com/miekg/dns" "github.com/stretchr/testify/require" ) @@ -13,30 +15,144 @@ var ( "kernel.org", } - expectedValid = []Response{ + dnOnlyIPv4 = []string{ + "fedora.com", + "hashicorp.com", + } + + dnNxdomain = []string{ + "foo.iana.org", + "buz.kernel.org", + } + + expectedValidIPv4 = []Response{ + { + Name: "iana.org", + Addresses: []string{"192.0.43.8"}, + rcode: 0, + }, + { + Name: "kernel.org", + Addresses: []string{"139.178.84.217"}, + rcode: 0, + }, + } + + expectedValidIPv6 = []Response{ + { + Name: "iana.org", + Addresses: []string{"2001:500:88:200::8"}, + rcode: 0, + }, + { + Name: "kernel.org", + Addresses: []string{"2604:1380:4641:c500::1"}, + rcode: 0, + }, + } + + expectedOnlyIPv4 = []Response{ + { + Name: "fedora.com", + Addresses: []string{"86.105.245.69"}, + rcode: 0, + }, + { + Name: "hashicorp.com", + Addresses: []string{"76.76.21.21"}, + rcode: 0, + }, + } + + expectedOnlyIPv4Empty = []Response{ { - Name: "iana.org", - Addresses: []net.IP{ - {192, 0, 43, 8}, - {32, 1, 5, 0, 0, 136, 2, 0, 0, 0, 0, 0, 0, 0, 0, 8}, - }, - Error: nil, + Name: "fedora.com", + Addresses: []string{}, + rcode: 0, }, { - Name: "kernel.org", - Addresses: []net.IP{ - {139, 178, 84, 217}, - {38, 4, 19, 128, 70, 65, 197, 0, 0, 0, 0, 0, 0, 0, 0, 1}, - }, - Error: nil, + Name: "hashicorp.com", + Addresses: []string{}, + rcode: 0, + }, + } + + expectedNxdomain = []Response{ + { + Name: "foo.iana.org", + Addresses: []string{}, + rcode: 3, + }, + { + Name: "buz.kernel.org", + Addresses: []string{}, + rcode: 3, }, } ) func TestBasicResolver(t *testing.T) { r := NewResolver() - responsesValid, err := r.Resolve(dnValid) + response, err := r.Resolve(dnValid) require.Nil(t, err) - require.Equal(t, expectedValid, responsesValid) + require.Equal(t, expectedValidIPv4, response) + + r.WithMode(ModeIpv6) + response, err = r.Resolve(dnValid) + require.Nil(t, err) + + require.Equal(t, expectedValidIPv6, response) + + r.WithMode(ModeIpv4) + response, err = r.Resolve(dnValid) + require.Nil(t, err) + + require.Equal(t, expectedValidIPv4, response) + + r.WithMode("foobarbuzz") + require.Equal(t, r.mode, dns.TypeA) +} + +func TestOnlyIPv4(t *testing.T) { + r := NewResolver() + + response, err := r.Resolve(dnOnlyIPv4) + require.Nil(t, err) + + require.Equal(t, expectedOnlyIPv4, response) + + r.WithMode(ModeIpv6) + responseEmpty, err := r.Resolve(dnOnlyIPv4) + require.Nil(t, err) + + require.Equal(t, expectedOnlyIPv4Empty, responseEmpty) +} + +func TestNxdomain(t *testing.T) { + r := NewResolver() + + responseNxdomain, err := r.Resolve(dnNxdomain) + require.Nil(t, err) + + require.Equal(t, expectedNxdomain, responseNxdomain) + + responseValid, err := r.Resolve(dnValid) + require.Nil(t, err) + + responseTotal := slices.Concat(responseNxdomain, responseValid) + + responseTotal = FilterResponsesNoerror(responseTotal) + require.Equal(t, expectedValidIPv4, responseTotal) + + responseTotal = FilterResponsesNoerror(responseTotal) + require.Equal(t, expectedValidIPv4, responseTotal) +} + +func TestTimeout(t *testing.T) { + r := NewResolver().WithTimeout(time.Microsecond * 10) + + _, err := r.Resolve([]string{dnValid[0]}) + require.NotNil(t, err) + }