Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix std.mem.eql, std.mem.indexOfDiff for zero-sized types, comptime-only types, non-reflexive equality types #22102

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
123 changes: 83 additions & 40 deletions lib/std/mem.zig
Original file line number Diff line number Diff line change
Expand Up @@ -654,49 +654,67 @@ const eqlBytes_allowed = switch (builtin.zig_backend) {
else => !builtin.fuzz,
};

/// Returns true if and only if the slices have the same length and all elements
/// compare true using equality operator.
pub fn eql(comptime T: type, a: []const T, b: []const T) bool {
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T) and
eqlBytes_allowed)
{
return eqlBytes(sliceAsBytes(a), sliceAsBytes(b));
}
/// Compares two slices and returns whether they are equal.
///
/// Two slices are defined to be equal if and only if they are
/// - the same length, and
/// - all of their element pairs compare equal, according to `==`.
///
/// Runtime-available types that have a unique bit representation
/// (according to `std.meta.hasUniqueRepresentation`) may still be
/// compared, even if they do not support the `==` operator. Such
/// types are compared bitwise.
pub inline fn eql(comptime T: type, a: []const T, b: []const T) bool {
return struct {
fn impl(x: []const T, y: []const T) union { res: bool, force_comptime: T } {
// slices with different lengths are always unequal
if (x.len != y.len) return .{ .res = false };

if (a.len != b.len) return false;
if (a.len == 0 or a.ptr == b.ptr) return true;
// pointer equality optimisation disabled for floating point
// numbers, as they may compare unequal to themselves: NaN != NaN
if (@typeInfo(T) != .float and x.ptr == y.ptr) return .{ .res = true };

for (a, b) |a_elem, b_elem| {
if (a_elem != b_elem) return false;
}
return true;
if (!@inComptime() and std.meta.hasUniqueRepresentation(T) and eqlBytes_allowed)
return .{ .res = eqlBytes(sliceAsBytes(x), sliceAsBytes(y)) };

return for (x, y) |x_elem, y_elem| {
if (x_elem != y_elem) break .{ .res = false };
} else .{ .res = true };
}
}.impl(a, b).res;
}

test eql {
// bitwise unique
try testing.expect(eql(u8, "abcd", "abcd"));
try testing.expect(!eql(u8, "abcdef", "abZdef"));
try testing.expect(!eql(u8, "abcdefg", "abcdef"));

comptime {
try testing.expect(eql(type, &.{ bool, f32 }, &.{ bool, f32 }));
try testing.expect(!eql(type, &.{ bool, f32 }, &.{ f32, bool }));
try testing.expect(!eql(type, &.{ bool, f32 }, &.{bool}));
// ZSTs
try testing.expect(eql(void, &.{ {}, {}, {} }, &.{ {}, {}, {} }));
try testing.expect(!eql(void, &.{ {}, {}, {} }, &.{ {}, {}, {}, {} }));

try testing.expect(eql(comptime_int, &.{ 1, 2, 3 }, &.{ 1, 2, 3 }));
try testing.expect(!eql(comptime_int, &.{ 1, 2, 3 }, &.{ 3, 2, 1 }));
try testing.expect(!eql(comptime_int, &.{1}, &.{ 1, 2 }));
}
// padded types
try testing.expect(eql(u7, &.{ 1, 2, 3 }, &.{ 1, 2, 3 }));
try testing.expect(!eql(u7, &.{ 1, 2, 3 }, &.{ 1, 50, 3 }));

// comptime-only types
try testing.expect(eql(type, &.{ usize, []const u8, void }, &.{ usize, []const u8, void }));
try testing.expect(!eql(type, &.{ usize, []const u8, void }, &.{ usize, [*:false]bool, void }));

try testing.expect(eql(void, &.{ {}, {} }, &.{ {}, {} }));
try testing.expect(!eql(void, &.{{}}, &.{ {}, {} }));
// non-reflexive types
const floats: [4]f32 = .{ 0.0, 2.5, 3.141592, 1e10 };
const sinks: [4]f32 = .{ -0.0, 2.5, 3.141592, 1e10 };
const nans: [3]f32 = .{ 100.0, std.math.nan(f32), -100.0 };
try testing.expect(eql(f32, &floats, &sinks)); // 0.0 == -0.0
try testing.expect(!eql(f32, &nans, &nans)); // NaN != NaN
}

