Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chacha20-Poly1305: use string instead of cstruct #203

Merged
merged 8 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions bench/speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module Time = struct

end

let burn_period = 2.0
let burn_period = 3.0

let sizes = [16; 64; 256; 1024; 8192]
(* let sizes = [16] *)
Expand All @@ -38,6 +38,17 @@ let burn f n =
let time = Time.time ~n:iters f cs in
(iters, time, float (n * iters) /. time)

let burn_str f n =
let cs = Cstruct.to_string (Mirage_crypto_rng.generate n) in
let (t1, i1) =
let rec loop it =
let t = Time.time ~n:it f cs in
if t > 0.2 then (t, it) else loop (it * 10) in
loop 10 in
let iters = int_of_float (float i1 *. burn_period /. t1) in
let time = Time.time ~n:iters f cs in
(iters, time, float (n * iters) /. time)

let mb = 1024. *. 1024.

let throughput title f =
Expand All @@ -48,6 +59,14 @@ let throughput title f =
Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!"
size (bw /. mb) iters time

let throughput_str title f =
Printf.printf "\n* [%s]\n%!" title ;
sizes |> List.iter @@ fun size ->
Gc.full_major () ;
let (iters, time, bw) = burn_str f size in
Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!"
size (bw /. mb) iters time

let count_period = 10.

let count f n =
Expand Down Expand Up @@ -370,8 +389,8 @@ let benchmarks = [

bm "chacha20-poly1305" (fun name ->
let key = Mirage_crypto.Chacha20.of_secret (Mirage_crypto_rng.generate 32)
and nonce = Mirage_crypto_rng.generate 8 in
throughput name (Mirage_crypto.Chacha20.authenticate_encrypt ~key ~nonce)) ;
and nonce = Cstruct.to_string (Mirage_crypto_rng.generate 8) in
throughput_str name (Mirage_crypto.Chacha20.auth_enc_str ~key ~nonce)) ;

