Skip to content

Commit

Permalink
Chacha20-Poly1305: use string instead of cstruct
Browse files Browse the repository at this point in the history
Performance improvement from 8MB/s to 20MB/s (with 16 byte blocks, on my laptop)
  • Loading branch information
hannesm committed Feb 25, 2024
1 parent 28f8cde commit e46d028
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 125 deletions.
25 changes: 22 additions & 3 deletions bench/speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module Time = struct

end

let burn_period = 2.0
let burn_period = 3.0

let sizes = [16; 64; 256; 1024; 8192]
(* let sizes = [16] *)
Expand All @@ -38,6 +38,17 @@ let burn f n =
let time = Time.time ~n:iters f cs in
(iters, time, float (n * iters) /. time)

let burn_str f n =
let cs = Cstruct.to_string (Mirage_crypto_rng.generate n) in
let (t1, i1) =
let rec loop it =
let t = Time.time ~n:it f cs in
if t > 0.2 then (t, it) else loop (it * 10) in
loop 10 in
let iters = int_of_float (float i1 *. burn_period /. t1) in
let time = Time.time ~n:iters f cs in
(iters, time, float (n * iters) /. time)

let mb = 1024. *. 1024.

let throughput title f =
Expand All @@ -48,6 +59,14 @@ let throughput title f =
Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!"
size (bw /. mb) iters time

let throughput_str title f =
Printf.printf "\n* [%s]\n%!" title ;
sizes |> List.iter @@ fun size ->
Gc.full_major () ;
let (iters, time, bw) = burn_str f size in
Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!"
size (bw /. mb) iters time

let count_period = 10.

let count f n =
Expand Down Expand Up @@ -370,8 +389,8 @@ let benchmarks = [

bm "chacha20-poly1305" (fun name ->
let key = Mirage_crypto.Chacha20.of_secret (Mirage_crypto_rng.generate 32)
and nonce = Mirage_crypto_rng.generate 8 in
throughput name (Mirage_crypto.Chacha20.authenticate_encrypt ~key ~nonce)) ;
and nonce = Cstruct.to_string (Mirage_crypto_rng.generate 8) in
throughput_str name (Mirage_crypto.Chacha20.auth_enc_str ~key ~nonce)) ;

