Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ec] Use windowed algorithm for base scalar mult on NIST P-curves #191

Merged
merged 11 commits into from
Feb 19, 2024
4 changes: 4 additions & 0 deletions ec/gen_tables/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(include_subdirs no)
(executable
(name gen_tables)
(libraries mirage_crypto_ec))
95 changes: 95 additions & 0 deletions ec/gen_tables/gen_tables.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
open Format

let print_header name =
printf
{|/*
Pre-computed multiples of the generator point G for the curve %s,
used for speeding up its scalar multiplication in point_operations.h.

Generated by %s, with tables for both 64-bit and 32-bit architectures.
*/|}
name Sys.argv.(0)

let pp_array elem_fmt fmt arr =
let fout = fprintf fmt in
let len = Array.length arr in
fout "@[<2>{@\n";
for i = 0 to len - 1 do
elem_fmt fmt arr.(i);
if i < len - 1 then printf ",@ " else printf ""
done;
fout "@]@,}"

let pp_string_words ~wordsize fmt str =
let limbs = String.length str * 8 / wordsize in
Firobe marked this conversation as resolved.
Show resolved Hide resolved
assert (String.length str * 8 mod wordsize = 0);
let bytes = Bytes.unsafe_of_string str in
fprintf fmt "@[<2>{@\n";
for i = 0 to limbs - 1 do
let index = i * (wordsize / 8) in
(if wordsize = 64 then
let w = Bytes.get_int64_le bytes index in
fprintf fmt "%#016Lx" w
else
let w = Bytes.get_int32_le bytes index in
fprintf fmt "%#08lx" w);
if i < limbs - 1 then printf ",@ " else printf ""
done;
fprintf fmt "@]@,}"

let check_shape tables =
let fe_len = String.length tables.(0).(0).(0) in
let table_len = fe_len * 2 in
assert (Array.length tables = table_len);
Array.iter
(fun x ->
assert (Array.length x = 15);
Array.iter
(fun x ->
assert (Array.length x = 3);
Array.iter (fun x -> assert (String.length x = fe_len)) x)
x)
tables

let print_tables tables ~wordsize =
let fe_len = String.length tables.(0).(0).(0) in
printf "@[<2>static WORD generator_table[%d][15][3][LIMBS] = @," (fe_len * 2);
pp_array
(pp_array (pp_array (pp_string_words ~wordsize)))
std_formatter tables;
printf "@];@,"

let print_toplevel name (module P : Mirage_crypto_ec.Dh_dsa) =
let tables = P.Dsa.Precompute.generator_tables () in
check_shape tables;
print_header name;
printf "@[<v>#ifdef ARCH_64BIT@,";
print_tables ~wordsize:64 tables;
printf "#else // 32-bit@,";
print_tables ~wordsize:32 tables;
printf "@]#endif@."

let curves =
Mirage_crypto_ec.
[
("p224", (module P224 : Dh_dsa));
("p256", (module P256));
("p384", (module P384));
("p521", (module P521));
]

let usage () =
printf "Usage: gen_tables [%a]"
(pp_print_list
~pp_sep:(fun fmt () -> pp_print_string fmt " | ")
pp_print_string)
(List.map fst curves)

let go =
let name, curve =
try List.find (fun (name, _) -> name = Sys.argv.(1)) curves
with _ ->
usage ();
exit 1
in
print_toplevel name curve
56 changes: 51 additions & 5 deletions ec/mirage_crypto_ec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ module type Dsa = sig
module K_gen (H : Mirage_crypto.Hash.S) : sig
val generate : key:priv -> Cstruct.t -> Cstruct.t
end
module Precompute : sig
val generator_tables : unit -> string array array array
end
end

module type Dh_dsa = sig
Expand Down Expand Up @@ -108,6 +111,7 @@ module type Foreign = sig

val double_c : out_point -> point -> unit
val add_c : out_point -> point -> point -> unit
val scalar_mult_base_c : out_point -> string -> unit
end

module type Field_element = sig
Expand All @@ -125,6 +129,7 @@ module type Field_element = sig
val to_octets : field_element -> string
val double_point : point -> point
val add_point : point -> point -> point
val scalar_mult_base_point : scalar -> point
end

