std.mem.indexOf with end-user's types

Hi,

I have a recurrent need when using Zig and I feel like a simple issue + pull request could solve this.
However, I am not so confident to do it because:

  • I am pretty new to Zig (I am only working with Zig for 6 months)
  • The Zig Github repository seems to have a lot of traffic in terms of contributions.

For these reasons, I do not want to bother maintainers with my proposal if something already exists to solve it or if it is maybe more work than I imagine.

The problem is pretty simple:

I need to find the index of an element into an array. I already know that the std.mem.indexOf () collection of function can achieve this job. However it is only working for primitive types. What I need is to use it for end-user’s types. I also know that I can achieve some like this:

fn index_of (comptime T: type, slice: []const T, value: T) ?usize
{
  for (slice, 0 ..) |element, index| { if (std.meta.eql (value, element)) return index; }
  else return null;
}

And that is what I am currently doing but because it is something that the std library is already achieving for primitive types, it could also be done for end-user’s types.
Looking into the std lib (and from my limited understanding) the only thing to change is the comparison operator used in each std.mem.indexOf functions:

for (slice[i..], i..) |c, j| {
  // Instead of this:
  if (c == value) return j;
  // use this:
  // if (std.meta.eql (c, value)) return j;
}

Maybe I am over simplifying to solve my need and that is why I am posting here: please let me if I am missing something.

Thank you.

I don’t understand why you dislike your index_of function, can you clarify why you don’t just use that function?

Also if it is difficult to express, maybe you could show some uses of the function that illustrate what you would prefer to be different.

One thing I am wondering is why do you need to find it, would it be possible to restructure the code so that it is more declarative and you already just know, without having to search first? But the answer to that depends highly on how you are using that function.


Hmm after looking at std.mem.indexOf I see that it does something quite different than your index_of function, the former finds a subslice in a bigger slice, the latter finds an element in a slice.

I think the question boils down to: why does std.mem.indexOfScalar() use c == value instead of std.meta.eql(c, value) for the comparison, when the latter is strictly more powerful?

2 Likes

Hi @Sze, thank you for your answer.

I don’t understand why you dislike your index_of function, can you clarify why you don’t just use that function?

I like my function and I am using it. I think you do not get the point of what I am trying to show. Let me illustrate with a new function:

fn is_digit (char: u8) bool
{
  return switch (char)
  {
    '0' ... '9' => true,
    else => false,
  };
}

This function is awesome and I can use it in my code. However, it already exists a function to achieve this: std.ascii.isDigit (). So even if my function can do the job, it is better to use the std library’s function. The std library handles common usages and secures them over time.

To come back to my index_of function, this is the same thing: I have this “utility” function in most of my Zig projects because the std.mem.indexOf functions do not handle this use case. Maybe I am not alone, so maybe its place is not in my projects but in the std library. And maybe it could be handle in the std.mem.indexOf set of functions.

Also if it is difficult to express, maybe you could show some uses of the function that illustrate what you would prefer to be different.

With what I already explained, this is what I have now:

# utils.zig

const std = @import ("std");

pub fn index_of (comptime T: type, slice: []const T, value: T) ?usize
{
  for (slice, 0 ..) |element, index| { if (std.meta.eql (value, element)) return index; }
  else return null;
}

# main.zig

const std = @import ("std");
const utils = @import ("utils.zig");

const MyStruct = struct
{
  x: u8 = undefined,
  y: u8 = undefined,
};

pub fn main () void
{
  const array = [_] MyStruct
  {
    .{ .x = 5, .y = 2 },
    .{ .x = 3, .y = 7 },
    .{ .x = 6, .y = 1 },
    .{ .x = 8, .y = 4 },
  };
  std.debug.print ("{}\n{?}\n", .{ utils.index_of (MyStruct, &array, .{ .x = 6, .y = 1 }).?,
                                   utils.index_of (MyStruct, &array, .{ .x = 6, .y = 10 }), });
}

and what I want instead:

# main.zig

const std = @import ("std");

const MyStruct = struct
{
  x: u8 = undefined,
  y: u8 = undefined,
};

pub fn main () void
{
  const array = [_] MyStruct
  {
    .{ .x = 5, .y = 2 },
    .{ .x = 3, .y = 7 },
    .{ .x = 6, .y = 1 },
    .{ .x = 8, .y = 4 },
  };
  std.debug.print ("{}\n{?}\n", .{ std.mem.indexOfScalar (MyStruct, &array, .{ .x = 6, .y = 1 }).?,
                                   std.mem.lastIndexOfAny (MyStruct, &array, &[_] MyStruct
                                                                             {
                                                                               .{ .x = 6, .y = 10 },
                                                                               .{ .x = 6, .y = 12 },
                                                                             }), });
}

One thing I am wondering is why do you need to find it, would it be possible to restructure the code so that it is more declarative and you already just know, without having to search first?

Yeah for sure, I can. I can also use an std.AutoHashMap, use the contains () and get () methods to solve all of this. But:

  1. it is bringing new problems (now I have to manage an allocator because now I am using a heap memory allocated object instead of an array)
  2. that does not answer the real question: Why can I use std.mem.indexOf () functions with primitive types and not with end-user types ?

Hmm after looking at std.mem.indexOf I see that it does something quite different than your index_of function

Yes you are right, std.mem.indexOf () can not be made for this use case. But for all these functions it could be:

  • std.mem.indexOfAny
  • std.mem.indexOfDiff
  • std.mem.indexOfNone
  • std.mem.indexOfScalar
  • their Pos and last counter-parts.

Hi @maksverver, thank you for your answer: even if I am not only talking about std.mem.indexOfScalar, it is exactly what I tried to mean.

2 Likes

I’m not seeing any objections to your idea - it feel like what you’re suggesting could be a good improvement. I can definitely see people leaning into the standard library more if it worked on custom types more fluently.

That said, it would need more work than I think we’re realizing. You’ll notice a dispatch to indexOfScalarPos.

Here’s that code:

pub fn indexOfScalarPos(comptime T: type, slice: []const T, start_index: usize, value: T) ?usize {
    if (start_index >= slice.len) return null;

    var i: usize = start_index;
    if (backend_supports_vectors and
        !@inComptime() and
        (@typeInfo(T) == .Int or @typeInfo(T) == .Float) and std.math.isPowerOfTwo(@bitSizeOf(T)))
    {
        if (std.simd.suggestVectorSize(T)) |block_len| {
            // For Intel Nehalem (2009) and AMD Bulldozer (2012) or later, unaligned loads on aligned data result
            // in the same execution as aligned loads. We ignore older arch's here and don't bother pre-aligning.
            //
            // Use `std.simd.suggestVectorSize(T)` to get the same alignment as used in this function
            // however this usually isn't necessary unless your arch has a performance penalty due to this.
            //
            // This may differ for other arch's. Arm for example costs a cycle when loading across a cache
            // line so explicit alignment prologues may be worth exploration.

            // Unrolling here is ~10% improvement. We can then do one bounds check every 2 blocks
            // instead of one which adds up.
            const Block = @Vector(block_len, T);
            if (i + 2 * block_len < slice.len) {
                const mask: Block = @splat(value);
                while (true) {
                    inline for (0..2) |_| {
                        const block: Block = slice[i..][0..block_len].*;
                        const matches = block == mask;
                        if (@reduce(.Or, matches)) {
                            return i + std.simd.firstTrue(matches).?;
                        }
                        i += block_len;
                    }
                    if (i + 2 * block_len >= slice.len) break;
                }
            }

            // {block_len, block_len / 2} check
            inline for (0..2) |j| {
                const block_x_len = block_len / (1 << j);
                comptime if (block_x_len < 4) break;

                const BlockX = @Vector(block_x_len, T);
                if (i + block_x_len < slice.len) {
                    const mask: BlockX = @splat(value);
                    const block: BlockX = slice[i..][0..block_x_len].*;
                    const matches = block == mask;
                    if (@reduce(.Or, matches)) {
                        return i + std.simd.firstTrue(matches).?;
                    }
                    i += block_x_len;
                }
            }
        }
    }

    for (slice[i..], i..) |c, j| {
        if (c == value) return j;
    }
    return null;
}

So it’s attempting to use SIMD optimization which will only work with the types supported by @Vector. There would need to be an additional comptime check or dispatch.

At the very end, we see the following…

    for (slice[i..], i..) |c, j| {
        if (c == value) return j;
    }

