Skip to content

Commit

Permalink
json: add ParseFlags values UseInt64, UseUint64, UseBigInt
Browse files Browse the repository at this point in the history
  • Loading branch information
extemporalgenome committed Dec 1, 2023
1 parent 1999556 commit cdd4c01
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 9 deletions.
5 changes: 5 additions & 0 deletions json/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding"
"encoding/json"
"fmt"
"math/big"
"reflect"
"sort"
"strconv"
Expand Down Expand Up @@ -838,6 +839,7 @@ func constructInlineValueEncodeFunc(encode encodeFunc) encodeFunc {
// compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime; see issues 23382 and 7921.
//
//go:nosplit
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
Expand Down Expand Up @@ -1078,6 +1080,7 @@ var (
float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))

bigIntType = reflect.TypeOf(new(big.Int))
numberType = reflect.TypeOf(json.Number(""))
stringType = reflect.TypeOf("")
stringsType = reflect.TypeOf([]string(nil))
Expand All @@ -1104,6 +1107,8 @@ var (
jsonUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

bigIntDecoder = constructJSONUnmarshalerDecodeFunc(bigIntType, false)
)

// =============================================================================
Expand Down
77 changes: 68 additions & 9 deletions json/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"math/big"
"reflect"
"strconv"
"time"
Expand All @@ -16,6 +17,10 @@ import (
"github.com/segmentio/encoding/iso8601"
)

func (d decoder) anyFlagsSet(flags ParseFlags) bool {
return d.flags&flags != 0
}