bm "aes-128-ecb" (fun name ->
let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in
Expand Down
75 changes: 41 additions & 34 deletions src/chacha20.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ open Uncommon

let block = 64

type key = Cstruct.t
type key = string

let of_secret a = a
let of_secret a = Cstruct.to_string a

let chacha20_block state idx key_stream =
Native.Chacha.round 10 state.Cstruct.buffer 0 key_stream.Cstruct.buffer idx
Native.Chacha.round 10 state key_stream idx

let init ctr ~key ~nonce =
let ctr_off = 48 in
let set_ctr32 b v = Cstruct.LE.set_uint32 b ctr_off v
and set_ctr64 b v = Cstruct.LE.set_uint64 b ctr_off v
let set_ctr32 b v = Bytes.set_int32_le b ctr_off v
and set_ctr64 b v = Bytes.set_int64_le b ctr_off v
in
let inc32 b = set_ctr32 b (Int32.add (Cstruct.LE.get_uint32 b ctr_off) 1l)
and inc64 b = set_ctr64 b (Int64.add (Cstruct.LE.get_uint64 b ctr_off) 1L)
let inc32 b = set_ctr32 b (Int32.add (Bytes.get_int32_le b ctr_off) 1l)
and inc64 b = set_ctr64 b (Int64.add (Bytes.get_int64_le b ctr_off) 1L)
in
let s, key, init_ctr, nonce_off, inc =
match Cstruct.length key, Cstruct.length nonce, Int64.shift_right ctr 32 = 0L with
match String.length key, String.length nonce, Int64.shift_right ctr 32 = 0L with
| 32, 12, true ->
let ctr = Int64.to_int32 ctr in
"expand 32-byte k", key, (fun b -> set_ctr32 b ctr), 52, inc32
Expand All @@ -29,81 +29,82 @@ let init ctr ~key ~nonce =
| 32, 8, _ ->
"expand 32-byte k", key, (fun b -> set_ctr64 b ctr), 56, inc64
| 16, 8, _ ->
let k = Cstruct.append key key in
let k = key ^ key in
"expand 16-byte k", k, (fun b -> set_ctr64 b ctr), 56, inc64
| _ -> invalid_arg "Valid parameters are nonce 12 bytes and key 32 bytes \
(counter 32 bit), or nonce 8 byte and key 16 or 32 \
bytes (counter 64 bit)."
in
let state = Cstruct.create block in
Cstruct.blit_from_string s 0 state 0 16 ;
Cstruct.blit key 0 state 16 32 ;
let state = Bytes.create block in
Bytes.blit_string s 0 state 0 16 ;
Bytes.blit_string key 0 state 16 32 ;
init_ctr state ;
Cstruct.blit nonce 0 state nonce_off (Cstruct.length nonce) ;
Bytes.blit_string nonce 0 state nonce_off (String.length nonce) ;
state, inc

let crypt ~key ~nonce ?(ctr = 0L) data =
let state, inc = init ctr ~key ~nonce in
let l = Cstruct.length data in
let l = String.length data in
let block_count = l // block in
let len = block * block_count in
let last_len =
let last = l mod block in
if last = 0 then block else last
in
let key_stream = Cstruct.create_unsafe len in
let key_stream = Bytes.create len in
let rec loop i = function
| 0 -> ()
| 1 ->
chacha20_block state i key_stream ;
Native.xor_into data.buffer (data.off + i) key_stream.buffer i last_len
Native.xor_into_bytes data i key_stream i last_len
| n ->
chacha20_block state i key_stream ;
Native.xor_into data.buffer (data.off + i) key_stream.buffer i block ;
Native.xor_into_bytes data i key_stream i block ;
inc state;
loop (i + block) (n - 1)
in
loop 0 block_count ;
Cstruct.sub key_stream 0 l
let res = Bytes.unsafe_to_string key_stream in
if l <> len then String.sub res 0 l else res

module P = Poly1305.It

let generate_poly1305_key ~key ~nonce =
crypt ~key ~nonce (Cstruct.create 32)
crypt ~key ~nonce (String.make 32 '\000')

let mac ~key ~adata ciphertext =
let pad16 b =
let len = Cstruct.length b mod 16 in
if len = 0 then Cstruct.empty else Cstruct.create (16 - len)
let len = String.length b mod 16 in
if len = 0 then "" else String.make (16 - len) '\000'
and len =
let data = Cstruct.create 16 in
Cstruct.LE.set_uint64 data 0 (Int64.of_int (Cstruct.length adata));
Cstruct.LE.set_uint64 data 8 (Int64.of_int (Cstruct.length ciphertext));
data
let data = Bytes.create 16 in
Bytes.set_int64_le data 0 (Int64.of_int (String.length adata));
Bytes.set_int64_le data 8 (Int64.of_int (String.length ciphertext));
Bytes.unsafe_to_string data
in
let ctx = P.empty ~key in
let ctx = P.feed ctx adata in
let ctx = P.feed ctx (pad16 adata) in
let ctx = P.feed ctx ciphertext in
let ctx = P.feed ctx (pad16 ciphertext) in
let ctx = P.feed ctx len in
P.get ctx
P.macl ~key [ adata ; pad16 adata ; ciphertext ; pad16 ciphertext ; len ]

let authenticate_encrypt_tag ~key ~nonce ?(adata = Cstruct.empty) data =
let adata = Cstruct.to_string adata in
let nonce = Cstruct.to_string nonce in
let data = Cstruct.to_string data in
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ciphertext = crypt ~key ~nonce ~ctr:1L data in
let mac = mac ~key:poly1305_key ~adata ciphertext in
ciphertext, mac
Cstruct.of_string ciphertext, Cstruct.of_string mac

let authenticate_encrypt ~key ~nonce ?adata data =
let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in
Cstruct.append cdata ctag

let authenticate_decrypt_tag ~key ~nonce ?(adata = Cstruct.empty) ~tag data =
let adata = Cstruct.to_string adata in
let nonce = Cstruct.to_string nonce in
let data = Cstruct.to_string data in
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ctag = mac ~key:poly1305_key ~adata data in
let plain = crypt ~key ~nonce ~ctr:1L data in
if Eqaf_cstruct.equal tag ctag then Some plain else None
if Eqaf_cstruct.equal tag (Cstruct.of_string ctag) then Some (Cstruct.of_string plain) else None

let authenticate_decrypt ~key ~nonce ?adata data =
if Cstruct.length data < P.mac_size then
Expand All @@ -112,4 +113,10 @@ let authenticate_decrypt ~key ~nonce ?adata data =
let cipher, tag = Cstruct.split data (Cstruct.length data - P.mac_size) in
authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher

let auth_enc_str ~key ~nonce ?(adata = "") data =
let poly1305_key = generate_poly1305_key ~key ~nonce in
let ciphertext = crypt ~key ~nonce ~ctr:1L data in
let mac = mac ~key:poly1305_key ~adata ciphertext in
ciphertext ^ mac

let tag_size = P.mac_size
20 changes: 13 additions & 7 deletions src/mirage_crypto.mli
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ end

(** The poly1305 message authentication code *)
module Poly1305 : sig
type mac = Cstruct.t
type mac = string

type 'a iter = ('a -> unit) -> unit

