Skip to content

Commit

Permalink
mirage-crypto-ec: Use windowed algorithm for base scalar mult on NIST…
Browse files Browse the repository at this point in the history
… P-curves (mirage#191)

* [ec] Use windowed algorithm for base scalar mult

Using a sliding window method with pre-computed values of multiples of
the generator point, obtain far more efficient performance for the
special case where G = P in the scalar multiplication kP.

By using a safe selection algorithm for pre-computed values and no
branches in the main loop, the algorithm leaks no less information about
its inputs than the current Montgomery ladder.

* [ec] Rewrite scalar_mult_base in C

For performance. This implies the need to get generator points from C as
well. The pre-computed tables are stored in static memory, and computed
lazily.

* Generate pre-tables AOT and hardcode them

* Separate 64/32 tables

* Add 32-bit tables
  • Loading branch information
Firobe committed Apr 8, 2024
1 parent 24c7cb7 commit 040ef22
Show file tree
Hide file tree
Showing 20 changed files with 176,144 additions and 5 deletions.
4 changes: 4 additions & 0 deletions ec/gen_tables/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(include_subdirs no)
(executable
(name gen_tables)
(libraries mirage_crypto_ec))
113 changes: 113 additions & 0 deletions ec/gen_tables/gen_tables.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
open Format

let print_header name =
printf
{|
/*
Pre-computed %d-bit multiples of the generator point G for the curve %s,
used for speeding up its scalar multiplication in point_operations.h.

Generated by %s
*/|}
Sys.word_size name Sys.argv.(0)

let pp_array elem_fmt fmt arr =
let fout = fprintf fmt in
let len = Array.length arr in
fout "@[<2>{@\n";
for i = 0 to len - 1 do
elem_fmt fmt arr.(i);
if i < len - 1 then printf ",@ " else printf ""
done;
fout "@]@,}"

let div_round_up a b = (a / b) + if a mod b = 0 then 0 else 1

let pp_string_words ~wordsize fmt str =
assert (String.length str * 8 mod wordsize = 0);
let limbs = String.length str * 8 / wordsize in
(* Truncate at the beginning (little-endian) *)
let bytes = Bytes.unsafe_of_string str in
(* let bytes = rev_str_bytes str in *)
fprintf fmt "@[<2>{@\n";
for i = 0 to limbs - 1 do
let index = i * (wordsize / 8) in
(if wordsize = 64 then
let w = Bytes.get_int64_le bytes index in
fprintf fmt "%#016Lx" w
else
let w = Bytes.get_int32_le bytes index in
fprintf fmt "%#08lx" w);
if i < limbs - 1 then printf ",@ " else printf ""
done;
fprintf fmt "@]@,}"

let check_shape tables =
let fe_len = String.length tables.(0).(0).(0) in
let table_len = fe_len * 2 in
assert (Array.length tables = table_len);
Array.iter
(fun x ->
assert (Array.length x = 15);
Array.iter
(fun x ->
assert (Array.length x = 3);
Array.iter (fun x -> assert (String.length x = fe_len)) x)
x)
tables

let print_tables tables ~wordsize =
let fe_len = String.length tables.(0).(0).(0) in
printf "@[<2>static WORD generator_table[%d][15][3][LIMBS] = @," (fe_len * 2);
pp_array
(pp_array (pp_array (pp_string_words ~wordsize)))
std_formatter tables;
printf "@];@,"

let print_toplevel name wordsize (module P : Mirage_crypto_ec.Dh_dsa) =
let tables = P.Dsa.Precompute.generator_tables () in
assert (wordsize = Sys.word_size);
check_shape tables;
print_header name;
if wordsize = 64 then
printf
"@[<v>#ifndef ARCH_64BIT@,\
#error \"Cannot use 64-bit tables on a 32-bit architecture\"@,\
#endif@,\
@]"
else
printf
"@[<v>#ifdef ARCH_64BIT@,\
#error \"Cannot use 32-bit tables on a 64-bit architecture\"@,\
#endif@,\
@]";
print_tables ~wordsize tables

let curves =
Mirage_crypto_ec.
[
("p224", (module P224 : Dh_dsa));
("p256", (module P256));
("p384", (module P384));
("p521", (module P521));
]

let usage () =
printf "Usage: gen_tables [%a] [64 | 32]@."
(pp_print_list
~pp_sep:(fun fmt () -> pp_print_string fmt " | ")
pp_print_string)
(List.map fst curves)

let go =
let name, curve, wordsize =
try
let name, curve =
List.find (fun (name, _) -> name = Sys.argv.(1)) curves
in
(name, curve, int_of_string Sys.argv.(2))
with _ ->
usage ();
exit 1
in
print_toplevel name wordsize curve
10 changes: 10 additions & 0 deletions ec/implementation.mld
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ The following references discuss this algorithm:
- {{:https://eprint.iacr.org/2017/293.pdf}Montgomery curves and the Montgomery
ladder, Daniel J. Bernstein and Tanja Lange}

For the special case of base scalar multiplication (where the generator point of
the curve specifically is multiplied by a scalar), instead an algorithm
(implemented by hand in C) using pre-computed tables of point doubling is used
(tables are in `native/p*_tables_32|64.c`).

The key for this algorithm being constant-time is the function selecting values
from the tables, which conceals what value it selects by exploring the whole
table in the same order no matter the input, using const-time selection (as
defined in fiat code). See `native/point_operations.h`.

{2 Key exchange}

Key exchange consists in
Expand Down
57 changes: 52 additions & 5 deletions ec/mirage_crypto_ec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ module type Dsa = sig
module K_gen (H : Mirage_crypto.Hash.S) : sig
val generate : key:priv -> Cstruct.t -> Cstruct.t
end
module Precompute : sig
val generator_tables : unit -> string array array array
end
end

module type Dh_dsa = sig
Expand Down Expand Up @@ -108,6 +111,7 @@ module type Foreign = sig

val double_c : out_point -> point -> unit
val add_c : out_point -> point -> point -> unit
val scalar_mult_base_c : out_point -> string -> unit
end

module type Field_element = sig
Expand All @@ -125,6 +129,7 @@ module type Field_element = sig
val to_octets : field_element -> string
val double_point : point -> point
val add_point : point -> point -> point
val scalar_mult_base_point : scalar -> point
end

module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct
Expand Down Expand Up @@ -213,6 +218,11 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc
let tmp = out_point () in
F.add_c tmp a b;
out_p_to_p tmp

let scalar_mult_base_point (Scalar d) =
let tmp = out_point () in
F.scalar_mult_base_c tmp d;
out_p_to_p tmp
end

module type Point = sig
Expand All @@ -226,6 +236,7 @@ module type Point = sig
val x_of_finite_point : point -> string
val params_g : point
val select : bool -> then_:point -> else_:point -> point
val scalar_mult_base : scalar -> point
end

module Make_point (P : Parameters) (F : Foreign) : Point = struct
Expand Down Expand Up @@ -406,6 +417,8 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
of_octets buf
| 0x00 | 0x04 -> Error `Invalid_length
| _ -> Error `Invalid_format

let scalar_mult_base = Fe.scalar_mult_base_point
end

module type Scalar = sig
Expand All @@ -414,6 +427,8 @@ module type Scalar = sig
val of_octets : string -> (scalar, error) result
val to_octets : scalar -> string
val scalar_mult : scalar -> point -> point
val scalar_mult_base : scalar -> point
val generator_tables : unit -> field_element array array array
end

module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
Expand All @@ -433,6 +448,7 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct

let to_octets (Scalar buf) = rev_string buf

(* Branchless Montgomery ladder method *)
let scalar_mult (Scalar s) p =
let r0 = ref (P.at_infinity ()) in
let r1 = ref p in
Expand All @@ -445,6 +461,29 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
r1 := P.select bit ~then_:r1_double ~else_:sum
done;
!r0

(* Specialization of [scalar_mult d p] when [p] is the generator *)
let scalar_mult_base = P.scalar_mult_base

(* Pre-compute multiples of the generator point
returns the tables along with the number of significant bytes *)
let generator_tables () =
let len = Param.fe_length * 2 in
let one_table _ = Array.init 15 (fun _ -> P.at_infinity ()) in
let table = Array.init len one_table in
let base = ref P.params_g in
for i = 0 to len - 1 do
table.(i).(0) <- !base;
for j = 1 to 14 do
table.(i).(j) <- P.add !base table.(i).(j - 1)
done;
base := P.double !base;
base := P.double !base;
base := P.double !base;
base := P.double !base
done;
let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
Array.map (Array.map convert) table
end

module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
Expand All @@ -459,7 +498,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
type secret = scalar

let share ?(compress = false) private_key =
let public_key = S.scalar_mult private_key P.params_g in
let public_key = S.scalar_mult_base private_key in
point_to_octets ~compress public_key

let secret_of_octets ?compress s =
Expand Down Expand Up @@ -668,7 +707,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
in
one ()
in
let q = S.scalar_mult d P.params_g in
let q = S.scalar_mult_base d in
(d, q)

let x_of_finite_point_mod_n p =
Expand All @@ -695,7 +734,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
| Ok ksc -> ksc
| Error _ -> invalid_arg "k not in range" (* if no k is provided, this cannot happen since K_gen_*.gen already preserves the Scalar invariants *)
in
let point = S.scalar_mult ksc P.params_g in
let point = S.scalar_mult_base ksc in
match x_of_finite_point_mod_n point with
| None -> again ()
| Some r ->
Expand All @@ -719,7 +758,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
let r, s = sign_octets ~key ?k:(Option.map Cstruct.to_string k) (Cstruct.to_string msg) in
Cstruct.of_string r, Cstruct.of_string s

let pub_of_priv priv = S.scalar_mult priv P.params_g
let pub_of_priv priv = S.scalar_mult_base priv

let verify_octets ~key (r, s) msg =
try
Expand All @@ -743,7 +782,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
| Ok u1, Ok u2 ->
let point =
P.add
(S.scalar_mult u1 P.params_g)
(S.scalar_mult_base u1)
(S.scalar_mult u2 key)
in
begin match x_of_finite_point_mod_n point with
Expand All @@ -756,6 +795,10 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira

let verify ~key (r, s) digest =
verify_octets ~key (Cstruct.to_string r, Cstruct.to_string s) (Cstruct.to_string digest)

module Precompute = struct
let generator_tables = S.generator_tables
end
end

module P224 : Dh_dsa = struct
Expand Down Expand Up @@ -787,6 +830,7 @@ module P224 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p224_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p224_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p224_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p224_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -836,6 +880,7 @@ module P256 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p256_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p256_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p256_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p256_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -886,6 +931,7 @@ module P384 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p384_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p384_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p384_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p384_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -937,6 +983,7 @@ module P521 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p521_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p521_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p521_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p521_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down
11 changes: 11 additions & 0 deletions ec/mirage_crypto_ec.mli
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ module type Dsa = sig
(** [generate ~key digest] deterministically takes the given private key
and message digest to a [k] suitable for seeding the signing process. *)
end

(** {2 Misc} *)

(** Operations to precompute useful data meant to be hardcoded in
[mirage-crypto-ec] before compilation *)
module Precompute : sig
val generator_tables : unit -> string array array array
(** Return an array of shape (Fe_length * 2, 15, 3) containing multiples of
the generator point for the curve. Useful only to bootstrap tables
necessary for scalar multiplication. *)
end
end

(** Elliptic curve with Diffie-Hellman and DSA. *)
Expand Down
Loading

0 comments on commit 040ef22

Please sign in to comment.