diff --git a/src/pgzx/datum.zig b/src/pgzx/datum.zig index 81177d4..6b8c2b2 100644 --- a/src/pgzx/datum.zig +++ b/src/pgzx/datum.zig @@ -7,55 +7,116 @@ const mem = @import("mem.zig"); const meta = @import("meta.zig"); const varatt = @import("varatt.zig"); -pub fn Conv(comptime T: type, comptime from: anytype, comptime to: anytype) type { +pub fn fromNullableDatum(comptime T: type, d: pg.NullableDatum) !T { + return findConv(T).fromNullableDatum(d); +} + +pub fn fromNullableDatumWithOID(comptime T: type, d: pg.NullableDatum, oid: ?pg.Oid) !T { + return findConv(T).fromNullableDatumWithOID(d, oid); +} + +pub fn fromDatum(comptime T: type, d: pg.Datum, is_null: bool) !T { + return findConv(T).fromNullableDatum(.{ .value = d, .isnull = is_null }); +} + +pub fn fromDatumWithOID(comptime T: type, d: pg.Datum, is_null: bool, oid: ?pg.Oid) !T { + return findConv(T).fromNullableDatumWithOID(.{ .value = d, .isnull = is_null }, oid); +} + +pub fn toNullableDatum(v: anytype) !pg.NullableDatum { + return findConv(@TypeOf(v)).toNullableDatum(v); +} + +pub fn toNullableDatumWithOID(v: anytype, oid: ?pg.Oid) !pg.NullableDatum { + return findConv(@TypeOf(v)).toNullableDatumWithOID(v, oid); +} + +// pub fn Conv(comptime T: type, comptime from: anytype, comptime to: anytype) type { +pub fn Conv(comptime context: type) type { return struct { - pub const Type = T; + pub const Type = context.Type; + + const Self = @This(); + pub fn fromNullableDatum(d: pg.NullableDatum) !Type { + return Self.fromNullableDatumWithOID(d, null); + } + + pub fn fromNullableDatumWithOID(d: pg.NullableDatum, oid: ?pg.Oid) !Type { if (d.isnull) { return err.PGError.UnexpectedNullValue; } - return try from(d.value); + return try context.from(d.value, normalizeOid(oid)); } + pub fn toNullableDatum(v: Type) !pg.NullableDatum { + return Self.toNullableDatumWithOID(v, null); + } + + pub fn toNullableDatumWithOID(v: Type, oid: ?pg.Oid) !pg.NullableDatum { return .{ - .value = try to(v), + .value = try context.to(v, normalizeOid(oid)), .isnull = false, }; } }; } -pub fn ConvNoFail(comptime T: type, comptime from: anytype, comptime to: anytype) type { - return struct { +pub fn ConvNoFail(comptime context: type) type { + return Conv(struct { + pub const Type = context.Type; + + pub fn from(d: pg.Datum, oid: pg.Oid) !Type { + return context.from(d, oid); + } + + pub fn to(v: Type, oid: pg.Oid) !pg.Datum { + return context.to(v, oid); + } + }); +} + +pub fn SimpleConv(comptime T: type, comptime from_datum: anytype, comptime to_datum: anytype) type { + return ConvNoFail(struct { pub const Type = T; - pub fn fromNullableDatum(d: pg.NullableDatum) !T { - if (d.isnull) { - return err.PGError.UnexpectedNullValue; - } - return from(d.value); + + pub fn from(d: pg.Datum, oid: pg.Oid) !Type { + _ = oid; + return from_datum(d); } - pub fn toNullableDatum(v: T) !pg.NullableDatum { - return .{ - .value = to(v), - .isnull = false, - }; + + pub fn to(v: Type, oid: pg.Oid) !pg.Datum { + _ = oid; + return to_datum(v); } - }; + }); } /// Conversion decorator for optional types. pub fn OptConv(comptime C: anytype) type { return struct { pub const Type = ?C.Type; + + const Self = @This(); + pub fn fromNullableDatum(d: pg.NullableDatum) !Type { + return try Self.fromNullableDatumWithOID(d, null); + } + + pub fn fromNullableDatumWithOID(d: pg.NullableDatum, oid: ?pg.Oid) !Type { if (d.isnull) { return null; } - return try C.fromNullableDatum(d); + return try C.fromNullableDatumWithOID(d, oid); } + pub fn toNullableDatum(v: Type) !pg.NullableDatum { + return Self.toNullableDatumWithOID(v, null); + } + + pub fn toNullableDatumWithOID(v: Type, oid: ?pg.Oid) !pg.NullableDatum { if (v) |value| { - return try C.toNullableDatum(value); + return try C.toNullableDatumWithOID(value, oid); } else { return .{ .value = 0, @@ -71,21 +132,8 @@ pub fn OptConv(comptime C: anytype) type { /// reflection only. var directMappings = .{ .{ pg.Datum, PGDatum }, - .{ pg.NullableDatum, PGNullableDatum }, }; -pub fn fromNullableDatum(comptime T: type, d: pg.NullableDatum) !T { - return findConv(T).fromNullableDatum(d); -} - -pub fn fromDatum(comptime T: type, d: pg.Datum, is_null: bool) !T { - return findConv(T).fromNullableDatum(.{ .value = d, .isnull = is_null }); -} - -pub fn toNullableDatum(v: anytype) !pg.NullableDatum { - return findConv(@TypeOf(v)).toNullableDatum(v); -} - pub fn findConv(comptime T: type) type { if (isConv(T)) { // is T already a converter? return T; @@ -155,32 +203,35 @@ inline fn isConv(comptime T: type) bool { return @hasDecl(T, "Type") and @hasDecl(T, "fromNullableDatum") and @hasDecl(T, "toNullableDatum"); } -pub const Void = ConvNoFail(void, idDatum, toVoid); -pub const Bool = ConvNoFail(bool, pg.DatumGetBool, pg.BoolGetDatum); -pub const Int8 = ConvNoFail(i8, datumGetInt8, pg.Int8GetDatum); -pub const Int16 = ConvNoFail(i16, pg.DatumGetInt16, pg.Int16GetDatum); -pub const Int32 = ConvNoFail(i32, pg.DatumGetInt32, pg.Int32GetDatum); -pub const Int64 = ConvNoFail(i64, pg.DatumGetInt64, pg.Int64GetDatum); -pub const UInt8 = ConvNoFail(u8, pg.DatumGetUInt8, pg.UInt8GetDatum); -pub const UInt16 = ConvNoFail(u16, pg.DatumGetUInt16, pg.UInt16GetDatum); -pub const UInt32 = ConvNoFail(u32, pg.DatumGetUInt32, pg.UInt32GetDatum); -pub const UInt64 = ConvNoFail(u64, pg.DatumGetUInt64, pg.UInt64GetDatum); -pub const Float32 = ConvNoFail(f32, pg.DatumGetFloat4, pg.Float4GetDatum); -pub const Float64 = ConvNoFail(f64, pg.DatumGetFloat8, pg.Float8GetDatum); - -pub const SliceU8 = Conv([]const u8, getDatumTextSlice, sliceToDatumText); -pub const SliceU8Z = Conv([:0]const u8, getDatumTextSliceZ, sliceToDatumTextZ); - -pub const PGDatum = ConvNoFail(pg.Datum, idDatum, idDatum); -const PGNullableDatum = struct { - pub const Type = pg.NullableDatum; - pub fn fromNullableDatum(d: pg.NullableDatum) !Type { - return d; - } - pub fn toNullableDatum(v: Type) !pg.NullableDatum { - return v; - } -}; +inline fn normalizeOid(oid: ?pg.Oid) pg.Oid { + return oid orelse pg.InvalidOid; +} + +pub const Void = SimpleConv(void, idDatum, toVoid); +pub const Bool = SimpleConv(bool, pg.DatumGetBool, pg.BoolGetDatum); +pub const Int8 = SimpleConv(i8, datumGetInt8, pg.Int8GetDatum); +pub const Int16 = SimpleConv(i16, pg.DatumGetInt16, pg.Int16GetDatum); +pub const Int32 = SimpleConv(i32, pg.DatumGetInt32, pg.Int32GetDatum); +pub const Int64 = SimpleConv(i64, pg.DatumGetInt64, pg.Int64GetDatum); +pub const UInt8 = SimpleConv(u8, pg.DatumGetUInt8, pg.UInt8GetDatum); +pub const UInt16 = SimpleConv(u16, pg.DatumGetUInt16, pg.UInt16GetDatum); +pub const UInt32 = SimpleConv(u32, pg.DatumGetUInt32, pg.UInt32GetDatum); +pub const UInt64 = SimpleConv(u64, pg.DatumGetUInt64, pg.UInt64GetDatum); +pub const Float32 = SimpleConv(f32, pg.DatumGetFloat4, pg.Float4GetDatum); +pub const Float64 = SimpleConv(f64, pg.DatumGetFloat8, pg.Float8GetDatum); +pub const PGDatum = SimpleConv(pg.Datum, idDatum, idDatum); + +pub const SliceU8Z = Conv(struct { + pub const Type = [:0]const u8; + pub const from = getDatumStringLikeZ; + pub const to = sliceToDatumStringLikeZ; +}); + +pub const SliceU8 = Conv(struct { + pub const Type = []const u8; + pub const from = getDatumStringLikeZ; + pub const to = sliceToDatumStringLike; +}); // TODO: conversion decorator for array types @@ -199,10 +250,26 @@ fn datumGetInt8(d: pg.Datum) i8 { return @as(i8, @bitCast(@as(i8, @truncate(d)))); } +pub fn getDatumStringLike(datum: pg.Datum, oid: pg.Oid) ![]const u8 { + return getDatumStringLikeZ(datum, oid); +} + /// Convert a datum to a TEXT slice. This function detoast the datum if necessary. /// All allocations will be performed in the Current Memory Context. -pub fn getDatumTextSlice(datum: pg.Datum) ![]const u8 { - return getDatumTextSliceZ(datum); +pub fn getDatumTextSlice(datum: pg.Datum, oid: pg.Oid) ![]const u8 { + return getDatumTextSliceZ(datum, oid); +} + +pub inline fn getDatumCString(datum: pg.Datum) ![]const u8 { + return getDatumCStringZ(datum); +} + +pub fn getDatumStringLikeZ(datum: pg.Datum, oid: pg.Oid) ![:0]const u8 { + return if (useStringPointer(oid)) getDatumCStringZ(datum) else getDatumTextSliceZ(datum); +} + +pub inline fn getDatumCStringZ(datum: pg.Datum) ![:0]const u8 { + return std.mem.span(pg.DatumGetCString(datum)); } /// Convert a datum to a TEXT slice. This function detoast the datum if necessary. @@ -222,6 +289,24 @@ pub fn getDatumTextSliceZ(datum: pg.Datum) ![:0]const u8 { return buffer[0..len :0]; } +pub fn sliceToDatumStringLikeZ(slice: [:0]const u8, oid: pg.Oid) !pg.Datum { + return if (useStringPointer(oid)) sliceToDatumCStringZ(slice) else sliceToDatumTextZ(slice); +} + +pub fn sliceToDatumStringLike(slice: []const u8, oid: pg.Oid) !pg.Datum { + return if (useStringPointer(oid)) sliceToDatumCString(slice) else sliceToDatumText(slice); +} + +pub inline fn sliceToDatumCString(slice: []const u8) !pg.Datum { + const alloc = mem.PGCurrentContextAllocator; + const slice_z = try alloc.dupeZ(u8, slice); + return pg.CStringGetDatum(slice_z.ptr); +} + +pub inline fn sliceToDatumCStringZ(slice: [:0]const u8) !pg.Datum { + return pg.CStringGetDatum(slice.ptr); +} + pub inline fn sliceToDatumText(slice: []const u8) !pg.Datum { const text = pg.cstring_to_text_with_len(slice.ptr, @intCast(slice.len)); return pg.PointerGetDatum(text); @@ -230,3 +315,10 @@ pub inline fn sliceToDatumText(slice: []const u8) !pg.Datum { pub inline fn sliceToDatumTextZ(slice: [:0]const u8) !pg.Datum { return sliceToDatumText(slice); } + +pub inline fn useStringPointer(oid: pg.Oid) bool { + return switch (oid) { + pg.CHAROID, pg.NAMEOID, pg.CSTRINGOID => true, + else => false, + }; +} diff --git a/src/pgzx/fmgr/args.zig b/src/pgzx/fmgr/args.zig index db12b13..66ca96b 100644 --- a/src/pgzx/fmgr/args.zig +++ b/src/pgzx/fmgr/args.zig @@ -1,5 +1,7 @@ const std = @import("std"); const pg = @import("pgzx_pgsys"); + +const err = @import("../err.zig"); const datum = @import("../datum.zig"); /// Index function argument type. @@ -58,7 +60,9 @@ fn readArg(comptime T: type, fcinfo: pg.FunctionCallInfo, argNum: u32) !readArgT return fcinfo; } const converter = comptime datum.findConv(T); - return converter.fromNullableDatum(try mustGetArgNullable(fcinfo, argNum)); + const oid = try err.wrap(pg.get_fn_expr_argtype, .{ fcinfo.*.flinfo, @as(c_int, @intCast(argNum)) }); + const ndatum = try mustGetArgNullable(fcinfo, argNum); + return converter.fromNullableDatumWithOID(ndatum, oid); } fn readOptionalArg(comptime T: type, fcinfo: pg.FunctionCallInfo, argNum: u32) !?T { diff --git a/src/pgzx/spi.zig b/src/pgzx/spi.zig index 4dbead4..932045d 100644 --- a/src/pgzx/spi.zig +++ b/src/pgzx/spi.zig @@ -1,6 +1,7 @@ const std = @import("std"); const pg = @import("pgzx_pgsys"); +const meta = @import("meta.zig"); const mem = @import("mem.zig"); const err = @import("err.zig"); const datum = @import("datum.zig"); @@ -43,7 +44,24 @@ pub const ExecOptions = struct { pub const SPIError = err.PGError || std.mem.Allocator.Error; -pub fn exec(sql: [:0]const u8, options: ExecOptions) SPIError!c_int { +pub fn exec(sql: [:0]const u8, options: ExecOptions) SPIError!isize { + const ret = try execImpl(sql, options); + var rows = Rows.init(); + defer rows.deinit(); + return @intCast(ret); +} + +pub fn query(sql: [:0]const u8, options: ExecOptions) SPIError!Rows { + _ = try execImpl(sql, options); + return Rows.init(); +} + +pub fn queryTyped(comptime T: type, sql: [:0]const u8, options: ExecOptions) SPIError!RowsOf(T) { + const rows = try query(sql, options); + return rows.typed(T); +} + +fn execImpl(sql: [:0]const u8, options: ExecOptions) SPIError!c_int { if (options.args) |args| { if (args.types.len != args.values.len) { return err.PGError.SPIArgument; @@ -92,83 +110,128 @@ pub fn exec(sql: [:0]const u8, options: ExecOptions) SPIError!c_int { } } -pub fn query(sql: [:0]const u8, options: ExecOptions) SPIError!Rows { - _ = try exec(sql, options); - return Rows.init(); -} - -pub fn queryTyped(comptime T: type, sql: [:0]const u8, options: ExecOptions) SPIError!RowsOf(T) { - const rows = try query(sql, options); - return rows.typed(T); +fn scanProcessed(row: usize, values: anytype) !void { + scanProcessedFrame(SPIFrame.get(), row, values); } -pub fn scanProcessed(row: usize, values: anytype) !void { - if (pg.SPI_processed <= row) { - return err.PGError.SPIInvalidRowIndex; - } - +inline fn scanProcessedFrame(frame: SPIFrame, row: usize, values: anytype) !void { var column: c_int = 1; inline for (std.meta.fields(@TypeOf(values)), 0..) |field, i| { - column = try scanField(field.type, values[i], row, column); + column = try scanField(field.type, frame, values[i], row, column); } } -fn scanField(comptime fieldType: type, to: anytype, row: usize, column: c_int) !c_int { - const fieldInfo = @typeInfo(fieldType); - if (fieldInfo != .Pointer) { +fn scanField( + comptime fieldType: type, + frame: SPIFrame, + to: anytype, + row: usize, + column: c_int, +) !c_int { + const field_info = @typeInfo(fieldType); + if (field_info != .Pointer) { @compileError("scanField requires a pointer"); } - if (fieldInfo.Pointer.size == .Slice) { + if (field_info.Pointer.size == .Slice) { @compileError("scanField requires a single pointer, not a slice"); } - const childType = fieldInfo.Pointer.child; - if (@typeInfo(childType) == .Struct) { - var structColumn = column; - inline for (std.meta.fields(childType)) |field| { - const childPtr = &@field(to.*, field.name); - structColumn = try scanField(@TypeOf(childPtr), childPtr, row, structColumn); + const child_type = field_info.Pointer.child; + if (@typeInfo(child_type) == .Struct) { + var struct_column = column; + inline for (std.meta.fields(child_type)) |field| { + const child_ptr = &@field(to.*, field.name); + struct_column = try scanField(@TypeOf(child_ptr), frame, child_ptr, row, struct_column); } - return structColumn; + return struct_column; } else { - const value = try convBinValue(childType, pg.SPI_tuptable, row, column); + const value = try convBinValue(child_type, frame, row, column); to.* = value; return column + 1; } } -pub const Rows = struct { - row: isize = -1, +pub fn OwnedSPIFrameRows(comptime R: type) type { + return struct { + rows: R, - const Self = @This(); + const Self = @This(); + + pub inline fn init(r: R) Self { + return .{ .rows = r }; + } + + pub inline fn deinit(self: *Self) void { + self.rows.deinit(); + finish(); + } + + pub fn next(self: *Self) meta.fnReturnType(@TypeOf(R.next)) { + return self.rows.next(); + } + + pub const scan = if (@hasField(R, "scan")) + R.scan + else + @compileError("no scan method available"); + }; +} + +// Rows iterates over SPI_tuptable from the last executed SPI query. +// When initializing a Rows iterator we capture the current SPI_tuptable from +// the active SPI frame. +// +// Safety: +// ======= +// +// The underlying tuple table is released when the current frame is released +// via `finish`. The iterator must not be used after. We have no way to check +// if the current frame was released or not. Accessing the tuple table after a +// release will result in undefined behavior. +// +// Due to SPI managing a stack of SPI frames it is safe to use `connect` to +// create a child frame to run queries while iterating over the rows. +// +pub const Rows = struct { + row: isize, + spi_frame: SPIFrame, - fn init() Self { - return .{}; + fn init() Rows { + return .{ + .row = -1, + .spi_frame = SPIFrame.get(), + }; } - fn typed(self: Self, comptime T: type) RowsOf(T) { + fn typed(self: Rows, comptime T: type) RowsOf(T) { return RowsOf(T).init(self); } - pub fn deinit(self: *Self) void { - pg.SPI_freetuptable(pg.SPI_tuptable); + fn ownedSPIFrame(self: Rows) OwnedSPIFrameRows(Rows) { + return OwnedSPIFrameRows(Rows).init(self); + } + + pub fn deinit(self: *Rows) void { + if (self.spi_frame.tuptable) |tt| { + pg.SPI_freetuptable(tt); + } self.row = -1; } - pub fn next(self: *Self) bool { + pub fn next(self: *Rows) bool { const next_idx = self.row + 1; - if (next_idx >= pg.SPI_processed) { + if (self.spi_frame.tuptable == null or next_idx >= self.spi_frame.processed) { return false; } self.row = next_idx; return true; } - pub fn scan(self: *Self, values: anytype) !void { + pub fn scan(self: *Rows, values: anytype) !void { if (self.row < 0) { return err.PGError.SPIInvalidRowIndex; } - try scanProcessed(@intCast(self.row), values); + try scanProcessedFrame(self.spi_frame, @intCast(self.row), values); } }; @@ -177,6 +240,7 @@ pub fn RowsOf(comptime T: type) type { rows: Rows, const Self = @This(); + pub const Owned = OwnedSPIFrameRows(Self); pub fn init(rows: Rows) Self { return .{ .rows = rows }; @@ -186,6 +250,10 @@ pub fn RowsOf(comptime T: type) type { self.rows.deinit(); } + pub fn ownedSPIFrame(self: Self) Self.Owned { + return OwnedSPIFrameRows(Self).init(self); + } + pub fn next(self: *Self) !?T { if (!self.rows.next()) { return null; @@ -197,20 +265,37 @@ pub fn RowsOf(comptime T: type) type { }; } +// The SPI interface uses a +const SPIFrame = struct { + processed: u64, + tuptable: ?*pg.SPITupleTable, + + inline fn get() SPIFrame { + return .{ + .processed = pg.SPI_processed, + .tuptable = pg.SPI_tuptable, + }; + } +}; + pub fn convProcessed(comptime T: type, row: c_int, col: c_int) !T { if (pg.SPI_processed <= row) { return err.PGError.SPIInvalidRowIndex; } - return convBinValue(T, pg.SPI_tuptable, row, col); + return convBinValue(T, SPIFrame.get(), row, col); } -pub fn convBinValue(comptime T: type, table: *pg.SPITupleTable, row: usize, col: c_int) !T { +pub fn convBinValue(comptime T: type, frame: SPIFrame, row: usize, col: c_int) !T { // TODO: check index? var nd: pg.NullableDatum = undefined; - nd.value = pg.SPI_getbinval(table.*.vals[row], table.*.tupdesc, col, @ptrCast(&nd.isnull)); + const table = frame.tuptable.?; + const desc = table.*.tupdesc; + nd.value = pg.SPI_getbinval(table.*.vals[row], desc, col, @ptrCast(&nd.isnull)); try checkStatus(pg.SPI_result); - return try datum.fromNullableDatum(T, nd); + const attr_desc = &desc.*.attrs()[@intCast(col - 1)]; + const oid = attr_desc.atttypid; + return try datum.fromNullableDatumWithOID(T, nd, oid); } fn checkStatus(st: c_int) err.PGError!void {