diff --git a/hostutil/hostutil.go b/hostutil/hostutil.go index 5ff3856..85149e3 100644 --- a/hostutil/hostutil.go +++ b/hostutil/hostutil.go @@ -1,56 +1,9 @@ package hostutil import ( - "fmt" - "net" - "strconv" "strings" ) -// NormalizeHost takes a host input (either a domain name or hostname) and returns it in a standardized format. -// It converts the input to lowercase, removes redundant ports (e.g., 443 for HTTPS and 80 for HTTP), -// validates that the input is a valid fully qualified domain name (FQDN), and ensures that the hostname is properly formatted. -func NormalizeHost(host string) (string, error) { - host = strings.TrimSpace(strings.ToLower(host)) - - if strings.Contains(host, "://") || strings.Contains(host, "/") { - return "", fmt.Errorf("input contains URL scheme or path: %s", host) - } - - hostname, port, err := net.SplitHostPort(host) - if err != nil { - if strings.Contains(host, ":") { - return "", fmt.Errorf("failed to parse host input: %s", host) - } - hostname = host - } - - if port != "" { - if _, err := strconv.Atoi(port); err != nil { - return "", fmt.Errorf("invalid port number: %s", port) - } - } - - if net.ParseIP(hostname) == nil && !IsValidHostname(hostname) { - return "", fmt.Errorf("invalid hostname or IP address: %s", hostname) - } - - if !strings.Contains(hostname, ".") { - return "", fmt.Errorf("invalid FQDN: %s", hostname) - } - - switch port { - case "443", "80": - port = "" - } - - if port != "" { - hostname = net.JoinHostPort(hostname, port) - } - - return hostname, nil -} - // IsValidHostname checks if the given hostname is valid based on RFC 1123. func IsValidHostname(hostname string) bool { if len(hostname) == 0 || len(hostname) > 255 { diff --git a/hostutil/hostutil_test.go b/hostutil/hostutil_test.go index cddc68f..be6f03f 100644 --- a/hostutil/hostutil_test.go +++ b/hostutil/hostutil_test.go @@ -1,42 +1,50 @@ package hostutil import ( + "strings" "testing" ) -func TestNormalizeHost(t *testing.T) { +func TestIsValidHostname(t *testing.T) { tests := []struct { - input string - expected string - hasError bool + hostname string + valid bool }{ - {"example.com", "example.com", false}, - {"Example.COM", "example.com", false}, - {"example.com:8080", "example.com:8080", false}, - {"example.com:80", "example.com", false}, - {"example.com:443", "example.com", false}, - {"subdomain.example.com:443", "subdomain.example.com", false}, - {"subdomain.example.com:80", "subdomain.example.com", false}, - {" example.com ", "example.com", false}, - {"invalid_host:port", "", true}, - {"subdomain:invalidport", "", true}, - {"http://example.com", "", true}, - {"https://example.com", "", true}, - {"ftp://example.com", "", true}, - {"example.com/path", "", true}, - {"http://example.com:443", "", true}, - {"example", "", true}, - {"localhost", "", true}, + {"example.com", true}, + {"localhost", true}, + {"sub.domain.example.com", true}, + {"example", true}, + {"example123.com", true}, + {"123example.com", true}, + {"example-com", true}, + {"example.com-", false}, + {"-example.com", false}, + {"exa_mple.com", false}, + {"example..com", false}, + {"", false}, + {strings.Repeat("a", 256), false}, + {"example!.com", false}, + {"example .com", false}, + {".example.com", false}, + {"example.com.", false}, + {"ex%ample.com", false}, + {"example.com/", false}, + {"example..com", false}, + {"-example-.com", false}, + {"ex--ample.com", true}, + {"example.-com", false}, + {"example.com-", false}, + {"example-.com", false}, + {"exa*mple.com", false}, + {"example@com", false}, + {"example,com", false}, } for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - result, err := NormalizeHost(tt.input) - if (err != nil) != tt.hasError { - t.Errorf("expected error status %v, got %v (error: %v)", tt.hasError, (err != nil), err) - } - if result != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, result) + t.Run(tt.hostname, func(t *testing.T) { + result := IsValidHostname(tt.hostname) + if result != tt.valid { + t.Errorf("IsValidHostname(%q) = %v; want %v", tt.hostname, result, tt.valid) } }) } diff --git a/urlutil/urlutil.go b/urlutil/urlutil.go index 5bacb20..75e274a 100644 --- a/urlutil/urlutil.go +++ b/urlutil/urlutil.go @@ -98,27 +98,20 @@ func HasScheme(rawURL string) bool { return re.MatchString(rawURL) } -// EnsureScheme ensures a URL has a scheme. If no scheme is provided, it defaults to "http". -func EnsureScheme(rawURL string, scheme ...string) string { - if rawURL == "" { - return rawURL - } - - defaultScheme := "http" - if len(scheme) > 0 && scheme[0] != "" { - defaultScheme = scheme[0] - } - - u, err := url.Parse(rawURL) - if err != nil { - return defaultScheme + "://" + rawURL +// EnsureHTTP ensures a URL has an HTTP scheme +func EnsureHTTP(rawURL string) string { + if !HasScheme(rawURL) { + rawURL = "http://" + rawURL } + return rawURL +} - if u.Scheme == "" { - u.Scheme = defaultScheme +// EnsureHTTPS ensures a URL has an HTTPS scheme +func EnsureHTTPS(rawURL string) string { + if !HasScheme(rawURL) { + rawURL = "https://" + rawURL } - - return u.String() + return rawURL } // HasFileExtension checks if the given rawURL string has a file extension in its path diff --git a/urlutil/urlutil_test.go b/urlutil/urlutil_test.go index ee342e3..cc93ea7 100644 --- a/urlutil/urlutil_test.go +++ b/urlutil/urlutil_test.go @@ -75,23 +75,6 @@ func TestHasScheme(t *testing.T) { } } -func TestEnsureScheme(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"example.com", "http://example.com"}, - {"http://example.com", "http://example.com"}, - } - - for _, test := range tests { - result := EnsureScheme(test.input) - if result != test.expected { - t.Errorf("EnsureScheme(%s) = %s; want %s", test.input, result, test.expected) - } - } -} - func TestHasFileExtension(t *testing.T) { tests := []struct { input string @@ -228,3 +211,85 @@ func TestIsMediaExt(t *testing.T) { } } } + +func TestRemoveDefaultPort(t *testing.T) { + tests := []struct { + input string + expected string + hasError bool + }{ + {"http://example.com:80", "http://example.com", false}, + {"https://example.com:443", "https://example.com", false}, + {"http://example.com:8080", "http://example.com:8080", false}, + {"https://example.com:8443", "https://example.com:8443", false}, + {"ftp://example.com:21", "ftp://example.com", false}, + {"ftp://example.com:2121", "ftp://example.com:2121", false}, + {"http://example.com", "http://example.com", false}, + {"https://example.com", "https://example.com", false}, + {"invalid_url", "", true}, + {"example.com", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, err := RemoveDefaultPort(tt.input) + if (err != nil) != tt.hasError { + t.Errorf("expected error status %v, got %v (error: %v)", tt.hasError, (err != nil), err) + } + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestEnsureHTTP(t *testing.T) { + tests := []struct { + rawURL string + expected string + }{ + {"example.com", "http://example.com"}, + {"http://example.com", "http://example.com"}, + {"https://example.com", "https://example.com"}, + {"example.com/path", "http://example.com/path"}, + {"localhost", "http://localhost"}, + {"http://localhost", "http://localhost"}, + {"https://localhost", "https://localhost"}, + {"192.168.0.1", "http://192.168.0.1"}, + {"http://192.168.0.1", "http://192.168.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.rawURL, func(t *testing.T) { + result := EnsureHTTP(tt.rawURL) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestEnsureHTTPS(t *testing.T) { + tests := []struct { + rawURL string + expected string + }{ + {"example.com", "https://example.com"}, + {"http://example.com", "http://example.com"}, + {"https://example.com", "https://example.com"}, + {"example.com/path", "https://example.com/path"}, + {"localhost", "https://localhost"}, + {"https://localhost", "https://localhost"}, + {"192.168.0.1", "https://192.168.0.1"}, + {"https://192.168.0.1", "https://192.168.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.rawURL, func(t *testing.T) { + result := EnsureHTTPS(tt.rawURL) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +}