RESP protocol deserializer

I find myself, yet again, doing a lot of serialization / de-serialization.

Specifically, I am working on the RESP protocol right now, which is what is used by Valkey, Redis, etc.

(I see also @ralph and @kristoff are working on this too)

The strategy I am using for this go around is the “recursive data structure” style.

pub const DataType = enum(u8) {
    simple_string = '+',
    simple_error = '-',
    integer = ':',
    bulk_string = '$',
    array = '*',
    null = '_',
    bool = '#',
    double = ',',
    big_number = '(',
    bulk_error = '!',
    verbatim_string = '=',
    map = '%',
    set = '~',
    push = '>',
};

pub const RESPType = union(DataType) {
    simple_string: []const u8,
    simple_error: []const u8,
    integer: i64,
    bulk_string: []const u8,
    array: []const RESPType,
    null: void,
    bool: bool,
    double: f64,
    big_number: []const u8, // TODO: use i128 or something?
    bulk_error: []const u8,
    verbatim_string: struct {
        encoding: [3]u8,
        data: []const u8,
    },
    map: []const MapItem,
    set: []const RESPType,
    push: []const RESPType,

    pub const MapItem = struct {
        key: RESPType,
        value: RESPType,
    };
};

Where this union can represent everything in the protocol. Notice that it is a union that can contain pointers to itelf (it is a recursive data structure).

And I can make a decoder for it:

/// This function doesn't free. The caller is responsible for using
/// an arena.
pub fn decodeRecursive(allocator: std.mem.Allocator, reader: anytype, max_size: usize) error{ OutOfMemory, Invalid, EndOfStream, StreamTooLong, InvalidCharacter, Overflow }!RESPType {
    const byte = try reader.readByte();
    const data_type = std.meta.intToEnum(DataType, byte) catch return error.Invalid;

    switch (data_type) {
        .simple_string => {
            const slice = try reader.readUntilDelimiterAlloc(allocator, '\r', max_size);
            try reader.skipBytes(1, .{});
            return RESPType{ .simple_string = slice };
        },
        .simple_error => {
            const slice = try reader.readUntilDelimiterAlloc(allocator, '\r', max_size);
            try reader.skipBytes(1, .{});
            return RESPType{ .simple_error = slice };
        },
        .integer => {
            var buf: [100]u8 = undefined;
            const slice = try reader.readUntilDelimiter(&buf, '\r');
            const int = try std.fmt.parseInt(i64, slice, 10);
            try reader.skipBytes(1, .{});
            return RESPType{ .integer = int };
        },
        .bulk_string => {
            const length = try decodeElementCount(reader, i64);
            // this is stupid
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const string = try allocator.alloc(u8, @intCast(length));
            try reader.readNoEof(string);
            try reader.skipBytes(2, .{});
            return RESPType{ .bulk_string = string };
        },
        .array => {
            const length = try decodeElementCount(reader, i64);
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const array = try allocator.alloc(RESPType, @intCast(length));
            for (array) |*element| {
                element.* = try decodeRecursive(allocator, reader, max_size);
            }
            return RESPType{ .array = array };
        },
        .null => {
            try reader.skipBytes(2, .{});
            return RESPType{ .null = {} };
        },
        .bool => {
            const value: bool = switch (try reader.readByte()) {
                't' => true,
                'f' => false,
                else => return error.Invalid,
            };
            try reader.skipBytes(2, .{});
            return RESPType{ .bool = value };
        },
        .double => {
            var buf: [100]u8 = undefined;
            const slice = try reader.readUntilDelimiter(&buf, '\r');
            const double = try std.fmt.parseFloat(f64, slice);
            try reader.skipBytes(1, .{});
            return RESPType{ .double = double };
        },
        .big_number => {
            const slice = try reader.readUntilDelimiterAlloc(allocator, '\r', max_size);
            try reader.skipBytes(1, .{});
            return RESPType{ .big_number = slice };
        },
        .bulk_error => {
            const length = try decodeElementCount(reader, i64);
            // this is stupid
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const string = try allocator.alloc(u8, @intCast(length));
            try reader.readNoEof(string);
            try reader.skipBytes(2, .{});
            return RESPType{ .bulk_error = string };
        },
        .verbatim_string => {
            const length = try decodeElementCount(reader, i64);
            // this is stupid
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const string = try allocator.alloc(u8, @intCast(length));
            try reader.readNoEof(string);
            try reader.skipBytes(2, .{});
            if (length < 4) {
                return RESPType{ .bulk_string = string };
            } else {
                var encoding: [3]u8 = undefined;
                @memcpy(&encoding, string[0..3]);
                return RESPType{ .verbatim_string = .{ .data = string[4..], .encoding = encoding } };
            }
        },
        .map => {
            const length = try decodeElementCount(reader, u64);
            if (length > max_size) return error.StreamTooLong;
            comptime assert(@TypeOf(max_size) == usize);
            assert(length <= std.math.maxInt(usize));
            const map = try allocator.alloc(RESPType.MapItem, @intCast(length));
            for (map) |*kv| {
                kv.key = try decodeRecursive(allocator, reader, max_size);
                kv.value = try decodeRecursive(allocator, reader, max_size);
            }
            return RESPType{ .map = map };
        },
        .set => {
            const length = try decodeElementCount(reader, i64);
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const set = try allocator.alloc(RESPType, @intCast(length));
            for (set) |*element| {
                element.* = try decodeRecursive(allocator, reader, max_size);
            }
            return RESPType{ .set = set };
        },
        .push => {
            const length = try decodeElementCount(reader, i64);
            if (length == -1) {
                return RESPType{ .null = {} };
            } else if (length < -1) return error.Invalid;

            if (length > max_size) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            const push = try allocator.alloc(RESPType, @intCast(length));
            for (push) |*element| {
                element.* = try decodeRecursive(allocator, reader, max_size);
            }
            return RESPType{ .push = push };
        },
    }
}