That’s where std.mem.eql would come into play. Here’s the issue - c is being capture by value. That may end up creating many copies in the loop. If it was replaced by a pointer capture, std.mem.eql doesn’t follow pointers. It would probably need to be replaced by an indexed loop instead of a capture.

Same thing would need to change here for indexOfAny which dispatches to indexOfAnyPos:

pub fn indexOfAnyPos(comptime T: type, slice: []const T, start_index: usize, values: []const T) ?usize {
    if (start_index >= slice.len) return null;
    for (slice[start_index..], start_index..) |c, i| {
        for (values) |value| {
            if (c == value) return i;
        }
    }
    return null;
}

The issue here (that @matklad has pointed out some time ago) is that if you’re capturing potentially large structs in a loop, it can make a lot of copies unintentionally. You’d have to convert it to an indexed loop and check the assembly (and it’s possible that std.mem.eql may get referential argument optimizations on higher optimization levels).

One addendum here too - you need to deal with things like slice comparisons. If I have two structs that use strings, std.meta.eql will not compare the value of them with how it’s currently implemented. Here’s the block that deals with that:

        .Struct => |info| {
            inline for (info.fields) |field_info| {
                if (!eql(@field(a, field_info.name), @field(b, field_info.name))) return false;
            }
            return true;
        },

It’s now calling eql on the struct fields - that recurses. The next time through, for []const u8, it will go to the pointer branch that then dispatches to slice… here’s that code:

        .Pointer => |info| {
            return switch (info.size) {
                .One, .Many, .C => a == b,
                .Slice => a.ptr == b.ptr and a.len == b.len,
            };

You can see here that it’s comparing slice’s ptr and len fields. It’s not doing a value comparison like you may expect. It’s doing a shallow comparison, not a deep one. This can be easily misunderstood and cause people to compare different things than what they think.

This idea needs a lot of consideration to make sure we want what we’re actually asking for.

2 Likes

So here’s something that’s interesting… std.meta.eql compares arrays differently - it attempts to do a deeper comparison…

        .Array => {
            if (a.len != b.len) return false;
            for (a, 0..) |e, i|
                if (!eql(e, b[i])) return false;
            return true;
        },

I think splitting std.meta.eql into std.meta.equalDeep and std.meta.equalShallow would be my first recommendation here. Deep would follow pointers and slices whereas shallow would not. Until that’s sorted out, there’s a lot of caveats that could trip up the uninitiated lol.

3 Likes

But after that’s sorted out, there’s still a lot of caveats that can trip up the uninitiated. For example, if equalDeep doesn’t handle it, any self-referential struct will cause an infinite loop (e.g. std.DoublyLinkedList(T).Node).

3 Likes

I was thinking there was still a sneaky edge-case here, but (as always @squeek502) that’s a great example.

The other issue is that the API doesn’t support this idea as it currently stands - you’d have to either make two separate implementations of search/equality functions or always provide a way for the user to parameterize that option. It’d be a massive breaking change.

Anyhow, to the issue at hand - equality isn’t a trivial thing to implement, so there’s a lot to consider.

3 Likes

Hi @AndrewCodeDev & @squeek502, thank you for your answers.

Because you mentioned this several times:

there’s a lot of caveats that could trip up the uninitiated

Your answers made me understand that this issue seems to involve (a lot) more work than I was expecting. I do not regret to open this topic here before posting an issue on the Zig Github repository.

The logical next step of this discussion is to open an issue on the Zig Github repository. I am going to link it with this brainstorming. It should be an helpful resource for maintainers.

1 Like

I think it’s also worth considering that trying to use indexOf with std.meta.eql should make you step back and think about what you’re actually trying to do. Linear scans are worth avoiding generally if something better can be done, and it seems like that’d be especially true when the equality of two items in the list is non-trivial to compute.

(as an example, see GeneralPurposeAllocator: Considerably improve worst case performance by squeek502 · Pull Request #17383 · ziglang/zig · GitHub where going from a O(n) linear scan to a O(log n) search via std.Treap took some real-world worst case performance from ~8 minutes to ~20 seconds)

6 Likes

I personally thought it was a great thread - at first I was like “yeah, why don’t we use that :thinking:…” because everything seemed to fit together nicely. It took some digging but I think we found some valuable stuff :slight_smile:

5 Likes