module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct
Expand Down Expand Up @@ -213,6 +218,11 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc
let tmp = out_point () in
F.add_c tmp a b;
out_p_to_p tmp

let scalar_mult_base_point (Scalar d) =
let tmp = out_point () in
F.scalar_mult_base_c tmp d;
out_p_to_p tmp
end

module type Point = sig
Expand All @@ -226,6 +236,7 @@ module type Point = sig
val x_of_finite_point : point -> string
val params_g : point
val select : bool -> then_:point -> else_:point -> point
val scalar_mult_base : scalar -> point
end

module Make_point (P : Parameters) (F : Foreign) : Point = struct
Expand Down Expand Up @@ -406,6 +417,8 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
of_octets buf
| 0x00 | 0x04 -> Error `Invalid_length
| _ -> Error `Invalid_format

let scalar_mult_base = Fe.scalar_mult_base_point
end

module type Scalar = sig
Expand All @@ -414,6 +427,8 @@ module type Scalar = sig
val of_octets : string -> (scalar, error) result
val to_octets : scalar -> string
val scalar_mult : scalar -> point -> point
val scalar_mult_base : scalar -> point
val generator_tables : unit -> field_element array array array
end

module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
Expand All @@ -433,6 +448,7 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct

let to_octets (Scalar buf) = rev_string buf

(* Branchless Montgomery ladder method *)
let scalar_mult (Scalar s) p =
let r0 = ref (P.at_infinity ()) in
let r1 = ref p in
Expand All @@ -445,6 +461,28 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
r1 := P.select bit ~then_:r1_double ~else_:sum
done;
!r0

(* Specialization of [scalar_mult d p] when [p] is the generator *)
let scalar_mult_base = P.scalar_mult_base

(* Pre-compute multiples of the generator point *)
let generator_tables () =
let len = Param.fe_length * 2 in
let one_table _ = Array.init 15 (fun _ -> P.at_infinity ()) in
let table = Array.init len one_table in
let base = ref P.params_g in
for i = 0 to len - 1 do
table.(i).(0) <- !base;
for j = 1 to 14 do
table.(i).(j) <- P.add !base table.(i).(j - 1)
done;
base := P.double !base;
base := P.double !base;
base := P.double !base;
base := P.double !base
done;
let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
Array.map (Array.map convert) table
end

module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
Expand All @@ -459,7 +497,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
type secret = scalar

let share ?(compress = false) private_key =
let public_key = S.scalar_mult private_key P.params_g in
let public_key = S.scalar_mult_base private_key in
point_to_octets ~compress public_key

let secret_of_octets ?compress s =
Expand Down Expand Up @@ -668,7 +706,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
in
one ()
in
let q = S.scalar_mult d P.params_g in
let q = S.scalar_mult_base d in
(d, q)

let x_of_finite_point_mod_n p =
Expand All @@ -695,7 +733,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
| Ok ksc -> ksc
| Error _ -> invalid_arg "k not in range" (* if no k is provided, this cannot happen since K_gen_*.gen already preserves the Scalar invariants *)
in
let point = S.scalar_mult ksc P.params_g in
let point = S.scalar_mult_base ksc in
match x_of_finite_point_mod_n point with
| None -> again ()
| Some r ->
Expand All @@ -719,7 +757,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
let r, s = sign_octets ~key ?k:(Option.map Cstruct.to_string k) (Cstruct.to_string msg) in
Cstruct.of_string r, Cstruct.of_string s

let pub_of_priv priv = S.scalar_mult priv P.params_g
let pub_of_priv priv = S.scalar_mult_base priv

let verify_octets ~key (r, s) msg =
try
Expand All @@ -743,7 +781,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
| Ok u1, Ok u2 ->
let point =
P.add
(S.scalar_mult u1 P.params_g)
(S.scalar_mult_base u1)
(S.scalar_mult u2 key)
in
begin match x_of_finite_point_mod_n point with
Expand All @@ -756,6 +794,10 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira

let verify ~key (r, s) digest =
verify_octets ~key (Cstruct.to_string r, Cstruct.to_string s) (Cstruct.to_string digest)

module Precompute = struct
let generator_tables = S.generator_tables
end
end

