diff --git a/src/euneus_encoder.erl b/src/euneus_encoder.erl index b996b11..7da24bf 100644 --- a/src/euneus_encoder.erl +++ b/src/euneus_encoder.erl @@ -7,6 +7,7 @@ -export([encode/2]). -export([continue/2]). +-export([codec_callback/2]). -export([key_to_binary/1]). -export([escape/1]). -export([encode_integer/2]). @@ -23,6 +24,7 @@ % -ignore_xref([continue/2]). +-ignore_xref([codec_callback/2]). -ignore_xref([key_to_binary/1]). -ignore_xref([escape/1]). -ignore_xref([encode_integer/2]). @@ -48,6 +50,8 @@ %% -------------------------------------------------------------------- -export_type([options/0]). +-export_type([codec_fun/0]). +-export_type([codec_result/0]). -export_type([codec_callback/0]). -export_type([is_proplist/0]). -export_type([encode/1]). @@ -75,6 +79,7 @@ -type options() :: #{ codecs => [codec()], + codec_callback => codec_callback(), nulls => [term()], skip_values => [term()], key_to_binary => fun((term()) -> binary()), @@ -99,10 +104,17 @@ | ipv4 | ipv6 | {records, #{Name :: atom() := {Fields :: [atom()], Size :: pos_integer()}}} - | codec_callback(). + | codec_fun() + | custom_codec(). -export_type([codec/0]). --type codec_callback() :: fun((tuple()) -> next | {halt, term()}). +-type codec_fun() :: fun((tuple()) -> codec_result()). + +-type codec_result() :: next | {halt, term()}. + +-type custom_codec() :: term(). + +-type codec_callback() :: fun((codec(), tuple()) -> codec_result()). -type is_proplist() :: fun((list()) -> boolean()). @@ -110,6 +122,7 @@ -record(state, { codecs :: [codec()], + codec_callback :: codec_callback(), nulls :: #{term() := null}, skip_values :: #{term() := skip}, key_to_binary :: fun((term()) -> binary()), @@ -264,6 +277,11 @@ %% ''' %% %%
  • +%% `codec_callback' - Overrides the default codec resolver. +%% +%% Default is `codec_callback/2'. +%%
  • +%%
  • %% `nulls' - Defines which values should be encoded as null. %% %% Default is `[null]'. @@ -413,7 +431,7 @@ continue(List, State) when is_list(List) -> continue(Map, State) when is_map(Map) -> (State#state.encode_map)(Map, State); continue(Tuple, State) when is_tuple(Tuple) -> - case traverse_codecs(State#state.codecs, Tuple) of + case traverse_codecs(State#state.codecs, State#state.codec_callback, Tuple) of NewTuple when is_tuple(NewTuple) -> (State#state.encode_tuple)(NewTuple, State); NewTerm -> @@ -428,6 +446,20 @@ continue(Ref, State) when is_reference(Ref) -> continue(Term, State) -> (State#state.encode_term)(Term, State). +-spec codec_callback(codec(), tuple()) -> codec_result(). +codec_callback(timestamp, Tuple) -> + timestamp_codec_callback(Tuple); +codec_callback(datetime, Tuple) -> + datetime_codec_callback(Tuple); +codec_callback(ipv4, Tuple) -> + ipv4_codec_callback(Tuple); +codec_callback(ipv6, Tuple) -> + ipv6_codec_callback(Tuple); +codec_callback({records, Records}, Tuple) -> + records_codec_callback(Tuple, Records); +codec_callback(CodecFun, Tuple) -> + CodecFun(Tuple). + -spec key_to_binary(Term) -> binary() when Term :: binary() | string() | atom() | integer(). key_to_binary(Bin) when is_binary(Bin) -> @@ -551,6 +583,7 @@ encode_term(Term, State) -> new_state(Opts) -> #state{ codecs = maps:get(codecs, Opts, []), + codec_callback = maps:get(codec_callback, Opts, fun codec_callback/2), nulls = maps:from_keys(maps:get(nulls, Opts, [null]), null), skip_values = maps:from_keys(maps:get(skip_values, Opts, [undefined]), skip), key_to_binary = maps:get(key_to_binary, Opts, fun key_to_binary/1), @@ -571,29 +604,16 @@ new_state(Opts) -> % Codecs -traverse_codecs([Codec | Codecs], Tuple) -> - case codec_callback(Codec, Tuple) of +traverse_codecs([Codec | Codecs], CodecCallback, Tuple) -> + case CodecCallback(Codec, Tuple) of next -> - traverse_codecs(Codecs, Tuple); + traverse_codecs(Codecs, CodecCallback, Tuple); {halt, NewTerm} -> NewTerm end; -traverse_codecs([], Tuple) -> +traverse_codecs([], _CodecCallback, Tuple) -> Tuple. -codec_callback(timestamp, Tuple) -> - timestamp_codec_callback(Tuple); -codec_callback(datetime, Tuple) -> - datetime_codec_callback(Tuple); -codec_callback(ipv4, Tuple) -> - ipv4_codec_callback(Tuple); -codec_callback(ipv6, Tuple) -> - ipv6_codec_callback(Tuple); -codec_callback({records, Records}, Tuple) -> - records_codec_callback(Tuple, Records); -codec_callback(Callback, Tuple) -> - Callback(Tuple). - timestamp_codec_callback({MegaSecs, Secs, MicroSecs} = Timestamp) when ?IS_MIN(MegaSecs, 0), ?IS_MIN(Secs, 0), ?IS_MIN(MicroSecs, 0) -> diff --git a/test/euneus_encoder_SUITE.erl b/test/euneus_encoder_SUITE.erl index 8d90709..db1765c 100644 --- a/test/euneus_encoder_SUITE.erl +++ b/test/euneus_encoder_SUITE.erl @@ -64,6 +64,17 @@ codecs_test(Config) when is_list(Config) -> ) ]. +codec_callback_test(Config) when is_list(Config) -> + ?assertEqual( + <<"[\"foo\"]">>, + encode({foo}, #{ + codecs => [tuple_to_list], + codec_callback => fun(tuple_to_list, Tuple) -> + {halt, erlang:tuple_to_list(Tuple)} + end + }) + ). + timestamp_codec_test(Config) when is_list(Config) -> ?assertEqual(<<"\"1970-01-01T00:00:00.000Z\"">>, encode({0, 0, 0}, #{codecs => [timestamp]})).