From 1ca85f346429a02ff8f708be80b57f4e836eeaa9 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 19 Mar 2024 21:05:08 +0100 Subject: [PATCH] avoid global buffers (#219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * avoid global buffers * rng: safety - ensure generate_into takes a long enough buffer (raise otherwise) * rng: interrupt_hook only one unit argument (@reynir) * remove offset from counters Co-authored-by: Reynir Björnsson Co-authored-by: Calascibetta Romain --- ec/mirage_crypto_ec.ml | 46 ++++++++++++++++++------------------ pk/dsa.ml | 2 +- pk/rsa.ml | 6 ++--- rng/entropy.ml | 12 ++++------ rng/fortuna.ml | 17 +++++++------- rng/hmac_drbg.ml | 4 ++-- rng/mirage_crypto_rng.mli | 4 ++-- rng/rng.ml | 2 ++ src/ccm.ml | 4 ++-- src/cipher_block.ml | 48 +++++++++++++++++++------------------- src/native.ml | 6 ++--- src/native/mirage_crypto.h | 2 +- src/native/misc.c | 4 ++-- src/native/misc_sse.c | 6 ++--- tests/test_entropy.ml | 2 +- 15 files changed, 82 insertions(+), 83 deletions(-) diff --git a/ec/mirage_crypto_ec.ml b/ec/mirage_crypto_ec.ml index 25772ddc..5c1e47ed 100644 --- a/ec/mirage_crypto_ec.ml +++ b/ec/mirage_crypto_ec.ml @@ -20,7 +20,7 @@ let pp_error fmt e = let rev_string buf = let len = String.length buf in - let res = Bytes.make len '\000' in + let res = Bytes.create len in for i = 0 to len - 1 do Bytes.set res (len - 1 - i) (String.get buf i) done ; @@ -135,7 +135,7 @@ end module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct let b_uts b = Bytes.unsafe_to_string b - let create () = Bytes.make P.fe_length '\000' + let create () = Bytes.create P.fe_length let mul a b = let tmp = create () in @@ -190,7 +190,7 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc b_uts tmp let create_octets () = - Bytes.make P.byte_length '\000' + Bytes.create P.byte_length let to_octets fe = let tmp = create_octets () in @@ -307,19 +307,19 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct | None -> String.make 1 '\000' | Some (x, y) -> let len_x = String.length x and len_y = String.length y in - let res = Bytes.make (1 + len_x + len_y) '\000' in + let res = Bytes.create (1 + len_x + len_y) in Bytes.set res 0 '\004' ; let rev_x = rev_string x and rev_y = rev_string y in - Bytes.blit_string rev_x 0 res 1 len_x ; - Bytes.blit_string rev_y 0 res (1 + len_x) len_y ; + Bytes.unsafe_blit_string rev_x 0 res 1 len_x ; + Bytes.unsafe_blit_string rev_y 0 res (1 + len_x) len_y ; Bytes.unsafe_to_string res in if compress then - let out = Bytes.make (P.byte_length + 1) '\000' in + let out = Bytes.create (P.byte_length + 1) in let ident = 2 + (string_get_uint8 buf ((P.byte_length * 2) - 1)) land 1 in - Bytes.blit_string buf 1 out 1 P.byte_length; + Bytes.unsafe_blit_string buf 1 out 1 P.byte_length; Bytes.set_uint8 out 0 ident; Bytes.unsafe_to_string out else @@ -391,10 +391,10 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct 2 + (string_get_uint8 y_struct (P.byte_length - 2)) land 1 in let res = if Int.equal signY ident then y_struct else y_struct2 in - let out = Bytes.make ((P.byte_length * 2) + 1) '\000' in + let out = Bytes.create ((P.byte_length * 2) + 1) in Bytes.set out 0 '\004'; - Bytes.blit_string pk 1 out 1 P.byte_length; - Bytes.blit_string res 0 out (P.byte_length + 1) P.byte_length; + Bytes.unsafe_blit_string pk 1 out 1 P.byte_length; + Bytes.unsafe_blit_string res 0 out (P.byte_length + 1) P.byte_length; Bytes.unsafe_to_string out let of_octets buf = @@ -547,9 +547,9 @@ end module Make_Fn (P : Parameters) (F : Foreign_n) : Fn = struct let b_uts = Bytes.unsafe_to_string - let create () = Bytes.make P.fe_length '\000' + let create () = Bytes.create P.fe_length - let create_octets () = Bytes.make P.byte_length '\000' + let create_octets () = Bytes.create P.byte_length let from_be_octets v = let v' = create () in @@ -617,7 +617,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige msg else ( let res = Bytes.make bl '\000' in - Bytes.blit_string msg 0 res (bl - l) (String.length msg) ; + Bytes.unsafe_blit_string msg 0 res (bl - l) l ; Bytes.unsafe_to_string res ) (* RFC 6979: compute a deterministic k *) @@ -907,7 +907,7 @@ module X25519 = struct let key_len = 32 let scalar_mult in_ base = - let out = Bytes.make key_len '\000' in + let out = Bytes.create key_len in x25519_scalar_mult_generic out in_ base; Bytes.unsafe_to_string out @@ -949,17 +949,17 @@ module Ed25519 = struct let key_len = 32 let scalar_mult_base_to_bytes p = - let tmp = Bytes.make key_len '\000' in + let tmp = Bytes.create key_len in scalar_mult_base_to_bytes tmp p; Bytes.unsafe_to_string tmp let muladd a b c = - let tmp = Bytes.make key_len '\000' in + let tmp = Bytes.create key_len in muladd tmp a b c; Bytes.unsafe_to_string tmp let double_scalar_mult a b c = - let tmp = Bytes.make key_len '\000' in + let tmp = Bytes.create key_len in let s = double_scalar_mult tmp a b c in s, Bytes.unsafe_to_string tmp @@ -1024,9 +1024,9 @@ module Ed25519 = struct reduce_l k; let k = Bytes.unsafe_to_string k in let s_out = muladd k s r in - let res = Bytes.make (key_len + key_len) '\000' in - Bytes.blit_string r_big 0 res 0 key_len ; - Bytes.blit_string s_out 0 res key_len key_len ; + let res = Bytes.create (key_len + key_len) in + Bytes.unsafe_blit_string r_big 0 res 0 key_len ; + Bytes.unsafe_blit_string s_out 0 res key_len key_len ; Bytes.unsafe_to_string res let verify ~key signature ~msg = @@ -1039,10 +1039,10 @@ module Ed25519 = struct let s_smaller_l = (* check s within 0 <= s < L *) let s' = Bytes.make (key_len * 2) '\000' in - Bytes.blit_string s 0 s' 0 key_len; + Bytes.unsafe_blit_string s 0 s' 0 key_len; reduce_l s'; let s' = Bytes.unsafe_to_string s' in - let s'' = String.concat "" [ s; String.make key_len '\000' ] in + let s'' = s ^ String.make key_len '\000' in String.equal s'' s' in if s_smaller_l then begin diff --git a/pk/dsa.ml b/pk/dsa.ml index 352d4318..a460702e 100644 --- a/pk/dsa.ml +++ b/pk/dsa.ml @@ -149,7 +149,7 @@ let rec shift_left_inplace buf = function | bits when bits mod 8 = 0 -> let off = bits / 8 in let to_blit = Bytes.length buf - off in - Bytes.blit buf off buf 0 to_blit ; + Bytes.unsafe_blit buf off buf 0 to_blit ; Bytes.unsafe_fill buf to_blit (Bytes.length buf - to_blit) '\x00' | bits when bits < 8 -> let foo = 8 - bits in diff --git a/pk/rsa.ml b/pk/rsa.ml index cc0d7d35..5eae25b9 100644 --- a/pk/rsa.ml +++ b/pk/rsa.ml @@ -307,10 +307,10 @@ end module MGF1 (H : Digestif.S) = struct - let _buf = Bytes.create 4 let repr n = - Bytes.set_int32_be _buf 0 n; - Bytes.unsafe_to_string _buf + let buf = Bytes.create 4 in + Bytes.set_int32_be buf 0 n; + Bytes.unsafe_to_string buf (* Assumes len < 2^32 * H.digest_size. *) let mgf ~seed len = diff --git a/rng/entropy.ml b/rng/entropy.ml index 710989b6..d7226e6a 100644 --- a/rng/entropy.ml +++ b/rng/entropy.ml @@ -130,17 +130,15 @@ let bootstrap id = let interrupt_hook () = let buf = Bytes.create 4 in - fun () -> - let a = Cpu_native.cycles () in - Bytes.set_int32_le buf 0 (Int32.of_int a) ; - Bytes.unsafe_to_string buf + let a = Cpu_native.cycles () in + Bytes.set_int32_le buf 0 (Int32.of_int a) ; + Bytes.unsafe_to_string buf let timer_accumulator g = let g = match g with None -> Some (Rng.default_generator ()) | Some g -> Some g in let source = register_source "timer" in let `Acc handle = Rng.accumulate g source in - let hook = interrupt_hook () in - (fun () -> handle (hook ())) + (fun () -> handle (interrupt_hook ())) let feed_pools g source f = let g = match g with None -> Some (Rng.default_generator ()) | Some g -> Some g in @@ -159,8 +157,8 @@ let cpu_rng = let s = match insn with `Rdrand -> "rdrand" | `Rdseed -> "rdseed" in register_source s in - let buf = Bytes.create 8 in let f () = + let buf = Bytes.create 8 in Bytes.set_int64_le buf 0 (Int64.of_int (randomf ())); Bytes.unsafe_to_string buf in diff --git a/rng/fortuna.ml b/rng/fortuna.ml index 1930682d..853a9b07 100644 --- a/rng/fortuna.ml +++ b/rng/fortuna.ml @@ -68,7 +68,7 @@ let generate_rekey ~g buf ~off len = let b = len // block + 2 in let n = b * block in let r = AES_CTR.stream ~key:g.key ~ctr:g.ctr n in - Bytes.blit_string r 0 buf off len; + Bytes.unsafe_blit_string r 0 buf off len; let r2 = String.sub r (n - 32) 32 in set_key ~g r2 ; g.ctr <- AES_CTR.add_ctr g.ctr (Int64.of_int b) @@ -105,15 +105,14 @@ let generate_into ~g buf ~off len = in chunk off len -let _buf = Bytes.create 2 - let add ~g (source, _) ~pool data = - let pool = pool land (pools - 1) - and source = source land 0xff in - Bytes.set_uint8 _buf 0 source; - Bytes.set_uint8 _buf 1 (String.length data); - g.pools.(pool) <- SHAd256.feedi g.pools.(pool) (iter2 (Bytes.unsafe_to_string _buf) data); - if pool = 0 then g.pool0_size <- g.pool0_size + String.length data + let buf = Bytes.create 2 + and pool = pool land (pools - 1) + and source = source land 0xff in + Bytes.set_uint8 buf 0 source; + Bytes.set_uint8 buf 1 (String.length data); + g.pools.(pool) <- SHAd256.feedi g.pools.(pool) (iter2 (Bytes.unsafe_to_string buf) data); + if pool = 0 then g.pool0_size <- g.pool0_size + String.length data (* XXX * Schneier recommends against using generator-imposed pool-seeding schedule diff --git a/rng/hmac_drbg.ml b/rng/hmac_drbg.ml index e48d7b5a..0dc781ca 100644 --- a/rng/hmac_drbg.ml +++ b/rng/hmac_drbg.ml @@ -34,11 +34,11 @@ module Make (H : Digestif.S) = struct let rem = len mod H.digest_size in if rem = 0 then H.digest_size else rem in - Bytes.blit_string v 0 buf off len; + Bytes.unsafe_blit_string v 0 buf off len; v | i -> let v = H.hmac_string ~key:k v |> H.to_raw_string in - Bytes.blit_string v 0 buf off H.digest_size; + Bytes.unsafe_blit_string v 0 buf off H.digest_size; go (off + H.digest_size) k v (pred i) in let v = go off g.k g.v Mirage_crypto.Uncommon.(len // H.digest_size) in diff --git a/rng/mirage_crypto_rng.mli b/rng/mirage_crypto_rng.mli index 63aff42b..6b9210d2 100644 --- a/rng/mirage_crypto_rng.mli +++ b/rng/mirage_crypto_rng.mli @@ -114,8 +114,8 @@ module Entropy : sig (** {1 Timer source} *) - val interrupt_hook : unit -> unit -> string - (** [interrupt_hook ()] collects lower bytes from the cycle counter, to be + val interrupt_hook : unit -> string + (** [interrupt_hook] collects lower bytes from the cycle counter, to be used for entropy collection in the event loop. *) val timer_accumulator : g option -> unit -> unit diff --git a/rng/rng.ml b/rng/rng.ml index 2ddaa4cd..a6948589 100644 --- a/rng/rng.ml +++ b/rng/rng.ml @@ -70,6 +70,8 @@ let get = function Some g -> g | None -> default_generator () let generate_into ?(g = default_generator ()) b ?(off = 0) n = let Generator (g, _, m) = g in let module M = (val m) in + if Bytes.length b - off < n then + invalid_arg "buffer too short"; M.generate_into ~g b ~off n let generate ?g n = diff --git a/src/ccm.ml b/src/ccm.ml index 746c02d5..a8368a03 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -46,8 +46,8 @@ let gen_adata a = llen + String.length a + to_pad, fun buf off -> set_llen buf off; - Bytes.blit_string a 0 buf (off + llen) (String.length a); - Bytes.fill buf (off + llen + String.length a) to_pad '\000' + Bytes.unsafe_blit_string a 0 buf (off + llen) (String.length a); + Bytes.unsafe_fill buf (off + llen + String.length a) to_pad '\000' let gen_ctr nonce i = let n = String.length nonce in diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 7fd43365..21a4b3ab 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -83,20 +83,19 @@ module Counters = struct val size : int val add : ctr -> int64 -> ctr val of_octets : string -> ctr - val unsafe_count_into : ctr -> bytes -> int -> blocks:int -> unit + val unsafe_count_into : ctr -> bytes -> blocks:int -> unit end - let _tmp = Bytes.make 16 '\x00' - module C64be = struct type ctr = int64 let size = 8 (* Until OCaml 4.13 is lower bound*) let of_octets cs = Bytes.get_int64_be (Bytes.unsafe_of_string cs) 0 let add = Int64.add - let unsafe_count_into t buf off ~blocks = - Bytes.set_int64_be _tmp 0 t; - Native.count8be _tmp buf off ~blocks + let unsafe_count_into t buf ~blocks = + let tmp = Bytes.create 8 in + Bytes.set_int64_be tmp 0 t; + Native.count8be tmp buf ~blocks end module C128be = struct @@ -109,9 +108,10 @@ module Counters = struct let w0' = Int64.add w0 n in let flip = if Int64.logxor w0 w0' < 0L then w0' > w0 else w0' < w0 in ((if flip then Int64.succ w1 else w1), w0') - let unsafe_count_into (w1, w0) buf off ~blocks = - Bytes.set_int64_be _tmp 0 w1; Bytes.set_int64_be _tmp 8 w0; - Native.count16be _tmp buf off ~blocks + let unsafe_count_into (w1, w0) buf ~blocks = + let tmp = Bytes.create 16 in + Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; + Native.count16be tmp buf ~blocks end module C128be32 = struct @@ -119,9 +119,10 @@ module Counters = struct let add (w1, w0) n = let hi = 0xffffffff00000000L and lo = 0x00000000ffffffffL in (w1, Int64.(logor (logand hi w0) (add n w0 |> logand lo))) - let unsafe_count_into (w1, w0) buf off ~blocks = - Bytes.set_int64_be _tmp 0 w1; Bytes.set_int64_be _tmp 8 w0; - Native.count16be4 _tmp buf off ~blocks + let unsafe_count_into (w1, w0) buf ~blocks = + let tmp = Bytes.create 16 in + Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; + Native.count16be4 tmp buf ~blocks end end @@ -207,15 +208,15 @@ module Modes = struct let stream ~key ~ctr n = let blocks = imax 0 n / block_size in let buf = Bytes.create n in - Ctr.unsafe_count_into ctr ~blocks buf 0 ; + Ctr.unsafe_count_into ctr ~blocks buf ; Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) 0 buf 0 ; let slack = imax 0 n mod block_size in if slack <> 0 then begin let buf' = Bytes.create block_size in let ctr = Ctr.add ctr (Int64.of_int blocks) in - Ctr.unsafe_count_into ctr ~blocks:1 buf' 0 ; + Ctr.unsafe_count_into ctr ~blocks:1 buf' ; Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string buf') 0 buf' 0 ; - Bytes.blit buf' 0 buf (blocks * block_size) slack + Bytes.unsafe_blit buf' 0 buf (blocks * block_size) slack end; Bytes.unsafe_to_string buf @@ -245,9 +246,8 @@ module Modes = struct let k = Bytes.create keysize in Native.GHASH.keyinit cs k; Bytes.unsafe_to_string k - let hash0 = Bytes.make tagsize '\x00' let digesti ~key i = - let res = Bytes.copy hash0 in + let res = Bytes.make tagsize '\x00' in i (fun cs -> Native.GHASH.ghash key res cs (String.length cs)); Bytes.unsafe_to_string res end @@ -261,21 +261,21 @@ module Modes = struct let tag_size = GHASH.tagsize let key_sizes, block_size = C.(key, block) - let z128, h = String.make block_size '\x00', Bytes.create block_size + let z128 = String.make block_size '\x00' let of_secret cs = + let h = Bytes.create block_size in let key = C.e_of_secret cs in C.encrypt ~key ~blocks:1 z128 0 h 0; { key ; hkey = GHASH.derive (Bytes.unsafe_to_string h) } let bits64 cs = Int64.of_int (String.length cs * 8) - let pack64s = - let _cs = Bytes.create 16 in - fun a b -> - Bytes.set_int64_be _cs 0 a; - Bytes.set_int64_be _cs 8 b; - Bytes.unsafe_to_string _cs + let pack64s a b = + let cs = Bytes.create 16 in + Bytes.set_int64_be cs 0 a; + Bytes.set_int64_be cs 8 b; + Bytes.unsafe_to_string cs (* OCaml 4.13 *) let string_get_int64 s idx = diff --git a/src/native.ml b/src/native.ml index 911a050f..5684235a 100644 --- a/src/native.ml +++ b/src/native.ml @@ -39,9 +39,9 @@ end * Unsolved: bounds-checked XORs are slowing things down considerably... *) external xor_into_bytes : string -> int -> bytes -> int -> int -> unit = "mc_xor_into_bytes" [@@noalloc] -external count8be : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] -external count16be : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] -external count16be4 : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] +external count8be : bytes -> bytes -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] +external count16be : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] +external count16be4 : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] external misc_mode : unit -> int = "mc_misc_mode" [@@noalloc] diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 6608a1b1..0542db2f 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -114,6 +114,6 @@ CAMLprim value mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n); CAMLprim value -mc_count_16_be_4_generic (value ctr, value dst, value off, value blocks); +mc_count_16_be_4_generic (value ctr, value dst, value blocks); #endif /* H__MIRAGE_CRYPTO */ diff --git a/src/native/misc.c b/src/native/misc.c index dea76e18..ba9590f8 100644 --- a/src/native/misc.c +++ b/src/native/misc.c @@ -60,9 +60,9 @@ mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n) } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value off, value blocks) { \ + CAMLprim value name (value ctr, value dst, value blocks) { \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) ); \ + (uint64_t*) _bp_uint8 (dst), Long_val (blocks) ); \ return Val_unit; \ } diff --git a/src/native/misc_sse.c b/src/native/misc_sse.c index a5a068c5..1f2265da 100644 --- a/src/native/misc_sse.c +++ b/src/native/misc_sse.c @@ -48,11 +48,11 @@ mc_xor_into_bytes (value b1, value off1, value b2, value off2, value n) { } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value off, value blocks) { \ + CAMLprim value name (value ctr, value dst, value blocks) { \ _mc_switch_accel(ssse3, \ - name##_generic (ctr, dst, off, blocks), \ + name##_generic (ctr, dst, blocks), \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) )) \ + (uint64_t*) _bp_uint8 (dst), Long_val (blocks) )) \ return Val_unit; \ } diff --git a/tests/test_entropy.ml b/tests/test_entropy.ml index f0984d38..13cb91fe 100644 --- a/tests/test_entropy.ml +++ b/tests/test_entropy.ml @@ -32,7 +32,7 @@ let whirlwind_bootstrap_check () = let timer_check () = for i = 0 to 10 do - let data' = Mirage_crypto_rng.Entropy.interrupt_hook () () in + let data' = Mirage_crypto_rng.Entropy.interrupt_hook () in if String.equal !data data' then begin Ohex.pp Format.std_formatter data'; failwith ("same data from timer at " ^ string_of_int i);