Expand All @@ -214,27 +214,30 @@ module Poly1305 : sig
val mac_size : int
(** [mac_size] is the size of the output. *)

val empty : key:Cstruct.t -> t
val empty : key:string -> t
(** [empty] is the empty context with the given [key].
@raise Invalid_argument if key is not 32 bytes. *)

val feed : t -> Cstruct.t -> t
val feed : t -> string -> t
(** [feed t msg] adds the information in [msg] to [t]. *)

val feedi : t -> Cstruct.t iter -> t
val feedi : t -> string iter -> t
(** [feedi t iter] feeds iter into [t]. *)

val get : t -> mac
(** [get t] is the mac corresponding to [t]. *)

val mac : key:Cstruct.t -> Cstruct.t -> mac
val mac : key:string -> string -> mac
(** [mac ~key msg] is the all-in-one mac computation:
[get (feed (empty ~key) msg)]. *)

val maci : key:Cstruct.t -> Cstruct.t iter -> mac
val maci : key:string -> string iter -> mac
(** [maci ~key iter] is the all-in-one mac computation:
[get (feedi (empty ~key) iter)]. *)

val macl : key:string -> string list -> mac
(** [macl ~key datas] computes the [mac] of [datas]. *)
end

(** {1 Symmetric-key cryptography} *)
Expand Down Expand Up @@ -506,7 +509,7 @@ end
module Chacha20 : sig
include AEAD

val crypt : key:key -> nonce:Cstruct.t -> ?ctr:int64 -> Cstruct.t -> Cstruct.t
val crypt : key:key -> nonce:string -> ?ctr:int64 -> string -> string
(** [crypt ~key ~nonce ~ctr data] generates a ChaCha20 key stream using
the [key], and [nonce]. The [ctr] defaults to 0. The generated key
stream is of the same length as [data], and the output is the XOR
Expand All @@ -520,6 +523,9 @@ module Chacha20 : sig
IETF mode (and counter fit into 32 bits), or [key] must be either 16
bytes or 32 bytes and [nonce] 8 bytes.
*)

val auth_enc_str : key:key -> nonce:string -> ?adata:string ->
string -> string
end

(** Streaming ciphers. *)
Expand Down
10 changes: 6 additions & 4 deletions src/native.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ module DES = struct
end

module Chacha = struct
external round : int -> buffer -> off -> buffer -> off -> unit = "mc_chacha_round" [@@noalloc]
external round : int -> bytes -> bytes -> off -> unit = "mc_chacha_round" [@@noalloc]
end

