diff --git a/assert.go b/assert.go index 70d5b7d..6560ba0 100644 --- a/assert.go +++ b/assert.go @@ -15,52 +15,62 @@ import ( "github.com/hexops/gotextdiff/myers" ) -func objectsAreEqual(expected, actual interface{}) bool { - if expected == nil || actual == nil { - return expected == actual - } +// A CompareOption modifies how object comparisons behave. +type CompareOption func() []repr.Option - exp, eok := expected.([]byte) - act, aok := actual.([]byte) - - if eok && aok { - return bytes.Equal(exp, act) +// Exclude fields of the given type from comparison. +func Exclude[T any]() CompareOption { + return func() []repr.Option { + return []repr.Option{repr.Hide[T]()} } - - return reflect.DeepEqual(expected, actual) } // Compare two values for equality and return true or false. -func Compare[T any](t testing.TB, x, y T) bool { - return objectsAreEqual(x, y) +func Compare[T any](t testing.TB, x, y T, options ...CompareOption) bool { + return objectsAreEqual(x, y, options...) +} + +func extractCompareOptions(msgAndArgs ...any) ([]any, []CompareOption) { + compareOptions := []CompareOption{} + out := []any{} + for _, arg := range msgAndArgs { + if opt, ok := arg.(CompareOption); ok { + compareOptions = append(compareOptions, opt) + } else { + out = append(out, arg) + } + } + return out, compareOptions } // Equal asserts that "expected" and "actual" are equal. // // If they are not, a diff of the Go representation of the values will be displayed. -func Equal[T any](t testing.TB, expected, actual T, msgAndArgs ...interface{}) { - if objectsAreEqual(expected, actual) { +func Equal[T any](t testing.TB, expected, actual T, msgArgsAndCompareOptions ...any) { + msgArgsAndCompareOptions, compareOptions := extractCompareOptions(msgArgsAndCompareOptions...) + if objectsAreEqual(expected, actual, compareOptions...) { return } t.Helper() - msg := formatMsgAndArgs("Expected values to be equal:", msgAndArgs...) - t.Fatalf("%s\n%s", msg, diff(expected, actual)) + msg := formatMsgAndArgs("Expected values to be equal:", msgArgsAndCompareOptions...) + t.Fatalf("%s\n%s", msg, diff(expected, actual, compareOptions...)) } // NotEqual asserts that "expected" is not equal to "actual". // // If they are equal the expected value will be displayed. -func NotEqual[T any](t testing.TB, expected, actual T, msgAndArgs ...interface{}) { - if !objectsAreEqual(expected, actual) { +func NotEqual[T any](t testing.TB, expected, actual T, msgArgsAndCompareOptions ...any) { + msgArgsAndCompareOptions, compareOptions := extractCompareOptions(msgArgsAndCompareOptions...) + if !objectsAreEqual(expected, actual, compareOptions...) { return } t.Helper() - msg := formatMsgAndArgs("Expected values to not be equal but both were:", msgAndArgs...) + msg := formatMsgAndArgs("Expected values to not be equal but both were:", msgArgsAndCompareOptions...) t.Fatalf("%s\n%s", msg, repr.String(expected, repr.Indent(" "))) } // Contains asserts that "haystack" contains "needle". -func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...interface{}) { +func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...any) { if strings.Contains(haystack, needle) { return } @@ -70,7 +80,7 @@ func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...interf } // NotContains asserts that "haystack" does not contain "needle". -func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...interface{}) { +func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...any) { if !strings.Contains(haystack, needle) { return } @@ -81,7 +91,7 @@ func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...int } // Zero asserts that a value is its zero value. -func Zero[T any](t testing.TB, value T, msgAndArgs ...interface{}) { +func Zero[T any](t testing.TB, value T, msgAndArgs ...any) { var zero T if objectsAreEqual(value, zero) { return @@ -96,7 +106,7 @@ func Zero[T any](t testing.TB, value T, msgAndArgs ...interface{}) { } // NotZero asserts that a value is not its zero value. -func NotZero[T any](t testing.TB, value T, msgAndArgs ...interface{}) { +func NotZero[T any](t testing.TB, value T, msgAndArgs ...any) { var zero T if !objectsAreEqual(value, zero) { val := reflect.ValueOf(value) @@ -111,7 +121,7 @@ func NotZero[T any](t testing.TB, value T, msgAndArgs ...interface{}) { // EqualError asserts that either an error is non-nil and that its message is what is expected, // or that error is nil if the expected message is empty. -func EqualError(t testing.TB, err error, errString string, msgAndArgs ...interface{}) { +func EqualError(t testing.TB, err error, errString string, msgAndArgs ...any) { if err == nil && errString == "" { return } @@ -126,7 +136,7 @@ func EqualError(t testing.TB, err error, errString string, msgAndArgs ...interfa } // IsError asserts than any error in "err"'s tree matches "target". -func IsError(t testing.TB, err, target error, msgAndArgs ...interface{}) { +func IsError(t testing.TB, err, target error, msgAndArgs ...any) { if errors.Is(err, target) { return } @@ -135,7 +145,7 @@ func IsError(t testing.TB, err, target error, msgAndArgs ...interface{}) { } // NotIsError asserts than no error in "err"'s tree matches "target". -func NotIsError(t testing.TB, err, target error, msgAndArgs ...interface{}) { +func NotIsError(t testing.TB, err, target error, msgAndArgs ...any) { if !errors.Is(err, target) { return } @@ -144,7 +154,7 @@ func NotIsError(t testing.TB, err, target error, msgAndArgs ...interface{}) { } // Error asserts that an error is not nil. -func Error(t testing.TB, err error, msgAndArgs ...interface{}) { +func Error(t testing.TB, err error, msgAndArgs ...any) { if err != nil { return } @@ -153,7 +163,7 @@ func Error(t testing.TB, err error, msgAndArgs ...interface{}) { } // NoError asserts that an error is nil. -func NoError(t testing.TB, err error, msgAndArgs ...interface{}) { +func NoError(t testing.TB, err error, msgAndArgs ...any) { if err == nil { return } @@ -163,7 +173,7 @@ func NoError(t testing.TB, err error, msgAndArgs ...interface{}) { } // True asserts that an expression is true. -func True(t testing.TB, ok bool, msgAndArgs ...interface{}) { +func True(t testing.TB, ok bool, msgAndArgs ...any) { if ok { return } @@ -172,7 +182,7 @@ func True(t testing.TB, ok bool, msgAndArgs ...interface{}) { } // False asserts that an expression is false. -func False(t testing.TB, ok bool, msgAndArgs ...interface{}) { +func False(t testing.TB, ok bool, msgAndArgs ...any) { if !ok { return } @@ -181,7 +191,7 @@ func False(t testing.TB, ok bool, msgAndArgs ...interface{}) { } // Panics asserts that the given function panics. -func Panics(t testing.TB, fn func(), msgAndArgs ...interface{}) { +func Panics(t testing.TB, fn func(), msgAndArgs ...any) { t.Helper() defer func() { if recover() == nil { @@ -193,7 +203,7 @@ func Panics(t testing.TB, fn func(), msgAndArgs ...interface{}) { } // NotPanics asserts that the given function does not panic. -func NotPanics(t testing.TB, fn func(), msgAndArgs ...interface{}) { +func NotPanics(t testing.TB, fn func(), msgAndArgs ...any) { t.Helper() defer func() { if err := recover(); err != nil { @@ -204,15 +214,16 @@ func NotPanics(t testing.TB, fn func(), msgAndArgs ...interface{}) { fn() } -func diff[T any](before, after T) string { +func diff[T any](before, after T, compareOptions ...CompareOption) string { var lhss, rhss string // Special case strings so we get nice diffs. if l, ok := any(before).(string); ok { lhss = l rhss = any(after).(string) } else { - lhss = repr.String(before, repr.Indent(" ")) + "\n" - rhss = repr.String(after, repr.Indent(" ")) + "\n" + ropts := expandCompareOptions(compareOptions...) + lhss = repr.String(before, ropts...) + "\n" + rhss = repr.String(after, ropts...) + "\n" } edits := myers.ComputeEdits("a.txt", lhss, rhss) lines := strings.Split(fmt.Sprint(gotextdiff.ToUnified("expected.txt", "actual.txt", lhss, edits)), "\n") @@ -222,7 +233,7 @@ func diff[T any](before, after T) string { return strings.Join(lines[3:], "\n") } -func formatMsgAndArgs(dflt string, msgAndArgs ...interface{}) string { +func formatMsgAndArgs(dflt string, msgAndArgs ...any) string { if len(msgAndArgs) == 0 { return dflt } @@ -243,3 +254,33 @@ func needlePosition(haystack, needle string) (quotedHaystack, quotedNeedle, posi } return } + +func expandCompareOptions(options ...CompareOption) []repr.Option { + ropts := []repr.Option{repr.Indent(" ")} + for _, option := range options { + ropts = append(ropts, option()...) + } + return ropts +} + +func objectsAreEqual(expected, actual any, options ...CompareOption) bool { + if expected == nil || actual == nil { + return expected == actual + } + if exp, eok := expected.([]byte); eok { + if act, aok := actual.([]byte); aok { + return bytes.Equal(exp, act) + } + } + if exp, eok := expected.(string); eok { + if act, aok := actual.(string); aok { + return exp == act + } + } + + ropts := expandCompareOptions(options...) + expectedStr := repr.String(expected, ropts...) + actualStr := repr.String(actual, ropts...) + + return expectedStr == actualStr +} diff --git a/assert_test.go b/assert_test.go index 728ea64..00bdfa3 100644 --- a/assert_test.go +++ b/assert_test.go @@ -33,6 +33,9 @@ func TestEqual(t *testing.T) { assertFail(t, "Different numbers", func(t testing.TB) { Equal(t, 42, 43) }) + assertOk(t, "Exclude", func(t testing.TB) { + Equal(t, Data{Str: "expected", Num: 1234}, Data{Str: "expected"}, Exclude[int64]()) + }) } func TestEqualStrings(t *testing.T) { @@ -48,6 +51,9 @@ func TestNotEqual(t *testing.T) { assertFail(t, "SameValue", func(t testing.TB) { NotEqual(t, Data{"expected", 1234}, Data{"expected", 1234}) }) + assertFail(t, "Exclude", func(t testing.TB) { + NotEqual(t, Data{Str: "expected", Num: 1234}, Data{Str: "expected"}, Exclude[int64]()) + }) } func TestContains(t *testing.T) { diff --git a/go.mod b/go.mod index 45eb9de..ae81508 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/alecthomas/assert/v2 go 1.18 require ( - github.com/alecthomas/repr v0.2.0 + github.com/alecthomas/repr v0.3.0 github.com/hexops/gotextdiff v1.0.3 ) diff --git a/go.sum b/go.sum index 690b029..343eabe 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/alecthomas/repr v0.3.0 h1:NeYzUPfjjlqHY4KtzgKJiWd6sVq2eNUPTi34PiFGjY8= +github.com/alecthomas/repr v0.3.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=