bm "aes-128-ecb" (fun name ->
let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in
Expand Down
75 changes: 41 additions & 34 deletions src/chacha20.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ open Uncommon

let block = 64

type key = Cstruct.t
type key = string

let of_secret a = a
let of_secret a = Cstruct.to_string a

let chacha20_block state idx key_stream =
Native.Chacha.round 10 state.Cstruct.buffer 0 key_stream.Cstruct.buffer idx
Native.Chacha.round 10 state key_stream idx

let init ctr ~key ~nonce =
let ctr_off = 48 in
let set_ctr32 b v = Cstruct.LE.set_uint32 b ctr_off v
and set_ctr64 b v = Cstruct.LE.set_uint64 b ctr_off v
let set_ctr32 b v = Bytes.set_int32_le b ctr_off v
and set_ctr64 b v = Bytes.set_int64_le b ctr_off v
in
let inc32 b = set_ctr32 b (Int32.add (Cstruct.LE.get_uint32 b ctr_off) 1l)
and inc64 b = set_ctr64 b (Int64.add (Cstruct.LE.get_uint64 b ctr_off) 1L)
let inc32 b = set_ctr32 b (Int32.add (Bytes.get_int32_le b ctr_off) 1l)
and inc64 b = set_ctr64 b (Int64.add (Bytes.get_int64_le b ctr_off) 1L)
in
let s, key, init_ctr, nonce_off, inc =
match Cstruct.length key, Cstruct.length nonce, Int64.shift_right ctr 32 = 0L with
match String.length key, String.length nonce, Int64.shift_right ctr 32 = 0L with
| 32, 12, true ->
let ctr = Int64.to_int32 ctr in
"expand 32-byte k", key, (fun b -> set_ctr32 b ctr), 52, inc32
Expand All @@ -29,81 +29,82 @@ let init ctr ~key ~nonce =
| 32, 8, _ ->
"expand 32-byte k", key, (fun b -> set_ctr64 b ctr), 56, inc64
| 16, 8, _ ->
let k = Cstruct.append key key in
let k = key ^ key in
"expand 16-byte k", k, (fun b -> set_ctr64 b ctr), 56, inc64
| _ -> invalid_arg "Valid parameters are nonce 12 bytes and key 32 bytes \
(counter 32 bit), or nonce 8 byte and key 16 or 32 \
bytes (counter 64 bit)."
in
let state = Cstruct.create block in
Cstruct.blit_from_string s 0 state 0 16 ;
Cstruct.blit key 0 state 16 32 ;
let state = Bytes.create block in
Bytes.blit_string s 0 state 0 16 ;
Bytes.blit_string key 0 state 16 32 ;
hannesm marked this conversation as resolved.
Show resolved Hide resolved
init_ctr state ;
Cstruct.blit nonce 0 state nonce_off (Cstruct.length nonce) ;
Bytes.blit_string nonce 0 state nonce_off (String.length nonce) ;
state, inc

let crypt ~key ~nonce ?(ctr = 0L) data =
let state, inc = init ctr ~key ~nonce in
let l = Cstruct.length data in
let l = String.length data in
let block_count = l // block in
let len = block * block_count in
let last_len =
let last = l mod block in
if last = 0 then block else last
in
let key_stream = Cstruct.create_unsafe len in
let key_stream = Bytes.create len in
let rec loop i = function
| 0 -> ()
| 1 ->
chacha20_block state i key_stream ;
Native.xor_into data.buffer (data.off + i) key_stream.buffer i last_len
Native.xor_into_bytes data i key_stream i last_len
| n ->
chacha20_block state i key_stream ;
Native.xor_into data.buffer (data.off + i) key_stream.buffer i block ;
Native.xor_into_bytes data i key_stream i block ;
inc state;
loop (i + block) (n - 1)
in
loop 0 block_count ;
Cstruct.sub key_stream 0 l
let res = Bytes.unsafe_to_string key_stream in
if l <> len then String.sub res 0 l else res
hannesm marked this conversation as resolved.
Show resolved Hide resolved

module P = Poly1305.It

let generate_poly1305_key ~key ~nonce =
crypt ~key ~nonce (Cstruct.create 32)
crypt ~key ~nonce (String.make 32 '\000')

let mac ~key ~adata ciphertext =
let pad16 b =
let len = Cstruct.length b mod 16 in
if len = 0 then Cstruct.empty else Cstruct.create (16 - len)
let len = String.length b mod 16 in
if len = 0 then "" else String.make (16 - len) '\000'
and len =
let data = Cstruct.create 16 in
Cstruct.LE.set_uint64 data 0 (Int64.of_int (Cstruct.length adata));
Cstruct.LE.set_uint64 data 8 (Int64.of_int (Cstruct.length ciphertext));
data
let data = Bytes.create 16 in
Bytes.set_int64_le data 0 (Int64.of_int (String.length adata));
Bytes.set_int64_le data 8 (Int64.of_int (String.length ciphertext));
Bytes.unsafe_to_string data
in
let ctx = P.empty ~key in
let ctx = P.feed ctx adata in
let ctx = P.feed ctx (pad16 adata) in
let ctx = P.feed ctx ciphertext in
let ctx = P.feed ctx (pad16 ciphertext) in
let ctx = P.feed ctx len in
P.get ctx
P.macl ~key [ adata ; pad16 adata ; ciphertext ; pad16 ciphertext ; len ]

let authenticate_encrypt_tag ~key ~nonce ?(adata = Cstruct.empty) data =
let adata = Cstruct.to_string adata in
let nonce = Cstruct.to_string nonce in
let data = Cstruct.to_string data in
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ciphertext = crypt ~key ~nonce ~ctr:1L data in
let mac = mac ~key:poly1305_key ~adata ciphertext in
ciphertext, mac
Cstruct.of_string ciphertext, Cstruct.of_string mac

let authenticate_encrypt ~key ~nonce ?adata data =
let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in
Cstruct.append cdata ctag

let authenticate_decrypt_tag ~key ~nonce ?(adata = Cstruct.empty) ~tag data =
let adata = Cstruct.to_string adata in
let nonce = Cstruct.to_string nonce in
let data = Cstruct.to_string data in
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ctag = mac ~key:poly1305_key ~adata data in
let plain = crypt ~key ~nonce ~ctr:1L data in
if Eqaf_cstruct.equal tag ctag then Some plain else None
if Eqaf_cstruct.equal tag (Cstruct.of_string ctag) then Some (Cstruct.of_string plain) else None

let authenticate_decrypt ~key ~nonce ?adata data =
if Cstruct.length data < P.mac_size then
Expand All @@ -112,4 +113,10 @@ let authenticate_decrypt ~key ~nonce ?adata data =
let cipher, tag = Cstruct.split data (Cstruct.length data - P.mac_size) in
authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher

let auth_enc_str ~key ~nonce ?(adata = "") data =
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ciphertext = crypt ~key ~nonce ~ctr:1L data in
let mac = mac ~key:poly1305_key ~adata ciphertext in
ciphertext ^ mac

let tag_size = P.mac_size
20 changes: 13 additions & 7 deletions src/mirage_crypto.mli
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ end

(** The poly1305 message authentication code *)
module Poly1305 : sig
type mac = Cstruct.t
type mac = string

type 'a iter = ('a -> unit) -> unit