fn decodeElementCount(reader: anytype, int_type: type) !int_type {
    var buf: [100]u8 = undefined;
    const slice = try reader.readUntilDelimiter(&buf, '\r');
    const int = try std.fmt.parseInt(int_type, slice, 10);
    try reader.skipBytes(1, .{});
    return int;
}

There are problems with this:

  1. It allocates like crazy. (somewhat unavoidable?)
  2. It accepts a reader, which is actually annoying because the return type of this function changes on every style of reader you give it. (A TCP stream has different errors than a fixed buffer stream, for example).

I can wrap this up in a bow and reduce the errors like this:

/// Call deinit() on this to free it.
pub fn Decoded(comptime T: type) type {
    return struct {
        arena: *std.heap.ArenaAllocator,
        value: T,
        pub fn deinit(self: @This()) void {
            const allocator = self.arena.child_allocator;
            self.arena.deinit();
            allocator.destroy(self.arena);
        }
    };
}

pub fn decodeAlloc(allocator: std.mem.Allocator, reader: anytype, max_size: usize) !Decoded(RESPType) {
    const arena = try allocator.create(std.heap.ArenaAllocator);
    errdefer allocator.destroy(arena);
    arena.* = .init(allocator);
    errdefer arena.deinit();
    const res = decodeRecursive(arena.allocator(), reader, max_size) catch |err| switch (err) {
        error.OutOfMemory => return error.OutOfMemory,
        error.Invalid, error.EndOfStream, error.StreamTooLong, error.InvalidCharacter, error.Overflow => return error.Invalid,
    };
    return Decoded(RESPType){ .arena = arena, .value = res };
}

But this has the following problems:

  1. It requires passing a slice to the decoder. So you need to know the length of your redis message. Which is impossible because RESP is a streaming protocol and I don’t get to know the length of messages before I read them.

Everything also has the addtional problem:

  1. When an error occurs, like OutOfMemory, I don’t really want to “currupt” the position of the stream. I need to continue to “consume” the rest of the message. Otherwise, I will lose my place and have to just open a new TCP connection to the database.