module P224 : Dh_dsa = struct
Expand Down Expand Up @@ -787,6 +829,7 @@ module P224 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p224_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p224_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p224_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p224_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -836,6 +879,7 @@ module P256 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p256_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p256_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p256_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p256_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -886,6 +930,7 @@ module P384 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p384_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p384_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p384_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p384_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down Expand Up @@ -937,6 +982,7 @@ module P521 : Dh_dsa = struct
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p521_select" [@@noalloc]
external double_c : out_point -> point -> unit = "mc_p521_point_double" [@@noalloc]
external add_c : out_point -> point -> point -> unit = "mc_p521_point_add" [@@noalloc]
external scalar_mult_base_c : out_point -> string -> unit = "mc_p521_scalar_mult_base" [@@noalloc]
end

module Foreign_n = struct
Expand Down
11 changes: 11 additions & 0 deletions ec/mirage_crypto_ec.mli
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ module type Dsa = sig
(** [generate ~key digest] deterministically takes the given private key
and message digest to a [k] suitable for seeding the signing process. *)
end

(** {2 Misc} *)

(** Operations to precompute useful data meant to be hardcoded in
[mirage-crypto-ec] before compilation *)
module Precompute : sig
val generator_tables : unit -> string array array array
(** Return an array of shape (Fe_length * 2, 15, 3) containing multiples of
the generator point for the curve. Useful only to bootstrap tables
necessary for scalar multiplication. *)
end
end

(** Elliptic curve with Diffie-Hellman and DSA. *)
Expand Down
26 changes: 22 additions & 4 deletions ec/native/GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ WBW_MONT ?= ../../../fiat-crypto/src/ExtractionOCaml/word_by_word_montgomery --s
UNSAT_SOLINAS ?= ../../../fiat-crypto/src/ExtractionOCaml/unsaturated_solinas --static --use-value-barrier --inline-internal
N_FUNCS=mul add opp from_montgomery to_montgomery one msat divstep_precomp divstep to_bytes from_bytes selectznz

GEN_TABLE=../../_build/default/ec/gen_tables/gen_tables.exe

# The NIST curve P-224 (AKA SECP224R1)
P224="2^224 - 2^96 + 1"

Expand All @@ -32,8 +34,12 @@ np224_64.h:
np224_32.h:
$(WBW_MONT) np224 32 $(P224N) $(N_FUNCS) > $@

.PHONY: p224_tables.h
p224_tables.h:
$(GEN_TABLE) p224 > $@

.PHONY: p224
p224: p224_64.h p224_32.h np224_64.h np224_32.h
p224: p224_64.h p224_32.h np224_64.h np224_32.h p224_tables.h


# The NIST curve P-256 (AKA SECP256R1)
Expand All @@ -58,8 +64,12 @@ np256_64.h:
np256_32.h:
$(WBW_MONT) np256 32 $(P256N) $(N_FUNCS) > $@

.PHONY: p256_tables.h
p256_tables.h:
$(GEN_TABLE) p256 > $@

.PHONY: p256
p256: p256_64.h p256_32.h np256_64.h np256_32.h
p256: p256_64.h p256_32.h np256_64.h np256_32.h p256_tables.h

# The NIST curve P-384 (AKA SECP384R1)
P384="2^384 - 2^128 - 2^96 + 2^32 - 1"
Expand All @@ -83,8 +93,12 @@ np384_64.h:
np384_32.h:
$(WBW_MONT) np384 32 $(P384N) $(N_FUNCS) > $@

.PHONY: p384_tables.h
p384_tables.h:
$(GEN_TABLE) p384 > $@

.PHONY: p384
p384: p384_64.h p384_32.h np384_64.h np384_32.h
p384: p384_64.h p384_32.h np384_64.h np384_32.h p384_tables.h

# The NIST curve P-521 (AKA SECP521R1)
P521="2^521 - 1"
Expand All @@ -108,8 +122,12 @@ np521_64.h:
np521_32.h:
$(WBW_MONT) np521 32 $(P521N) $(N_FUNCS) > $@

.PHONY: p521_tables.h
p521_tables.h:
$(GEN_TABLE) p521 > $@

.PHONY: p521
p521: p521_64.h p521_32.h np521_64.h np521_32.h
p521: p521_64.h p521_32.h np521_64.h np521_32.h p521_tables.h

# 25519
25519="2^255 - 19"
Expand Down
Loading
Loading