diff --git a/assert.go b/assert.go index 6560ba0..5011797 100644 --- a/assert.go +++ b/assert.go @@ -237,7 +237,11 @@ func formatMsgAndArgs(dflt string, msgAndArgs ...any) string { if len(msgAndArgs) == 0 { return dflt } - return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + format, ok := msgAndArgs[0].(string) + if !ok { + panic("message argument to assert function must be a fmt string") + } + return fmt.Sprintf(format, msgAndArgs[1:]...) } func needlePosition(haystack, needle string) (quotedHaystack, quotedNeedle, positions string) { diff --git a/assert_test.go b/assert_test.go index 00bdfa3..20bf37a 100644 --- a/assert_test.go +++ b/assert_test.go @@ -139,6 +139,12 @@ func TestIsError(t *testing.T) { }) } +func TestInvalidFormatMsg(t *testing.T) { + Panics(t, func() { + NotZero(t, Data{}, 123) + }) +} + func TestNotIsError(t *testing.T) { assertFail(t, "SameError", func(t testing.TB) { NotIsError(t, fmt.Errorf("os error: %w", os.ErrClosed), os.ErrClosed)