/// std.mem.eql heavily optimized for slices of bytes.
fn eqlBytes(a: []const u8, b: []const u8) bool {
comptime assert(eqlBytes_allowed);

if (a.len != b.len) return false;
if (a.len == 0 or a.ptr == b.ptr) return true;
if (a.len == 0) return true;

if (a.len <= 16) {
if (a.len < 4) {
Expand Down Expand Up @@ -752,23 +770,48 @@ fn eqlBytes(a: []const u8, b: []const u8) bool {
return !Scan.isNotEqual(last_a_chunk, last_b_chunk);
}

/// Compares two slices and returns the index of the first inequality.
/// Returns null if the slices are equal.
pub fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
const shortest = @min(a.len, b.len);
if (a.ptr == b.ptr)
return if (a.len == b.len) null else shortest;
var index: usize = 0;
while (index < shortest) : (index += 1) if (a[index] != b[index]) return index;
return if (a.len == b.len) null else shortest;
/// Compares two slices and returns the index of the first inequality,
/// returns `null` if the slices are equal.
///
/// Elements are tested according to the `==` operator, if one slice is
/// a [proper prefix](https://en.wikipedia.org/wiki/Substring#Prefix) of
/// the other, the length of the former is returned.
pub inline fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
return struct {
fn impl(x: []const T, y: []const T) union { res: ?usize, force_comptime: T } {
const short = @min(x.len, y.len);

// pointer equality optimisation disabled for floating point
// numbers, as they may compare unequal to themselves: NaN != NaN
if (@typeInfo(T) != .float and x.ptr == y.ptr)
return .{ .res = if (x.len == y.len) null else short };

return for (x.ptr, y.ptr, 0..short) |x_elem, y_elem, i| {
if (x_elem != y_elem) break .{ .res = i };
} else .{ .res = if (x.len == y.len) null else short };
}
}.impl(a, b).res;
}

test indexOfDiff {
try testing.expectEqual(indexOfDiff(u8, "one", "one"), null);
try testing.expectEqual(indexOfDiff(u8, "one two", "one"), 3);
try testing.expectEqual(indexOfDiff(u8, "one", "one two"), 3);
try testing.expectEqual(indexOfDiff(u8, "one twx", "one two"), 6);
try testing.expectEqual(indexOfDiff(u8, "xne", "one"), 0);
try testing.expectEqual(null, indexOfDiff(u8, "one", "one"));
try testing.expectEqual(3, indexOfDiff(u8, "one two", "one"));
try testing.expectEqual(3, indexOfDiff(u8, "one", "one two"));
try testing.expectEqual(6, indexOfDiff(u8, "one twx", "one two"));
try testing.expectEqual(0, indexOfDiff(u8, "xne", "one"));
try testing.expectEqual(null, indexOfDiff(void, &.{ {}, {}, {} }, &.{ {}, {}, {} }));
try testing.expectEqual(3, indexOfDiff(void, &.{ {}, {}, {} }, &.{ {}, {}, {}, {} }));
try testing.expectEqual(null, indexOfDiff(u7, &.{ 1, 2, 3 }, &.{ 1, 2, 3 }));
try testing.expectEqual(1, indexOfDiff(u7, &.{ 1, 2, 3 }, &.{ 1, 50, 3 }));

try testing.expectEqual(null, indexOfDiff(type, &.{ usize, []const u8, void }, &.{ usize, []const u8, void }));
try testing.expectEqual(1, indexOfDiff(type, &.{ usize, []const u8, void }, &.{ usize, [*:false]bool, void }));

const floats: [4]f32 = .{ 0.0, 2.5, 3.141592, 1e10 };
const sinks: [4]f32 = .{ -0.0, 2.5, 3.141592, 1e10 };
const nans: [3]f32 = .{ 100.0, std.math.nan(f32), -100.0 };
try testing.expectEqual(null, indexOfDiff(f32, &floats, &sinks)); // 0.0 == -0.0
try testing.expectEqual(1, indexOfDiff(f32, &nans, &nans)); // NaN != NaN
}

/// Takes a sentinel-terminated pointer and returns a slice preserving pointer attributes.
Expand Down
1 change: 1 addition & 0 deletions lib/std/meta.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,7 @@ pub inline fn hasUniqueRepresentation(comptime T: type) bool {
.@"enum",
.error_set,
.@"fn",
.void,
=> true,

.bool => false,
Expand Down
Loading