Anyone have resources or better ideas to deal with these problems?

RESP is very similar to JSON in the sense that both have similar data types and both can be arbitrarily deeply nested. Your arena approach is good and also what the zig stdlib json parser does by default.

You can use std.io.AnyReader to make it non-generic. You can get an AnyReader by calling .any() on a reader.

I haven’t touched my parser in years so the code will be horribly outdated, but in my case I have a so called “void parser” that basically consumes a well formed piece of RESP data and throws it away.

In other parts of the parser, when an error is detected, I try to consume any remaining bytes in order to get to a good “start point” and then feed the remaining stuff to the void parser. I don’t remember how religously I did this in the original code, but I remember at least starting that work and finding it doable.

1 Like

I think I can get around this whole void parser complexity by still requiring the user to pass a complete slice, but also providing a readUntilEoRESP method, which is like a void parser but just returns a slice of the bytes containing a single RESP value.

It has the disadvantage of reading the RESP replies twice (first to find the end of the RESP value and again to actually decode it, but I think thats fine for now).

I think if the server returns invalid RESP there’s no reasonable action for a user to take than to just close the TCP connection and open a new one. But at least with readUntilEoResp I will be able to separate out recoverable errors from non-recoverable ones, I hope.

I came up with this stream from reader to writer thing that will allow me to consume one full RESP value before passing it to my decoder so I don’t have to currupt the stream after I have the RESP value:

/// Stream data from reader to writer for one RESP Value.
pub fn streamUntilEoResp(reader: anytype, writer: anytype) !void {
    const byte = try reader.readByte();
    const data_type = std.meta.intToEnum(DataType, byte) catch return error.InvalidRESP;
    try writer.writeByte(byte);

    return switch (data_type) {
        .simple_string, .simple_error, .integer, .double, .big_number => {
            try reader.streamUntilDelimiter(writer, '\r', null);
            try writer.writeAll(separator);
            try reader.skipBytes(1, .{});
        },
        .bulk_string, .bulk_error, .verbatim_string => {
            var buf: [100]u8 = undefined;
            const slice = try reader.readUntilDelimiter(&buf, '\r');
            try writer.writeAll(slice);
            try reader.skipBytes(1, .{});
            try writer.writeAll(separator);
            const length = try std.fmt.parseInt(i64, slice, 10);

            // this is stupid
            if (length == -1) {
                return;
            } else if (length < -1) return error.Invalid;

            if (length > std.math.maxInt(usize)) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            var limited = std.io.limitedReader(reader, @intCast(length));
            const limited_reader = limited.reader();

            var fifo = std.fifo.LinearFifo(u8, .{ .Static = 128 }).init();
            try fifo.pump(limited_reader, writer);
            try reader.skipBytes(2, .{});
            try writer.writeAll(separator);
        },
        .array, .map, .set, .push => |tag| {
            var buf: [100]u8 = undefined;
            const slice = try reader.readUntilDelimiter(&buf, '\r');
            try writer.writeAll(slice);
            try reader.skipBytes(1, .{});
            try writer.writeAll(separator);
            const length = try std.fmt.parseInt(i64, slice, 10);

            // this is stupid
            if (length == -1) {
                return;
            } else if (length < -1) return error.Invalid;

            if (length > std.math.maxInt(usize)) return error.StreamTooLong;
            assert(length <= std.math.maxInt(usize));
            for (0..@intCast(length)) |_| {
                switch (tag) {
                    .array, .set, .push => try streamUntilEoResp(reader, writer),
                    .map => {
                        try streamUntilEoResp(reader, writer);
                        try streamUntilEoResp(reader, writer);
                    },
                    else => unreachable,
                }
            }
        },
        .null => {
            try reader.skipBytes(2, .{});
            try writer.writeAll(separator);
        },
        .bool => {
            try writer.writeByte(try reader.readByte());
            try reader.skipBytes(2, .{});
            try writer.writeAll(separator);
        },
    };
}