func (d decoder) decodeNull(b []byte, p unsafe.Pointer) ([]byte, error) {
if hasNullPrefix(b) {
return b[4:], nil
Expand Down Expand Up @@ -1306,15 +1311,7 @@ func (d decoder) decodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) {
v, val = nil, k == True

case Num:
if (d.flags & UseNumber) != 0 {
n := Number("")
v, err = d.decodeNumber(v, unsafe.Pointer(&n))
val = n
} else {
f := 0.0
v, err = d.decodeFloat64(v, unsafe.Pointer(&f))
val = f
}
v, err = d.decodeDynamicNumber(v, unsafe.Pointer(&val))

default:
return b, syntaxError(v, "expected token but found '%c'", v[0])
Expand All @@ -1332,6 +1329,68 @@ func (d decoder) decodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) {
return b, nil
}

func (d decoder) decodeDynamicNumber(b []byte, p unsafe.Pointer) ([]byte, error) {
kind := Float
var err error

// Only pre-parse for numeric kind if a conditional decode
// has been requested.
if d.anyFlagsSet(UseBigInt | UseInt64 | UseUint64) {
_, _, kind, err = d.parseNumber(b)
if err != nil {
return b, err
}
}

var rem []byte
anyPtr := (*any)(p)

// Mutually exclusive integer handling cases.
switch {
// If requested, attempt decode of positive integers as uint64.
case kind == Uint && d.anyFlagsSet(UseUint64):
rem, err = decodeInto[uint64](anyPtr, b, d, decoder.decodeUint64)
if err == nil {
return rem, err
}

// If uint64 decode was not requested but int64 decode was requested,
// then attempt decode of positive integers as int64.
case kind == Uint && d.anyFlagsSet(UseInt64):
fallthrough

// If int64 decode was requested,
// attempt decode of negative integers as int64.
case kind == Int && d.anyFlagsSet(UseInt64):
rem, err = decodeInto[int64](anyPtr, b, d, decoder.decodeInt64)
if err == nil {
return rem, err
}
}

// Fallback numeric handling cases:
// these cannot be combined into the above switch,
// since these cases also handle overflow
// from the above cases, if decode was already attempted.
switch {
// If *big.Int decode was requested, handle that case for any integer.
case kind == Uint && d.anyFlagsSet(UseBigInt):
fallthrough
case kind == Int && d.anyFlagsSet(UseBigInt):
rem, err = decodeInto[*big.Int](anyPtr, b, d, bigIntDecoder)

// If json.Number decode was requested, handle that for any number.
case d.anyFlagsSet(UseNumber):
rem, err = decodeInto[Number](anyPtr, b, d, decoder.decodeNumber)

// Fall back to float64 decode when no special decoding has been requested.
default:
rem, err = decodeInto[float64](anyPtr, b, d, decoder.decodeFloat64)
}

return rem, err
}

func (d decoder) decodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) {
if hasNullPrefix(b) {
*(*interface{})(p) = nil
Expand Down
13 changes: 13 additions & 0 deletions json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ const (
// mode.
DontMatchCaseInsensitiveStructFields

// Decode integers into *big.Int.
// Takes precedence over UseNumber for integers.
UseBigInt

// Decode in-range integers to int64.
// Takes precedence over UseNumber and UseBigInt for in-range integers.
UseInt64

// Decode in-range positive integers to uint64.
// Takes precedence over UseNumber, UseBigInt, and UseInt64
// for positive, in-range integers.
UseUint64

// ZeroCopy is a parsing flag that combines all the copy optimizations
// available in the package.
//
Expand Down
139 changes: 139 additions & 0 deletions json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"math"
"math/big"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -85,6 +86,13 @@ type tree struct {
Right *tree
}

var (
// bigPos128 and bigNeg128 are 1<<128 and -1<<128
// certainly neither is representable using a uint64/int64.
bigPos128 = new(big.Int).Lsh(big.NewInt(1), 128)
bigNeg128 = new(big.Int).Neg(bigPos128)
)

var testValues = [...]interface{}{
// constants
nil,
Expand Down Expand Up @@ -126,6 +134,9 @@ var testValues = [...]interface{}{
float64(math.SmallestNonzeroFloat64),
float64(math.MaxFloat64),

bigPos128,
bigNeg128,

// number
Number("0"),
Number("1234567890"),
Expand Down Expand Up @@ -484,6 +495,134 @@ func TestCodecDuration(t *testing.T) {
}
}

var numericParseTests = [...]struct {
name string
input string
flags ParseFlags
want any
}{
{
name: "zero_flags_default",
input: `0`,
flags: 0,
want: float64(0),
},
{
name: "zero_flags_int_uint_bigint_number",
input: `0`,
flags: UseInt64 | UseUint64 | UseBigInt | UseNumber,
want: uint64(0),
},
{
name: "zero_flags_int_bigint_number",
input: `0`,
flags: UseInt64 | UseBigInt | UseNumber,
want: int64(0),
},
{
name: "zero_flags_bigint_number",
input: `0`,
flags: UseBigInt | UseNumber,
want: big.NewInt(0),
},
{
name: "zero_flags_number",
input: `0`,
flags: UseNumber,
want: json.Number(`0`),
},
{
name: "max_uint64_flags_default",
input: fmt.Sprint(uint64(math.MaxUint64)),
flags: 0,
want: float64(math.MaxUint64),
},
{
name: "max_uint64_flags_int_uint_bigint_number",
input: fmt.Sprint(uint64(math.MaxUint64)),
flags: UseInt64 | UseUint64 | UseBigInt | UseNumber,
want: uint64(math.MaxUint64),
},
{
name: "min_int64_flags_uint_int_bigint_number",
input: fmt.Sprint(int64(math.MinInt64)),
flags: UseInt64 | UseBigInt | UseNumber,
want: int64(math.MinInt64),
},
{
name: "max_uint64_flags_int_bigint_number",
input: fmt.Sprint(uint64(math.MaxUint64)),
flags: UseInt64 | UseBigInt | UseNumber,
want: new(big.Int).SetUint64(math.MaxUint64),
},
{
name: "overflow_uint64_flags_uint_int_bigint_number",
input: bigPos128.String(),
flags: UseUint64 | UseInt64 | UseBigInt | UseNumber,
want: bigPos128,
},
{
name: "underflow_uint64_flags_uint_int_bigint_number",
input: bigNeg128.String(),
flags: UseUint64 | UseInt64 | UseBigInt | UseNumber,
want: bigNeg128,
},
{
name: "overflow_uint64_flags_uint_int_number",
input: bigPos128.String(),
flags: UseUint64 | UseInt64 | UseNumber,
want: json.Number(bigPos128.String()),
},
{
name: "underflow_uint64_flags_uint_int_number",
input: bigNeg128.String(),
flags: UseUint64 | UseInt64 | UseNumber,
want: json.Number(bigNeg128.String()),
},
{
name: "overflow_uint64_flags_uint_int",
input: bigPos128.String(),
flags: UseUint64 | UseInt64,
want: float64(1 << 128),
},
{
name: "underflow_uint64_flags_uint_int",
input: bigNeg128.String(),
flags: UseUint64 | UseInt64,
want: float64(-1 << 128),
},
}

func TestParse_numeric(t *testing.T) {
t.Parallel()

for _, test := range numericParseTests {
test := test

t.Run(test.name, func(t *testing.T) {
t.Parallel()

var got any

rem, err := Parse([]byte(test.input), &got, test.flags)
if err != nil {
format := "Parse(%#q, ..., %#b) = %q [error], want nil"
t.Errorf(format, test.input, test.flags, err)
}

if len(rem) != 0 {
format := "Parse(%#q, ..., %#b) = %#q, want zero length"
t.Errorf(format, test.input, test.flags, rem)
}

if !reflect.DeepEqual(got, test.want) {
format := "Parse(%#q, %#b) -> %T(%#[3]v), want %T(%#[4]v)"
t.Errorf(format, test.input, test.flags, got, test.want)
}
})
}
}

func newValue(model interface{}) reflect.Value {
if model == nil {
return reflect.New(reflect.TypeOf(&model).Elem())
Expand Down

0 comments on commit cdd4c01

Please sign in to comment.