Expand All @@ -214,27 +214,30 @@ module Poly1305 : sig
val mac_size : int
(** [mac_size] is the size of the output. *)

val empty : key:Cstruct.t -> t
val empty : key:string -> t
(** [empty] is the empty context with the given [key].

@raise Invalid_argument if key is not 32 bytes. *)

val feed : t -> Cstruct.t -> t
val feed : t -> string -> t
(** [feed t msg] adds the information in [msg] to [t]. *)

val feedi : t -> Cstruct.t iter -> t
val feedi : t -> string iter -> t
(** [feedi t iter] feeds iter into [t]. *)

val get : t -> mac
(** [get t] is the mac corresponding to [t]. *)

val mac : key:Cstruct.t -> Cstruct.t -> mac
val mac : key:string -> string -> mac
(** [mac ~key msg] is the all-in-one mac computation:
[get (feed (empty ~key) msg)]. *)

val maci : key:Cstruct.t -> Cstruct.t iter -> mac
val maci : key:string -> string iter -> mac
(** [maci ~key iter] is the all-in-one mac computation:
[get (feedi (empty ~key) iter)]. *)

val macl : key:string -> string list -> mac
(** [macl ~key datas] computes the [mac] of [datas]. *)
end

(** {1 Symmetric-key cryptography} *)
Expand Down Expand Up @@ -506,7 +509,7 @@ end
module Chacha20 : sig
include AEAD

val crypt : key:key -> nonce:Cstruct.t -> ?ctr:int64 -> Cstruct.t -> Cstruct.t
val crypt : key:key -> nonce:string -> ?ctr:int64 -> string -> string
(** [crypt ~key ~nonce ~ctr data] generates a ChaCha20 key stream using
the [key], and [nonce]. The [ctr] defaults to 0. The generated key
stream is of the same length as [data], and the output is the XOR
Expand All @@ -520,6 +523,9 @@ module Chacha20 : sig
IETF mode (and counter fit into 32 bits), or [key] must be either 16
bytes or 32 bytes and [nonce] 8 bytes.
*)

val auth_enc_str : key:key -> nonce:string -> ?adata:string ->
string -> string
end

(** Streaming ciphers. *)
Expand Down
10 changes: 6 additions & 4 deletions src/native.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ module DES = struct
end

module Chacha = struct
external round : int -> buffer -> off -> buffer -> off -> unit = "mc_chacha_round" [@@noalloc]
external round : int -> bytes -> bytes -> off -> unit = "mc_chacha_round" [@@noalloc]
end

module Poly1305 = struct
external init : ctx -> buffer -> off -> unit = "mc_poly1305_init" [@@noalloc]
external update : ctx -> buffer -> off -> size -> unit = "mc_poly1305_update" [@@noalloc]
external finalize : ctx -> buffer -> off -> unit = "mc_poly1305_finalize" [@@noalloc]
external init : ctx -> string -> unit = "mc_poly1305_init" [@@noalloc]
external update : ctx -> string -> size -> unit = "mc_poly1305_update" [@@noalloc]
external finalize : ctx -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc]
external ctx_size : unit -> int = "mc_poly1305_ctx_size" [@@noalloc]
external mac_size : unit -> int = "mc_poly1305_mac_size" [@@noalloc]
end
Expand Down Expand Up @@ -95,6 +95,8 @@ end
* Unsolved: bounds-checked XORs are slowing things down considerably... *)
external xor_into : buffer -> off -> buffer -> off -> size -> unit = "mc_xor_into" [@@noalloc]