module Poly1305 = struct
external init : ctx -> buffer -> off -> unit = "mc_poly1305_init" [@@noalloc]
external update : ctx -> buffer -> off -> size -> unit = "mc_poly1305_update" [@@noalloc]
external finalize : ctx -> buffer -> off -> unit = "mc_poly1305_finalize" [@@noalloc]
external init : ctx -> string -> unit = "mc_poly1305_init" [@@noalloc]
external update : ctx -> string -> size -> unit = "mc_poly1305_update" [@@noalloc]
external finalize : ctx -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc]
external ctx_size : unit -> int = "mc_poly1305_ctx_size" [@@noalloc]
external mac_size : unit -> int = "mc_poly1305_mac_size" [@@noalloc]
end
Expand Down Expand Up @@ -95,6 +95,8 @@ end
* Unsolved: bounds-checked XORs are slowing things down considerably... *)
external xor_into : buffer -> off -> buffer -> off -> size -> unit = "mc_xor_into" [@@noalloc]

external xor_into_bytes : string -> off -> bytes -> off -> size -> unit = "mc_xor_into_bytes" [@@noalloc]

external count8be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_8_be" [@@noalloc]
external count16be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be" [@@noalloc]
external count16be4 : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be_4" [@@noalloc]
Expand Down
34 changes: 10 additions & 24 deletions src/native/chacha.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "mirage_crypto.h"

extern void mc_chacha_core_generic(int count, uint8_t *src, uint8_t *dst);
extern void mc_chacha_core_generic(int count, const uint32_t *src, uint32_t *dst);

#ifdef __mc_ACCELERATE__

Expand All @@ -13,24 +13,10 @@ static inline void mc_chacha_quarterround(uint32_t *x, int a, int b, int c, int
x[c] += x[d]; x[b] = rol32(x[b] ^ x[c], 7);
}

static inline uint32_t mc_get_u32_le(uint8_t *input, int offset) {
return input[offset]
| (input[offset + 1] << 8)
| (input[offset + 2] << 16)
| (input[offset + 3] << 24);
}

static inline void mc_set_u32_le(uint8_t *input, int offset, uint32_t value) {
input[offset] = (uint8_t) value;
input[offset + 1] = (uint8_t) (value >> 8);
input[offset + 2] = (uint8_t) (value >> 16);
input[offset + 3] = (uint8_t) (value >> 24);
}

static void mc_chacha_core(int count, uint8_t *src, uint8_t *dst) {
static void mc_chacha_core(int count, const uint32_t *src, uint32_t *dst) {
uint32_t x[16];
for (int i = 0; i < 16; i++) {
x[i] = mc_get_u32_le(src, i * 4);
x[i] = src[i];
}
for (int i = 0; i < count; i++) {
mc_chacha_quarterround(x, 0, 4, 8, 12);
Expand All @@ -45,26 +31,26 @@ static void mc_chacha_core(int count, uint8_t *src, uint8_t *dst) {
}
for (int i = 0; i < 16; i++) {
uint32_t xi = x[i];
uint32_t hj = mc_get_u32_le(src, i * 4);
mc_set_u32_le(dst, i * 4, xi + hj);
uint32_t hj = src[i];
dst[i] = xi + hj;
}
}

CAMLprim value
mc_chacha_round(value count, value src, value off1, value dst, value off2)
mc_chacha_round(value count, value src, value dst, value off)
{
_mc_switch_accel(ssse3,
mc_chacha_core_generic(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2)),
mc_chacha_core(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2)));
mc_chacha_core_generic(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off))),
mc_chacha_core(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off))));
return Val_unit;
}

#else //#ifdef __mc_ACCELERATE__

CAMLprim value
mc_chacha_round(value count, value src, value off1, value dst, value off2)
mc_chacha_round(value count, value src, value dst, value off)
{
mc_chacha_core_generic(Int_val(count), _ba_uint8_off(src, off1), _ba_uint8_off(dst, off2));
mc_chacha_core_generic(Int_val(count), (const uint32_t *)(String_val(src)), (uint32_t *)(Bytes_val(dst) + Long_val(off)));
return Val_unit;
}

Expand Down
Loading

0 comments on commit e46d028

Please sign in to comment.