external xor_into_bytes : string -> off -> bytes -> off -> size -> unit = "mc_xor_into_bytes" [@@noalloc]

external count8be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_8_be" [@@noalloc]
external count16be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be" [@@noalloc]
external count16be4 : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be_4" [@@noalloc]
Expand Down
8 changes: 4 additions & 4 deletions src/native/bitfn.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,24 @@ static inline uint64_t ror64(uint64_t word, uint32_t shift)
return (word >> shift) | (word << (64 - shift));
}

static inline void array_swap32(uint32_t *d, uint32_t *s, uint32_t nb)
static inline void array_swap32(uint32_t *d, const uint32_t *s, uint32_t nb)
{
while (nb--)
*d++ = bitfn_swap32(*s++);
}

static inline void array_swap64(uint64_t *d, uint64_t *s, uint32_t nb)
static inline void array_swap64(uint64_t *d, const uint64_t *s, uint32_t nb)
{
while (nb--)
*d++ = bitfn_swap64(*s++);
}

static inline void array_copy32(uint32_t *d, uint32_t *s, uint32_t nb)
static inline void array_copy32(uint32_t *d, const uint32_t *s, uint32_t nb)
{
while (nb--) *d++ = *s++;
}

static inline void array_copy64(uint64_t *d, uint64_t *s, uint32_t nb)
static inline void array_copy64(uint64_t *d, const uint64_t *s, uint32_t nb)
{
while (nb--) *d++ = *s++;
}
Expand Down
36 changes: 10 additions & 26 deletions src/native/chacha.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "mirage_crypto.h"

extern void mc_chacha_core_generic(int count, uint8_t *src, uint8_t *dst);
extern void mc_chacha_core_generic(int count, const uint32_t *src, uint32_t *dst);

#ifdef __mc_ACCELERATE__

Expand All @@ -13,25 +13,9 @@ static inline void mc_chacha_quarterround(uint32_t *x, int a, int b, int c, int
x[c] += x[d]; x[b] = rol32(x[b] ^ x[c], 7);
}

static inline uint32_t mc_get_u32_le(uint8_t *input, int offset) {
return input[offset]
| (input[offset + 1] << 8)
| (input[offset + 2] << 16)
| (input[offset + 3] << 24);
}

static inline void mc_set_u32_le(uint8_t *input, int offset, uint32_t value) {
input[offset] = (uint8_t) value;
input[offset + 1] = (uint8_t) (value >> 8);
input[offset + 2] = (uint8_t) (value >> 16);
input[offset + 3] = (uint8_t) (value >> 24);
}

static void mc_chacha_core(int count, uint8_t *src, uint8_t *dst) {
static void mc_chacha_core(int count, const uint32_t *src, uint32_t *dst) {
uint32_t x[16];
for (int i = 0; i < 16; i++) {
x[i] = mc_get_u32_le(src, i * 4);
}
cpu_to_le32_array(x, src, 16);
for (int i = 0; i < count; i++) {
mc_chacha_quarterround(x, 0, 4, 8, 12);
mc_chacha_quarterround(x, 1, 5, 9, 13);
Expand All @@ -45,26 +29,26 @@ static void mc_chacha_core(int count, uint8_t *src, uint8_t *dst) {
}
for (int i = 0; i < 16; i++) {
uint32_t xi = x[i];
uint32_t hj = mc_get_u32_le(src, i * 4);
mc_set_u32_le(dst, i * 4, xi + hj);
uint32_t hj = cpu_to_le32(src[i]);
dst[i] = le32_to_cpu(xi + hj);
}
}

CAMLprim value
mc_chacha_round(value count, value src, value off1, value dst, value off2)
mc_chacha_round(value count, value src, value dst, value off)
{
_mc_switch_accel(ssse3,
mc_chacha_core_generic(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2)),
mc_chacha_core(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2)));
mc_chacha_core_generic(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off))),
mc_chacha_core(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off))));
return Val_unit;
}

#else //#ifdef __mc_ACCELERATE__

CAMLprim value
mc_chacha_round(value count, value src, value off1, value dst, value off2)
mc_chacha_round(value count, value src, value dst, value off)
{
mc_chacha_core_generic(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2));
mc_chacha_core_generic(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off)));
return Val_unit;
}

Expand Down
Loading
Loading