From 24c7cb70c19e2729fa10c2119c623703107954fd Mon Sep 17 00:00:00 2001 From: Calascibetta Romain Date: Tue, 13 Feb 2024 11:25:34 +0100 Subject: [PATCH] Replace the internal usage of Cstruct.t by string (#146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Originally, we used Cstruct.t (bigarray) for interfacing. Instead, we use string now. The benefit is that allocating a string is cheap, and in line with OCaml's GC. After some years of stalling, we included benchmarks in bench/speed.ml fot the EC operations in #192 (sign, verify, generate for EC/EdDSA; and ECDH). The result for thi change is a factor between 2 and 2.5. The external API (mirage_crypto_ec.mli) does not change at all. There are various other cleanups in the code, such as providing a layer to isolate the C calls (which receive a bytes buffer for the result value, and thus mutate this buffer) to be immutable. Co-authored-by: Pierre Alain Co-authored-by: Hannes Mehnert Co-authored-by: Reynir Björnsson Reviewed-by: Virgile Robles Reviewed-by: Pierre Alain --- ec/dune | 2 +- ec/mirage_crypto_ec.ml | 1149 ++++++++++++++++++-------------- ec/native/curve25519_stubs.c | 26 +- ec/native/inversion_template.h | 2 +- ec/native/np224_stubs.c | 16 +- ec/native/np256_stubs.c | 16 +- ec/native/np384_stubs.c | 16 +- ec/native/np521_stubs.c | 16 +- ec/native/p224_stubs.c | 68 +- ec/native/p256_stubs.c | 68 +- ec/native/p384_stubs.c | 68 +- ec/native/p521_stubs.c | 68 +- src/native/mirage_crypto.h | 2 + 13 files changed, 815 insertions(+), 702 deletions(-) diff --git a/ec/dune b/ec/dune index 96767e95..95d80945 100644 --- a/ec/dune +++ b/ec/dune @@ -1,7 +1,7 @@ (library (name mirage_crypto_ec) (public_name mirage-crypto-ec) - (libraries cstruct eqaf.cstruct mirage-crypto mirage-crypto-rng) + (libraries cstruct eqaf mirage-crypto mirage-crypto-rng) (foreign_stubs (language c) (names p224_stubs np224_stubs p256_stubs np256_stubs p384_stubs np384_stubs diff --git a/ec/mirage_crypto_ec.ml b/ec/mirage_crypto_ec.ml index dbd04abb..d827c388 100644 --- a/ec/mirage_crypto_ec.ml +++ b/ec/mirage_crypto_ec.ml @@ -18,178 +18,213 @@ let error_to_string = function let pp_error fmt e = Format.fprintf fmt "Cannot parse point: %s" (error_to_string e) +let rev_string buf = + let len = String.length buf in + let res = Bytes.make len '\000' in + for i = 0 to len - 1 do + Bytes.set res (len - 1 - i) (String.get buf i) + done ; + Bytes.unsafe_to_string res + exception Message_too_long +let string_get_uint8 buf idx = + (* TODO: use String.get_uint8 when mirage-crypto-ec requires OCaml >= 4.13 *) + Bytes.get_uint8 (Bytes.unsafe_of_string buf) idx + let bit_at buf i = let byte_num = i / 8 in let bit_num = i mod 8 in - let byte = Cstruct.get_uint8 buf byte_num in + let byte = string_get_uint8 buf byte_num in byte land (1 lsl bit_num) <> 0 module type Dh = sig type secret - val secret_of_cs : ?compress:bool -> Cstruct.t -> (secret * Cstruct.t, error) result - val gen_key : ?compress:bool -> ?g:Mirage_crypto_rng.g -> unit -> secret * Cstruct.t - val key_exchange : secret -> Cstruct.t -> (Cstruct.t, error) result end module type Dsa = sig type priv - type pub - val byte_length : int - val priv_of_cstruct : Cstruct.t -> (priv, error) result - val priv_to_cstruct : priv -> Cstruct.t - val pub_of_cstruct : Cstruct.t -> (pub, error) result - val pub_to_cstruct : ?compress:bool -> pub -> Cstruct.t - val pub_of_priv : priv -> pub - val generate : ?g:Mirage_crypto_rng.g -> unit -> priv * pub - val sign : key:priv -> ?k:Cstruct.t -> Cstruct.t -> Cstruct.t * Cstruct.t - val verify : key:pub -> Cstruct.t * Cstruct.t -> Cstruct.t -> bool - module K_gen (H : Mirage_crypto.Hash.S) : sig - val generate : key:priv -> Cstruct.t -> Cstruct.t end end + module type Dh_dsa = sig module Dh : Dh module Dsa : Dsa end +type field_element = string + +type out_field_element = bytes + module type Parameters = sig - val a : Cstruct.t - val b : Cstruct.t - val g_x : Cstruct.t - val g_y : Cstruct.t - val p : Cstruct.t - val n : Cstruct.t - val pident: Cstruct.t + val a : field_element + val b : field_element + val g_x : field_element + val g_y : field_element + val p : field_element + val n : field_element + val pident: string val byte_length : int val fe_length : int val first_byte_bits : int option end -type field_element = Cstruct.buffer - type point = { f_x : field_element; f_y : field_element; f_z : field_element } -type scalar = Scalar of Cstruct.t +type out_point = { m_f_x : out_field_element; m_f_y : out_field_element; m_f_z : out_field_element } + +type scalar = Scalar of string module type Foreign = sig - val mul : field_element -> field_element -> field_element -> unit - val sub : field_element -> field_element -> field_element -> unit - val add : field_element -> field_element -> field_element -> unit - val to_montgomery : field_element -> unit - val from_bytes_buf : field_element -> Cstruct.buffer -> unit - val set_one : field_element -> unit + val mul : out_field_element -> field_element -> field_element -> unit + val sub : out_field_element -> field_element -> field_element -> unit + val add : out_field_element -> field_element -> field_element -> unit + val to_montgomery : out_field_element -> field_element -> unit + val from_octets : out_field_element -> string -> unit + val set_one : out_field_element -> unit val nz : field_element -> bool - val sqr : field_element -> field_element -> unit - val from_montgomery : field_element -> unit - val to_bytes_buf : Cstruct.buffer -> field_element -> unit - val inv : field_element -> field_element -> unit - val select_c : field_element -> bool -> field_element -> field_element -> unit - - val double_c : point -> point -> unit - val add_c : point -> point -> point -> unit + val sqr : out_field_element -> field_element -> unit + val from_montgomery : out_field_element -> field_element -> unit + val to_octets : bytes -> field_element -> unit + val inv : out_field_element -> field_element -> unit + val select_c : out_field_element -> bool -> field_element -> field_element -> unit + + val double_c : out_point -> point -> unit + val add_c : out_point -> point -> point -> unit end module type Field_element = sig - val create : unit -> field_element - - val copy : field_element -> field_element -> unit - - val one : unit -> field_element - - val to_bytes : Cstruct.t -> field_element -> unit - - val from_montgomery : field_element -> unit - - val add : field_element -> field_element -> field_element -> unit - - val sub : field_element -> field_element -> field_element -> unit - - val mul : field_element -> field_element -> field_element -> unit - + val mul : field_element -> field_element -> field_element + val sub : field_element -> field_element -> field_element + val add : field_element -> field_element -> field_element + val from_montgomery : field_element -> field_element + val zero : field_element + val one : field_element val nz : field_element -> bool - - val sqr : field_element -> field_element -> unit - - val inv : field_element -> field_element -> unit - - val from_be_cstruct : Cstruct.t -> field_element - + val sqr : field_element -> field_element + val inv : field_element -> field_element val select : bool -> then_:field_element -> else_:field_element -> field_element + val from_be_octets : string -> field_element + val to_octets : field_element -> string + val double_point : point -> point + val add_point : point -> point -> point end module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct - include F + let b_uts b = Bytes.unsafe_to_string b + + let create () = Bytes.make P.fe_length '\000' + + let mul a b = + let tmp = create () in + F.mul tmp a b; + b_uts tmp - let create () = Cstruct.to_bigarray (Cstruct.create P.fe_length) + let sub a b = + let tmp = create () in + F.sub tmp a b; + b_uts tmp - let copy dst src = Bigarray.Array1.blit src dst + let add a b = + let tmp = create () in + F.add tmp a b; + b_uts tmp - let checked_buffer cs = - assert (Cstruct.length cs = P.byte_length); - Cstruct.to_bigarray cs + let from_montgomery a = + let tmp = create () in + F.from_montgomery tmp a; + b_uts tmp - let from_bytes fe cs = - F.from_bytes_buf fe (checked_buffer cs) + let zero = + b_uts (create ()) - let one () = + let one = let fe = create () in F.set_one fe; - fe + b_uts fe - let to_bytes cs fe = - F.to_bytes_buf (checked_buffer cs) fe + let nz a = F.nz a - let from_be_cstruct cs = - let cs_rev = Cstruct.rev cs in - let fe = create () in - from_bytes fe cs_rev; - F.to_montgomery fe; - fe + let sqr a = + let tmp = create () in + F.sqr tmp a; + b_uts tmp + + let inv a = + let tmp = create () in + F.inv tmp a; + b_uts tmp let select bit ~then_ ~else_ = - let out = create () in - F.select_c out bit then_ else_; - out + let tmp = create () in + F.select_c tmp bit then_ else_; + b_uts tmp + + let from_be_octets buf = + let buf_rev = rev_string buf in + let tmp = create () in + F.from_octets tmp buf_rev; + F.to_montgomery tmp (b_uts tmp); + b_uts tmp + + let create_octets () = + Bytes.make P.byte_length '\000' + + let to_octets fe = + let tmp = create_octets () in + F.to_octets tmp fe; + b_uts tmp + + let out_point () = { + m_f_x = create (); + m_f_y = create (); + m_f_z = create (); + } + + let out_p_to_p p = { + f_x = b_uts p.m_f_x ; + f_y = b_uts p.m_f_y ; + f_z = b_uts p.m_f_z ; + } + + let double_point p = + let tmp = out_point () in + F.double_c tmp p; + out_p_to_p tmp + + let add_point a b = + let tmp = out_point () in + F.add_c tmp a b; + out_p_to_p tmp end module type Point = sig val at_infinity : unit -> point - val is_infinity : point -> bool - val add : point -> point -> point - val double : point -> point - - val of_cstruct : Cstruct.t -> (point, error) result - - val to_cstruct : compress:bool -> point -> Cstruct.t - + val of_octets : string -> (point, error) result + val to_octets : compress:bool -> point -> string val to_affine_raw : point -> (field_element * field_element) option - - val x_of_finite_point : point -> Cstruct.t - + val x_of_finite_point : point -> string val params_g : point - val select : bool -> then_:point -> else_:point -> point end @@ -197,36 +232,32 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct module Fe = Make_field_element(P)(F) let at_infinity () = - let f_x = Fe.one () in - let f_y = Fe.one () in - let f_z = Fe.create () in + let f_x = Fe.one in + let f_y = Fe.one in + let f_z = Fe.zero in { f_x; f_y; f_z } - let is_infinity p = not (Fe.nz p.f_z) + let is_infinity (p : point) = not (Fe.nz p.f_z) let is_solution_to_curve_equation = - let a = Fe.from_be_cstruct P.a in - let b = Fe.from_be_cstruct P.b in + let a = Fe.from_be_octets P.a in + let b = Fe.from_be_octets P.b in fun ~x ~y -> - let x3 = Fe.create () in - Fe.mul x3 x x; - Fe.mul x3 x3 x; - let ax = Fe.create () in - Fe.mul ax a x; - let y2 = Fe.create () in - Fe.mul y2 y y; - let sum = Fe.create () in - Fe.add sum x3 ax; - Fe.add sum sum b; - Fe.sub sum sum y2; + let x3 = Fe.mul x x in + let x3 = Fe.mul x3 x in + let ax = Fe.mul a x in + let y2 = Fe.mul y y in + let sum = Fe.add x3 ax in + let sum = Fe.add sum b in + let sum = Fe.sub sum y2 in not (Fe.nz sum) - let check_coordinate cs = - (* ensure cs < p: *) - match Eqaf_cstruct.compare_be_with_len ~len:P.byte_length cs P.p >= 0 with + let check_coordinate buf = + (* ensure buf < p: *) + match Eqaf.compare_be_with_len ~len:P.byte_length buf P.p >= 0 with | true -> None | exception Invalid_argument _ -> None - | false -> Some (Fe.from_be_cstruct cs) + | false -> Some (Fe.from_be_octets buf) (** Convert cstruct coordinates to a finite point ensuring: - x < p @@ -237,7 +268,7 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct match (check_coordinate x, check_coordinate y) with | Some f_x, Some f_y -> if is_solution_to_curve_equation ~x:f_x ~y:f_y then - let f_z = Fe.one () in + let f_z = Fe.one in Ok { f_x; f_y; f_z } else Error `Not_on_curve | _ -> Error `Invalid_range @@ -246,65 +277,49 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct if is_infinity p then None else - let z1 = Fe.create () in - let z2 = Fe.create () in - Fe.copy z1 p.f_z; - Fe.from_montgomery z1; - Fe.inv z2 z1; - Fe.sqr z1 z2; - Fe.from_montgomery z1; - let x = Fe.create () in - Fe.copy x p.f_x; - Fe.mul x x z1; - let y = Fe.create () in - Fe.copy y p.f_y; - Fe.mul z1 z1 z2; - Fe.mul y y z1; + let z1 = Fe.from_montgomery p.f_z in + let z2 = Fe.inv z1 in + let z1 = Fe.sqr z2 in + let z1 = Fe.from_montgomery z1 in + let x = Fe.mul p.f_x z1 in + let z1 = Fe.mul z1 z2 in + let y = Fe.mul p.f_y z1 in Some (x, y) let to_affine p = - match to_affine_raw p with - | None -> None - | Some (x, y) -> - let out_x = Cstruct.create P.byte_length in - let out_y = Cstruct.create P.byte_length in - Fe.to_bytes out_x x; - Fe.to_bytes out_y y; - Some (out_x, out_y) - - let to_cstruct ~compress p = + Option.map (fun (x, y) -> Fe.to_octets x, Fe.to_octets y) + (to_affine_raw p) + + let to_octets ~compress p = let buf = match to_affine p with - | None -> Cstruct.create 1 + | None -> String.make 1 '\000' | Some (x, y) -> - let four = Cstruct.create 1 in - Cstruct.set_uint8 four 0 4; - let rev_x = Cstruct.rev x and rev_y = Cstruct.rev y in - Cstruct.concat [ four; rev_x; rev_y ] + let len_x = String.length x and len_y = String.length y in + let res = Bytes.make (1 + len_x + len_y) '\000' in + Bytes.set res 0 '\004' ; + let rev_x = rev_string x and rev_y = rev_string y in + Bytes.blit_string rev_x 0 res 1 len_x ; + Bytes.blit_string rev_y 0 res (1 + len_x) len_y ; + Bytes.unsafe_to_string res in if compress then - let out = Cstruct.create (P.byte_length + 1) in + let out = Bytes.make (P.byte_length + 1) '\000' in let ident = - 2 + (Cstruct.get_uint8 buf ((P.byte_length * 2) - 1)) land 1 + 2 + (string_get_uint8 buf ((P.byte_length * 2) - 1)) land 1 in - Cstruct.blit buf 1 out 1 P.byte_length; - Cstruct.set_uint8 out 0 ident; - out + Bytes.blit_string buf 1 out 1 P.byte_length; + Bytes.set_uint8 out 0 ident; + Bytes.unsafe_to_string out else buf - let double p = - let out = { f_x = Fe.create (); f_y = Fe.create (); f_z = Fe.create () } in - F.double_c out p; - out + let double p = Fe.double_point p - let add fe_p fe_q = - let out = { f_x = Fe.create (); f_y = Fe.create (); f_z = Fe.create () } in - F.add_c out fe_p fe_q; - out + let add p q = Fe.add_point p q let x_of_finite_point p = - match to_affine p with None -> assert false | Some (x, _) -> Cstruct.rev x + match to_affine p with None -> assert false | Some (x, _) -> rev_string x let params_g = match validate_finite_point ~x:P.g_x ~y:P.g_y with @@ -318,25 +333,14 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct f_z = Fe.select bit ~then_:then_.f_z ~else_:else_.f_z; } - let pow = - let mult a b = - let r = Fe.create () in - Fe.mul r a b; - r - in - let sqr x = - let r = Fe.create () in - Fe.sqr r x; - r - in - fun x exp -> - let r0 = ref (Fe.one ()) in + let pow x exp = + let r0 = ref Fe.one in let r1 = ref x in for i = P.byte_length * 8 - 1 downto 0 do let bit = bit_at exp i in - let multiplied = mult !r0 !r1 in - let r0_sqr = sqr !r0 in - let r1_sqr = sqr !r1 in + let multiplied = Fe.mul !r0 !r1 in + let r0_sqr = Fe.sqr !r0 in + let r1_sqr = Fe.sqr !r1 in r0 := Fe.select bit ~then_:multiplied ~else_:r0_sqr; r1 := Fe.select bit ~then_:r1_sqr ~else_:multiplied; done; @@ -353,90 +357,81 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct Q=(x,y) is the canonical representation of the point *) let pident = P.pident (* (Params.p + 1) / 4*) in - let a = Fe.from_be_cstruct P.a in - let b = Fe.from_be_cstruct P.b in - let p = Fe.from_be_cstruct P.p in - fun pk_cstruct -> - let x = Fe.from_be_cstruct (Cstruct.sub pk_cstruct 1 P.byte_length) in - let x3 = Fe.create () in - let ax = Fe.create () in - let sum = Fe.create () in - Fe.mul x3 x x; - Fe.mul x3 x3 x; (* x3 *) - Fe.mul ax a x; (* ax *) - Fe.add sum x3 ax; - Fe.add sum sum b; (* y^2 *) + let a = Fe.from_be_octets P.a in + let b = Fe.from_be_octets P.b in + let p = Fe.from_be_octets P.p in + fun pk -> + let x = Fe.from_be_octets (String.sub pk 1 P.byte_length) in + let x3 = Fe.mul x x in + let x3 = Fe.mul x3 x in (* x3 *) + let ax = Fe.mul a x in (* ax *) + let sum = Fe.add x3 ax in + let sum = Fe.add sum b in (* y^2 *) let y = pow sum pident in (* https://tools.ietf.org/id/draft-jivsov-ecc-compact-00.xml#sqrt point 4.3*) - let y' = Fe.create () in - Fe.sub y' p y; - let y_struct = Cstruct.create (P.byte_length) in - Fe.from_montgomery y; - Fe.to_bytes y_struct y; (* number must not be in montgomery domain*) - let y_struct = Cstruct.rev y_struct in - let y_struct2 = Cstruct.create (P.byte_length) in - Fe.from_montgomery y'; - Fe.to_bytes y_struct2 y';(* number must not be in montgomery domain*) - let y_struct2 = Cstruct.rev y_struct2 in - let ident = Cstruct.get_uint8 pk_cstruct 0 in + let y' = Fe.sub p y in + let y = Fe.from_montgomery y in + let y_struct = Fe.to_octets y in (* number must not be in montgomery domain*) + let y_struct = rev_string y_struct in + let y' = Fe.from_montgomery y' in + let y_struct2 = Fe.to_octets y' in (* number must not be in montgomery domain*) + let y_struct2 = rev_string y_struct2 in + let ident = string_get_uint8 pk 0 in let signY = - 2 + (Cstruct.get_uint8 y_struct (P.byte_length - 2)) land 1 + 2 + (string_get_uint8 y_struct (P.byte_length - 2)) land 1 in let res = if Int.equal signY ident then y_struct else y_struct2 in - let out = Cstruct.create ((P.byte_length * 2) + 1) in - Cstruct.set_uint8 out 0 4; - Cstruct.blit pk_cstruct 1 out 1 P.byte_length; - Cstruct.blit res 0 out (P.byte_length + 1) P.byte_length; - out + let out = Bytes.make ((P.byte_length * 2) + 1) '\000' in + Bytes.set out 0 '\004'; + Bytes.blit_string pk 1 out 1 P.byte_length; + Bytes.blit_string res 0 out (P.byte_length + 1) P.byte_length; + Bytes.unsafe_to_string out - let of_cstruct cs = + let of_octets buf = let len = P.byte_length in - if Cstruct.length cs = 0 then + if String.length buf = 0 then Error `Invalid_format else - let of_cs cs = - let x = Cstruct.sub cs 1 len in - let y = Cstruct.sub cs (1 + len) len in + let of_octets buf = + let x = String.sub buf 1 len in + let y = String.sub buf (1 + len) len in validate_finite_point ~x ~y in - match Cstruct.get_uint8 cs 0 with - | 0x00 when Cstruct.length cs = 1 -> Ok (at_infinity ()) - | 0x02 | 0x03 when Cstruct.length P.pident > 0 -> - let decompressed = decompress cs in - of_cs decompressed - | 0x04 when Cstruct.length cs = 1 + len + len -> - of_cs cs + match string_get_uint8 buf 0 with + | 0x00 when String.length buf = 1 -> + Ok (at_infinity ()) + | 0x02 | 0x03 when String.length P.pident > 0 -> + let decompressed = decompress buf in + of_octets decompressed + | 0x04 when String.length buf = 1 + len + len -> + of_octets buf | 0x00 | 0x04 -> Error `Invalid_length | _ -> Error `Invalid_format end module type Scalar = sig - val not_zero : Cstruct.t -> bool - - val is_in_range : Cstruct.t -> bool - - val of_cstruct : Cstruct.t -> (scalar, error) result - - val to_cstruct : scalar -> Cstruct.t - + val not_zero : string -> bool + val is_in_range : string -> bool + val of_octets : string -> (scalar, error) result + val to_octets : scalar -> string val scalar_mult : scalar -> point -> point end module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct let not_zero = - let zero = Cstruct.create Param.byte_length in - fun cs -> not (Eqaf_cstruct.equal cs zero) + let zero = String.make Param.byte_length '\000' in + fun buf -> not (Eqaf.equal buf zero) - let is_in_range cs = - not_zero cs - && Eqaf_cstruct.compare_be_with_len ~len:Param.byte_length Param.n cs > 0 + let is_in_range buf = + not_zero buf + && Eqaf.compare_be_with_len ~len:Param.byte_length Param.n buf > 0 - let of_cstruct cs = - match is_in_range cs with + let of_octets buf = + match is_in_range buf with | exception Invalid_argument _ -> Error `Invalid_length - | true -> Ok (Scalar (Cstruct.rev cs)) + | true -> Ok (Scalar (rev_string buf)) | false -> Error `Invalid_range - let to_cstruct (Scalar cs) = Cstruct.rev cs + let to_octets (Scalar buf) = rev_string buf let scalar_mult (Scalar s) p = let r0 = ref (P.at_infinity ()) in @@ -453,64 +448,157 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct end module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct - let point_of_cs c = - match P.of_cstruct c with + let point_of_octets c = + match P.of_octets c with | Ok p when not (P.is_infinity p) -> Ok p | Ok _ -> Error `At_infinity | Error _ as e -> e - let point_to_cs = P.to_cstruct + let point_to_octets = P.to_octets type secret = scalar let share ?(compress = false) private_key = let public_key = S.scalar_mult private_key P.params_g in - point_to_cs ~compress public_key + point_to_octets ~compress public_key - let secret_of_cs ?compress s = - match S.of_cstruct s with + let secret_of_octets ?compress s = + match S.of_octets s with | Ok p -> Ok (p, share ?compress p) | Error _ as e -> e + let secret_of_cs ?compress s = + Result.map (fun (p, share) -> p, Cstruct.of_string share) + (secret_of_octets ?compress (Cstruct.to_string s)) + let rec generate_private_key ?g () = let candidate = Mirage_crypto_rng.generate ?g Param.byte_length in - match S.of_cstruct candidate with + match S.of_octets (Cstruct.to_string candidate) with | Ok secret -> secret | Error _ -> generate_private_key ?g () - let gen_key ?compress ?g () = + let gen_key_octets ?compress ?g () = let private_key = generate_private_key ?g () in (private_key, share ?compress private_key) - let key_exchange secret received = - match point_of_cs received with + let gen_key ?compress ?g () = + let private_key, share = gen_key_octets ?compress ?g () in + private_key, Cstruct.of_string share + + let key_exchange_octets secret received = + match point_of_octets received with | Error _ as err -> err | Ok shared -> Ok (P.x_of_finite_point (S.scalar_mult secret shared)) + + let key_exchange secret received = + match key_exchange_octets secret (Cstruct.to_string received) with + | Error _ as err -> err + | Ok shared -> Ok (Cstruct.of_string shared) end module type Foreign_n = sig - val mul : field_element -> field_element -> field_element -> unit - val add : field_element -> field_element -> field_element -> unit - val inv : field_element -> field_element -> unit - val one : field_element -> unit - val from_bytes : field_element -> Cstruct.buffer -> unit - val to_bytes : Cstruct.buffer -> field_element -> unit - val from_montgomery : field_element -> field_element -> unit - val to_montgomery : field_element -> field_element -> unit + val mul : out_field_element -> field_element -> field_element -> unit + val add : out_field_element -> field_element -> field_element -> unit + val inv : out_field_element -> field_element -> unit + val one : out_field_element -> unit + val from_bytes : out_field_element -> string -> unit + val to_bytes : bytes -> field_element -> unit + val from_montgomery : out_field_element -> field_element -> unit + val to_montgomery : out_field_element -> field_element -> unit +end + +module type Fn = sig + val from_be_octets : string -> field_element + val to_be_octets : field_element -> string + val mul : field_element -> field_element -> field_element + val add : field_element -> field_element -> field_element + val inv : field_element -> field_element + val one : field_element + val from_montgomery : field_element -> field_element + val to_montgomery : field_element -> field_element end -module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H : Mirage_crypto.Hash.S) = struct - let create () = Cstruct.to_bigarray (Cstruct.create Param.fe_length) +module Make_Fn (P : Parameters) (F : Foreign_n) : Fn = struct + let b_uts = Bytes.unsafe_to_string + + let create () = Bytes.make P.fe_length '\000' + + let create_octets () = Bytes.make P.byte_length '\000' + + let from_be_octets v = + let v' = create () in + F.from_bytes v' (rev_string v); + F.to_montgomery v' (b_uts v'); + b_uts v' + + let to_be_octets v = + let buf = create_octets () in + F.to_bytes buf v; + rev_string (b_uts buf) + + let mul a b = + let tmp = create () in + F.mul tmp a b; + b_uts tmp + + let add a b = + let tmp = create () in + F.add tmp a b; + b_uts tmp + + let inv a = + let tmp = create () in + F.inv tmp a; + F.to_montgomery tmp (b_uts tmp); + b_uts tmp + + let one = + let tmp = create () in + F.one tmp; + b_uts tmp + + let from_montgomery a = + let tmp = create () in + F.from_montgomery tmp a; + b_uts tmp + + let to_montgomery a = + let tmp = create () in + F.to_montgomery tmp a; + b_uts tmp +end +module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mirage_crypto.Hash.S) = struct type priv = scalar let byte_length = Param.byte_length - let priv_of_cstruct = S.of_cstruct + let priv_of_octets= S.of_octets + + let priv_to_octets = S.to_octets - let priv_to_cstruct = S.to_cstruct + let priv_of_cstruct cs = priv_of_octets (Cstruct.to_string cs) + + let priv_to_cstruct p = Cstruct.of_string (priv_to_octets p) let padded msg = + let l = String.length msg in + let bl = Param.byte_length in + let first_byte_ok () = + match Param.first_byte_bits with + | None -> true + | Some m -> (string_get_uint8 msg 0) land (0xFF land (lnot m)) = 0 + in + if l > bl || (l = bl && not (first_byte_ok ())) then + raise Message_too_long + else if l = bl then + msg + else + ( let res = Bytes.make bl '\000' in + Bytes.blit_string msg 0 res (bl - l) (String.length msg) ; + Bytes.unsafe_to_string res ) + + let padded_cs msg = let l = Cstruct.length msg in let bl = Param.byte_length in let first_byte_ok () = @@ -525,52 +613,56 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H else Cstruct.append (Cstruct.create (bl - l)) msg - let from_be_cstruct v = - let v' = create () in - F.from_bytes v' (Cstruct.to_bigarray (Cstruct.rev v)); - v' - - let to_be_cstruct v = - let cs = Cstruct.create Param.byte_length in - F.to_bytes (Cstruct.to_bigarray cs) v; - Cstruct.rev cs - (* RFC 6979: compute a deterministic k *) module K_gen (H : Mirage_crypto.Hash.S) = struct - let drbg : 'a Mirage_crypto_rng.generator = let module M = Mirage_crypto_rng.Hmac_drbg (H) in (module M) let g ~key cs = let g = Mirage_crypto_rng.create ~strict:true drbg in Mirage_crypto_rng.reseed ~g - (Cstruct.append (S.to_cstruct key) cs); + (Cstruct.append (Cstruct.of_string (S.to_octets key)) cs); + g + + let g_octets ~key msg = + let g = Mirage_crypto_rng.create ~strict:true drbg in + Mirage_crypto_rng.reseed ~g + (Cstruct.of_string (String.concat "" [ S.to_octets key ; msg ])); g (* take qbit length, and ensure it is suitable for ECDSA (> 0 & < n) *) let gen g = let rec go () = let r = Mirage_crypto_rng.generate ~g Param.byte_length in + let r = Cstruct.to_string r in if S.is_in_range r then r else go () in go () - let generate ~key cs = gen (g ~key (padded cs)) + (* let generate_octets ~key buf = gen (g ~key (Cstruct.of_string (padded buf))) *) + + let generate ~key buf = + Cstruct.of_string (gen (g ~key (padded_cs buf))) end module K_gen_default = K_gen(H) type pub = point - let pub_of_cstruct = P.of_cstruct + let pub_of_octets = P.of_octets + + let pub_to_octets ?(compress = false) pk = P.to_octets ~compress pk - let pub_to_cstruct ?(compress = false) pk = P.to_cstruct ~compress pk + let pub_of_cstruct cs = pub_of_octets (Cstruct.to_string cs) + + let pub_to_cstruct ?compress p = + Cstruct.of_string (pub_to_octets ?compress p) let generate ?g () = (* FIPS 186-4 B 4.2 *) let d = let rec one () = - match S.of_cstruct (Mirage_crypto_rng.generate ?g Param.byte_length) with + match S.of_octets (Cstruct.to_string (Mirage_crypto_rng.generate ?g Param.byte_length)) with | Ok x -> x | Error _ -> one () in @@ -583,17 +675,15 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H match P.to_affine_raw p with | None -> None | Some (x, _) -> - F.to_montgomery x x; - let o = create () in - F.one o; - F.mul x x o; - F.from_montgomery x x; - Some (to_be_cstruct x) + let x = F.to_montgomery x in + let x = F.mul x F.one in + let x = F.from_montgomery x in + Some (F.to_be_octets x) - let sign ~key ?k msg = + let sign_octets ~key ?k msg = let msg = padded msg in - let e = from_be_cstruct msg in - let g = K_gen_default.g ~key msg in + let e = F.from_be_octets msg in + let g = K_gen_default.g_octets ~key msg in let rec do_sign g = let again () = match k with @@ -601,7 +691,7 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H | Some _ -> invalid_arg "k not suitable" in let k' = match k with None -> K_gen_default.gen g | Some k -> k in - let ksc = match S.of_cstruct k' with + let ksc = match S.of_octets k' with | Ok ksc -> ksc | Error _ -> invalid_arg "k not in range" (* if no k is provided, this cannot happen since K_gen_*.gen already preserves the Scalar invariants *) in @@ -609,26 +699,15 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H match x_of_finite_point_mod_n point with | None -> again () | Some r -> - let r_mon = from_be_cstruct r in - F.to_montgomery r_mon r_mon; - let kinv = create () in - let kmon = from_be_cstruct k' in - F.to_montgomery kmon kmon; - F.inv kinv kmon; - F.to_montgomery kmon kinv; - let rd = create () in - let dmon = from_be_cstruct (S.to_cstruct key) in - F.to_montgomery dmon dmon; - F.mul rd r_mon dmon; - let cmon = create () in - let zmon = create () in - F.to_montgomery zmon e; - F.add cmon zmon rd; - let smon = create () in - F.mul smon kmon cmon; - let s = create () in - F.from_montgomery s smon; - let s = to_be_cstruct s in + let r_mon = F.from_be_octets r in + let kmon = F.from_be_octets k' in + let kinv = F.inv kmon in + let dmon = F.from_be_octets (S.to_octets key) in + let rd = F.mul r_mon dmon in + let cmon = F.add e rd in + let smon = F.mul kinv cmon in + let s = F.from_montgomery smon in + let s = F.to_be_octets s in if S.not_zero s && S.not_zero r then r, s else @@ -636,33 +715,30 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H in do_sign g + let sign ~key ?k msg = + let r, s = sign_octets ~key ?k:(Option.map Cstruct.to_string k) (Cstruct.to_string msg) in + Cstruct.of_string r, Cstruct.of_string s + let pub_of_priv priv = S.scalar_mult priv P.params_g - let verify ~key (r, s) msg = + let verify_octets ~key (r, s) msg = try let r = padded r and s = padded s in if not (S.is_in_range r && S.is_in_range s) then false else let msg = padded msg in - let z = from_be_cstruct msg in - let s_inv = create () in - let s_mon = from_be_cstruct s in - F.to_montgomery s_mon s_mon; - F.inv s_inv s_mon; - let u1 = create () in - F.to_montgomery s_inv s_inv; - F.to_montgomery z z; - F.mul u1 z s_inv; - let u2 = create () in - let r_mon = from_be_cstruct r in - F.to_montgomery r_mon r_mon; - F.mul u2 r_mon s_inv; - F.from_montgomery u1 u1; - F.from_montgomery u2 u2; + let z = F.from_be_octets msg in + let s_mon = F.from_be_octets s in + let s_inv = F.inv s_mon in + let u1 = F.mul z s_inv in + let r_mon = F.from_be_octets r in + let u2 = F.mul r_mon s_inv in + let u1 = F.from_montgomery u1 in + let u2 = F.from_montgomery u2 in match - S.of_cstruct (to_be_cstruct u1), - S.of_cstruct (to_be_cstruct u2) + S.of_octets (F.to_be_octets u1), + S.of_octets (F.to_be_octets u2) with | Ok u1, Ok u2 -> let point = @@ -672,348 +748,391 @@ module Make_dsa (Param : Parameters) (F : Foreign_n) (P : Point) (S : Scalar) (H in begin match x_of_finite_point_mod_n point with | None -> false (* point is infinity *) - | Some r' -> Cstruct.equal r r' + | Some r' -> String.equal r r' end | Error _, _ | _, Error _ -> false with | Message_too_long -> false + + let verify ~key (r, s) digest = + verify_octets ~key (Cstruct.to_string r, Cstruct.to_string s) (Cstruct.to_string digest) end module P224 : Dh_dsa = struct module Params = struct - let a = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFE" - let b = Cstruct.of_hex "B4050A850C04B3ABF54132565044B0B7D7BFD8BA270B39432355FFB4" - let g_x = Cstruct.of_hex "B70E0CBD6BB4BF7F321390B94A03C1D356C21122343280D6115C1D21" - let g_y = Cstruct.of_hex "BD376388B5F723FB4C22DFE6CD4375A05A07476444D5819985007E34" - let p = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001" - let n = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFF16A2E0B8F03E13DD29455C5C2A3D" - let pident = Cstruct.empty + let a = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE" + let b = "\xB4\x05\x0A\x85\x0C\x04\xB3\xAB\xF5\x41\x32\x56\x50\x44\xB0\xB7\xD7\xBF\xD8\xBA\x27\x0B\x39\x43\x23\x55\xFF\xB4" + let g_x = "\xB7\x0E\x0C\xBD\x6B\xB4\xBF\x7F\x32\x13\x90\xB9\x4A\x03\xC1\xD3\x56\xC2\x11\x22\x34\x32\x80\xD6\x11\x5C\x1D\x21" + let g_y = "\xBD\x37\x63\x88\xB5\xF7\x23\xFB\x4C\x22\xDF\xE6\xCD\x43\x75\xA0\x5A\x07\x47\x64\x44\xD5\x81\x99\x85\x00\x7E\x34" + let p = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + let n = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x16\xA2\xE0\xB8\xF0\x3E\x13\xDD\x29\x45\x5C\x5C\x2A\x3D" + let pident = "" let byte_length = 28 let fe_length = if Sys.word_size == 64 then 32 else 28 (* TODO: is this congruent with C code? *) let first_byte_bits = None end module Foreign = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_p224_mul" [@@noalloc] - external sub : field_element -> field_element -> field_element -> unit = "mc_p224_sub" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_p224_add" [@@noalloc] - external to_montgomery : field_element -> unit = "mc_p224_to_montgomery" [@@noalloc] - external from_bytes_buf : field_element -> Cstruct.buffer -> unit = "mc_p224_from_bytes" [@@noalloc] - external set_one : field_element -> unit = "mc_p224_set_one" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_p224_mul" [@@noalloc] + external sub : out_field_element -> field_element -> field_element -> unit = "mc_p224_sub" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_p224_add" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_p224_to_montgomery" [@@noalloc] + external from_octets : out_field_element -> string -> unit = "mc_p224_from_bytes" [@@noalloc] + external set_one : out_field_element -> unit = "mc_p224_set_one" [@@noalloc] external nz : field_element -> bool = "mc_p224_nz" [@@noalloc] - external sqr : field_element -> field_element -> unit = "mc_p224_sqr" [@@noalloc] - external from_montgomery : field_element -> unit = "mc_p224_from_montgomery" [@@noalloc] - external to_bytes_buf : Cstruct.buffer -> field_element -> unit = "mc_p224_to_bytes" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_p224_inv" [@@noalloc] - external select_c : field_element -> bool -> field_element -> field_element -> unit = "mc_p224_select" [@@noalloc] - - external double_c : point -> point -> unit = "mc_p224_point_double" [@@noalloc] - external add_c : point -> point -> point -> unit = "mc_p224_point_add" [@@noalloc] + external sqr : out_field_element -> field_element -> unit = "mc_p224_sqr" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_p224_from_montgomery" [@@noalloc] + external to_octets : bytes -> field_element -> unit = "mc_p224_to_bytes" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_p224_inv" [@@noalloc] + external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p224_select" [@@noalloc] + external double_c : out_point -> point -> unit = "mc_p224_point_double" [@@noalloc] + external add_c : out_point -> point -> point -> unit = "mc_p224_point_add" [@@noalloc] end module Foreign_n = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_np224_mul" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_np224_add" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_np224_inv" [@@noalloc] - external one : field_element -> unit = "mc_np224_one" [@@noalloc] - external from_bytes : field_element -> Cstruct.buffer -> unit = "mc_np224_from_bytes" [@@noalloc] - external to_bytes : Cstruct.buffer -> field_element -> unit = "mc_np224_to_bytes" [@@noalloc] - external from_montgomery : field_element -> field_element -> unit = "mc_np224_from_montgomery" [@@noalloc] - external to_montgomery : field_element -> field_element -> unit = "mc_np224_to_montgomery" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_np224_mul" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_np224_add" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_np224_inv" [@@noalloc] + external one : out_field_element -> unit = "mc_np224_one" [@@noalloc] + external from_bytes : out_field_element -> string -> unit = "mc_np224_from_bytes" [@@noalloc] + external to_bytes : bytes -> field_element -> unit = "mc_np224_to_bytes" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_np224_from_montgomery" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_np224_to_montgomery" [@@noalloc] end module P = Make_point(Params)(Foreign) module S = Make_scalar(Params)(P) module Dh = Make_dh(Params)(P)(S) - module Dsa = Make_dsa(Params)(Foreign_n)(P)(S)(Mirage_crypto.Hash.SHA256) + module Fn = Make_Fn(Params)(Foreign_n) + module Dsa = Make_dsa(Params)(Fn)(P)(S)(Mirage_crypto.Hash.SHA256) end module P256 : Dh_dsa = struct module Params = struct - let a = Cstruct.of_hex "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC" - let b = Cstruct.of_hex "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B" - let g_x = - Cstruct.of_hex "6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296" - let g_y = - Cstruct.of_hex "4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5" - let p = Cstruct.of_hex "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF" - let n = Cstruct.of_hex "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551" - let pident = Cstruct.of_hex "3FFFFFFFC0000000400000000000000000000000400000000000000000000000" |> Cstruct.rev (* (Params.p + 1) / 4*) + let a = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC" + let b = "\x5A\xC6\x35\xD8\xAA\x3A\x93\xE7\xB3\xEB\xBD\x55\x76\x98\x86\xBC\x65\x1D\x06\xB0\xCC\x53\xB0\xF6\x3B\xCE\x3C\x3E\x27\xD2\x60\x4B" + let g_x = "\x6B\x17\xD1\xF2\xE1\x2C\x42\x47\xF8\xBC\xE6\xE5\x63\xA4\x40\xF2\x77\x03\x7D\x81\x2D\xEB\x33\xA0\xF4\xA1\x39\x45\xD8\x98\xC2\x96" + let g_y = "\x4F\xE3\x42\xE2\xFE\x1A\x7F\x9B\x8E\xE7\xEB\x4A\x7C\x0F\x9E\x16\x2B\xCE\x33\x57\x6B\x31\x5E\xCE\xCB\xB6\x40\x68\x37\xBF\x51\xF5" + let p = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + let n = "\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xBC\xE6\xFA\xAD\xA7\x17\x9E\x84\xF3\xB9\xCA\xC2\xFC\x63\x25\x51" + let pident = "\x3F\xFF\xFF\xFF\xC0\x00\x00\x00\x40\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x40\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" |> rev_string (* (Params.p + 1) / 4*) let byte_length = 32 let fe_length = 32 let first_byte_bits = None end module Foreign = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_p256_mul" [@@noalloc] - external sub : field_element -> field_element -> field_element -> unit = "mc_p256_sub" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_p256_add" [@@noalloc] - external to_montgomery : field_element -> unit = "mc_p256_to_montgomery" [@@noalloc] - external from_bytes_buf : field_element -> Cstruct.buffer -> unit = "mc_p256_from_bytes" [@@noalloc] - external set_one : field_element -> unit = "mc_p256_set_one" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_p256_mul" [@@noalloc] + external sub : out_field_element -> field_element -> field_element -> unit = "mc_p256_sub" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_p256_add" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_p256_to_montgomery" [@@noalloc] + external from_octets : out_field_element -> string -> unit = "mc_p256_from_bytes" [@@noalloc] + external set_one : out_field_element -> unit = "mc_p256_set_one" [@@noalloc] external nz : field_element -> bool = "mc_p256_nz" [@@noalloc] - external sqr : field_element -> field_element -> unit = "mc_p256_sqr" [@@noalloc] - external from_montgomery : field_element -> unit = "mc_p256_from_montgomery" [@@noalloc] - external to_bytes_buf : Cstruct.buffer -> field_element -> unit = "mc_p256_to_bytes" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_p256_inv" [@@noalloc] - external select_c : field_element -> bool -> field_element -> field_element -> unit = "mc_p256_select" [@@noalloc] - - external double_c : point -> point -> unit = "mc_p256_point_double" [@@noalloc] - external add_c : point -> point -> point -> unit = "mc_p256_point_add" [@@noalloc] + external sqr : out_field_element -> field_element -> unit = "mc_p256_sqr" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_p256_from_montgomery" [@@noalloc] + external to_octets : bytes -> field_element -> unit = "mc_p256_to_bytes" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_p256_inv" [@@noalloc] + external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p256_select" [@@noalloc] + external double_c : out_point -> point -> unit = "mc_p256_point_double" [@@noalloc] + external add_c : out_point -> point -> point -> unit = "mc_p256_point_add" [@@noalloc] end module Foreign_n = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_np256_mul" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_np256_add" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_np256_inv" [@@noalloc] - external one : field_element -> unit = "mc_np256_one" [@@noalloc] - external from_bytes : field_element -> Cstruct.buffer -> unit = "mc_np256_from_bytes" [@@noalloc] - external to_bytes : Cstruct.buffer -> field_element -> unit = "mc_np256_to_bytes" [@@noalloc] - external from_montgomery : field_element -> field_element -> unit = "mc_np256_from_montgomery" [@@noalloc] - external to_montgomery : field_element -> field_element -> unit = "mc_np256_to_montgomery" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_np256_mul" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_np256_add" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_np256_inv" [@@noalloc] + external one : out_field_element -> unit = "mc_np256_one" [@@noalloc] + external from_bytes : out_field_element -> string -> unit = "mc_np256_from_bytes" [@@noalloc] + external to_bytes : bytes -> field_element -> unit = "mc_np256_to_bytes" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_np256_from_montgomery" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_np256_to_montgomery" [@@noalloc] end module P = Make_point(Params)(Foreign) module S = Make_scalar(Params)(P) module Dh = Make_dh(Params)(P)(S) - module Dsa = Make_dsa(Params)(Foreign_n)(P)(S)(Mirage_crypto.Hash.SHA256) + module Fn = Make_Fn(Params)(Foreign_n) + module Dsa = Make_dsa(Params)(Fn)(P)(S)(Mirage_crypto.Hash.SHA256) end module P384 : Dh_dsa = struct module Params = struct - let a = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFF0000000000000000FFFFFFFC" - let b = Cstruct.of_hex "B3312FA7E23EE7E4988E056BE3F82D19181D9C6EFE8141120314088F5013875AC656398D8A2ED19D2A85C8EDD3EC2AEF" - let g_x = - Cstruct.of_hex "AA87CA22BE8B05378EB1C71EF320AD746E1D3B628BA79B9859F741E082542A385502F25DBF55296C3A545E3872760AB7" + let a = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xFF\xFF\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFC" + let b = "\xB3\x31\x2F\xA7\xE2\x3E\xE7\xE4\x98\x8E\x05\x6B\xE3\xF8\x2D\x19\x18\x1D\x9C\x6E\xFE\x81\x41\x12\x03\x14\x08\x8F\x50\x13\x87\x5A\xC6\x56\x39\x8D\x8A\x2E\xD1\x9D\x2A\x85\xC8\xED\xD3\xEC\x2A\xEF" + let g_x = "\xAA\x87\xCA\x22\xBE\x8B\x05\x37\x8E\xB1\xC7\x1E\xF3\x20\xAD\x74\x6E\x1D\x3B\x62\x8B\xA7\x9B\x98\x59\xF7\x41\xE0\x82\x54\x2A\x38\x55\x02\xF2\x5D\xBF\x55\x29\x6C\x3A\x54\x5E\x38\x72\x76\x0A\xB7" let g_y = - Cstruct.of_hex "3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f" - let p = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFF0000000000000000FFFFFFFF" - let n = Cstruct.of_hex "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC7634D81F4372DDF581A0DB248B0A77AECEC196ACCC52973" - let pident = Cstruct.of_hex "3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFFFFC00000000000000040000000" |> Cstruct.rev (* (Params.p + 1) / 4*) +"\x36\x17\xde\x4a\x96\x26\x2c\x6f\x5d\x9e\x98\xbf\x92\x92\xdc\x29\xf8\xf4\x1d\xbd\x28\x9a\x14\x7c\xe9\xda\x31\x13\xb5\xf0\xb8\xc0\x0a\x60\xb1\xce\x1d\x7e\x81\x9d\x7a\x43\x1d\x7c\x90\xea\x0e\x5f" + let p = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xFF\xFF\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF" + let n = "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xC7\x63\x4D\x81\xF4\x37\x2D\xDF\x58\x1A\x0D\xB2\x48\xB0\xA7\x7A\xEC\xEC\x19\x6A\xCC\xC5\x29\x73" + let pident = "\x3F\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xBF\xFF\xFF\xFF\xC0\x00\x00\x00\x00\x00\x00\x00\x40\x00\x00\x00" |> rev_string (* (Params.p + 1) / 4*) let byte_length = 48 let fe_length = 48 let first_byte_bits = None end module Foreign = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_p384_mul" [@@noalloc] - external sub : field_element -> field_element -> field_element -> unit = "mc_p384_sub" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_p384_add" [@@noalloc] - external to_montgomery : field_element -> unit = "mc_p384_to_montgomery" [@@noalloc] - external from_bytes_buf : field_element -> Cstruct.buffer -> unit = "mc_p384_from_bytes" [@@noalloc] - external set_one : field_element -> unit = "mc_p384_set_one" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_p384_mul" [@@noalloc] + external sub : out_field_element -> field_element -> field_element -> unit = "mc_p384_sub" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_p384_add" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_p384_to_montgomery" [@@noalloc] + external from_octets : out_field_element -> string -> unit = "mc_p384_from_bytes" [@@noalloc] + external set_one : out_field_element -> unit = "mc_p384_set_one" [@@noalloc] external nz : field_element -> bool = "mc_p384_nz" [@@noalloc] - external sqr : field_element -> field_element -> unit = "mc_p384_sqr" [@@noalloc] - external from_montgomery : field_element -> unit = "mc_p384_from_montgomery" [@@noalloc] - external to_bytes_buf : Cstruct.buffer -> field_element -> unit = "mc_p384_to_bytes" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_p384_inv" [@@noalloc] - external select_c : field_element -> bool -> field_element -> field_element -> unit = "mc_p384_select" [@@noalloc] - - external double_c : point -> point -> unit = "mc_p384_point_double" [@@noalloc] - external add_c : point -> point -> point -> unit = "mc_p384_point_add" [@@noalloc] + external sqr : out_field_element -> field_element -> unit = "mc_p384_sqr" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_p384_from_montgomery" [@@noalloc] + external to_octets : bytes -> field_element -> unit = "mc_p384_to_bytes" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_p384_inv" [@@noalloc] + external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p384_select" [@@noalloc] + external double_c : out_point -> point -> unit = "mc_p384_point_double" [@@noalloc] + external add_c : out_point -> point -> point -> unit = "mc_p384_point_add" [@@noalloc] end module Foreign_n = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_np384_mul" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_np384_add" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_np384_inv" [@@noalloc] - external one : field_element -> unit = "mc_np384_one" [@@noalloc] - external from_bytes : field_element -> Cstruct.buffer -> unit = "mc_np384_from_bytes" [@@noalloc] - external to_bytes : Cstruct.buffer -> field_element -> unit = "mc_np384_to_bytes" [@@noalloc] - external from_montgomery : field_element -> field_element -> unit = "mc_np384_from_montgomery" [@@noalloc] - external to_montgomery : field_element -> field_element -> unit = "mc_np384_to_montgomery" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_np384_mul" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_np384_add" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_np384_inv" [@@noalloc] + external one : out_field_element -> unit = "mc_np384_one" [@@noalloc] + external from_bytes : out_field_element -> string -> unit = "mc_np384_from_bytes" [@@noalloc] + external to_bytes : bytes -> field_element -> unit = "mc_np384_to_bytes" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_np384_from_montgomery" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_np384_to_montgomery" [@@noalloc] end module P = Make_point(Params)(Foreign) module S = Make_scalar(Params)(P) module Dh = Make_dh(Params)(P)(S) - module Dsa = Make_dsa(Params)(Foreign_n)(P)(S)(Mirage_crypto.Hash.SHA384) + module Fn = Make_Fn(Params)(Foreign_n) + module Dsa = Make_dsa(Params)(Fn)(P)(S)(Mirage_crypto.Hash.SHA384) end module P521 : Dh_dsa = struct module Params = struct - let a = Cstruct.of_hex "01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC" - let b = Cstruct.of_hex "0051953EB9618E1C9A1F929A21A0B68540EEA2DA725B99B315F3B8B489918EF109E156193951EC7E937B1652C0BD3BB1BF073573DF883D2C34F1EF451FD46B503F00" + let a = "\x01\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC" + let b = "\x00\x51\x95\x3E\xB9\x61\x8E\x1C\x9A\x1F\x92\x9A\x21\xA0\xB6\x85\x40\xEE\xA2\xDA\x72\x5B\x99\xB3\x15\xF3\xB8\xB4\x89\x91\x8E\xF1\x09\xE1\x56\x19\x39\x51\xEC\x7E\x93\x7B\x16\x52\xC0\xBD\x3B\xB1\xBF\x07\x35\x73\xDF\x88\x3D\x2C\x34\xF1\xEF\x45\x1F\xD4\x6B\x50\x3F\x00" let g_x = - Cstruct.of_hex "00C6858E06B70404E9CD9E3ECB662395B4429C648139053FB521F828AF606B4D3DBAA14B5E77EFE75928FE1DC127A2FFA8DE3348B3C1856A429BF97E7E31C2E5BD66" +"\x00\xC6\x85\x8E\x06\xB7\x04\x04\xE9\xCD\x9E\x3E\xCB\x66\x23\x95\xB4\x42\x9C\x64\x81\x39\x05\x3F\xB5\x21\xF8\x28\xAF\x60\x6B\x4D\x3D\xBA\xA1\x4B\x5E\x77\xEF\xE7\x59\x28\xFE\x1D\xC1\x27\xA2\xFF\xA8\xDE\x33\x48\xB3\xC1\x85\x6A\x42\x9B\xF9\x7E\x7E\x31\xC2\xE5\xBD\x66" let g_y = - Cstruct.of_hex "011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650" - let p = Cstruct.of_hex "01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" - let n = Cstruct.of_hex "01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFA51868783BF2F966B7FCC0148F709A5D03BB5C9B8899C47AEBB6FB71E91386409" - let pident = Cstruct.of_hex "017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" |> Cstruct.rev +"\x01\x18\x39\x29\x6a\x78\x9a\x3b\xc0\x04\x5c\x8a\x5f\xb4\x2c\x7d\x1b\xd9\x98\xf5\x44\x49\x57\x9b\x44\x68\x17\xaf\xbd\x17\x27\x3e\x66\x2c\x97\xee\x72\x99\x5e\xf4\x26\x40\xc5\x50\xb9\x01\x3f\xad\x07\x61\x35\x3c\x70\x86\xa2\x72\xc2\x40\x88\xbe\x94\x76\x9f\xd1\x66\x50" + let p = "\x01\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + let n = "\x01\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFA\x51\x86\x87\x83\xBF\x2F\x96\x6B\x7F\xCC\x01\x48\xF7\x09\xA5\xD0\x3B\xB5\xC9\xB8\x89\x9C\x47\xAE\xBB\x6F\xB7\x1E\x91\x38\x64\x09" + let pident = "\x01\x7f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" |> rev_string let byte_length = 66 let fe_length = if Sys.word_size == 64 then 72 else 68 (* TODO: is this congruent with C code? *) let first_byte_bits = Some 0x01 end module Foreign = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_p521_mul" [@@noalloc] - external sub : field_element -> field_element -> field_element -> unit = "mc_p521_sub" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_p521_add" [@@noalloc] - external to_montgomery : field_element -> unit = "mc_p521_to_montgomery" [@@noalloc] - external from_bytes_buf : field_element -> Cstruct.buffer -> unit = "mc_p521_from_bytes" [@@noalloc] - external set_one : field_element -> unit = "mc_p521_set_one" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_p521_mul" [@@noalloc] + external sub : out_field_element -> field_element -> field_element -> unit = "mc_p521_sub" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_p521_add" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_p521_to_montgomery" [@@noalloc] + external from_octets : out_field_element -> string -> unit = "mc_p521_from_bytes" [@@noalloc] + external set_one : out_field_element -> unit = "mc_p521_set_one" [@@noalloc] external nz : field_element -> bool = "mc_p521_nz" [@@noalloc] - external sqr : field_element -> field_element -> unit = "mc_p521_sqr" [@@noalloc] - external from_montgomery : field_element -> unit = "mc_p521_from_montgomery" [@@noalloc] - external to_bytes_buf : Cstruct.buffer -> field_element -> unit = "mc_p521_to_bytes" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_p521_inv" [@@noalloc] - external select_c : field_element -> bool -> field_element -> field_element -> unit = "mc_p521_select" [@@noalloc] - - external double_c : point -> point -> unit = "mc_p521_point_double" [@@noalloc] - external add_c : point -> point -> point -> unit = "mc_p521_point_add" [@@noalloc] + external sqr : out_field_element -> field_element -> unit = "mc_p521_sqr" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_p521_from_montgomery" [@@noalloc] + external to_octets : bytes -> field_element -> unit = "mc_p521_to_bytes" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_p521_inv" [@@noalloc] + external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p521_select" [@@noalloc] + external double_c : out_point -> point -> unit = "mc_p521_point_double" [@@noalloc] + external add_c : out_point -> point -> point -> unit = "mc_p521_point_add" [@@noalloc] end module Foreign_n = struct - external mul : field_element -> field_element -> field_element -> unit = "mc_np521_mul" [@@noalloc] - external add : field_element -> field_element -> field_element -> unit = "mc_np521_add" [@@noalloc] - external inv : field_element -> field_element -> unit = "mc_np521_inv" [@@noalloc] - external one : field_element -> unit = "mc_np521_one" [@@noalloc] - external from_bytes : field_element -> Cstruct.buffer -> unit = "mc_np521_from_bytes" [@@noalloc] - external to_bytes : Cstruct.buffer -> field_element -> unit = "mc_np521_to_bytes" [@@noalloc] - external from_montgomery : field_element -> field_element -> unit = "mc_np521_from_montgomery" [@@noalloc] - external to_montgomery : field_element -> field_element -> unit = "mc_np521_to_montgomery" [@@noalloc] + external mul : out_field_element -> field_element -> field_element -> unit = "mc_np521_mul" [@@noalloc] + external add : out_field_element -> field_element -> field_element -> unit = "mc_np521_add" [@@noalloc] + external inv : out_field_element -> field_element -> unit = "mc_np521_inv" [@@noalloc] + external one : out_field_element -> unit = "mc_np521_one" [@@noalloc] + external from_bytes : out_field_element -> string -> unit = "mc_np521_from_bytes" [@@noalloc] + external to_bytes : bytes -> field_element -> unit = "mc_np521_to_bytes" [@@noalloc] + external from_montgomery : out_field_element -> field_element -> unit = "mc_np521_from_montgomery" [@@noalloc] + external to_montgomery : out_field_element -> field_element -> unit = "mc_np521_to_montgomery" [@@noalloc] end module P = Make_point(Params)(Foreign) module S = Make_scalar(Params)(P) module Dh = Make_dh(Params)(P)(S) - module Dsa = Make_dsa(Params)(Foreign_n)(P)(S)(Mirage_crypto.Hash.SHA512) + module Fn = Make_Fn(Params)(Foreign_n) + module Dsa = Make_dsa(Params)(Fn)(P)(S)(Mirage_crypto.Hash.SHA512) end module X25519 = struct (* RFC 7748 *) - external x25519_scalar_mult_generic : Cstruct.buffer -> Cstruct.buffer -> int -> Cstruct.buffer -> int -> unit = "mc_x25519_scalar_mult_generic" [@@noalloc] + external x25519_scalar_mult_generic : bytes -> string -> string -> unit = "mc_x25519_scalar_mult_generic" [@@noalloc] let key_len = 32 let scalar_mult in_ base = - let out = Cstruct.create key_len in - x25519_scalar_mult_generic out.Cstruct.buffer - in_.Cstruct.buffer in_.Cstruct.off base.Cstruct.buffer base.Cstruct.off; - out + let out = Bytes.make key_len '\000' in + x25519_scalar_mult_generic out in_ base; + Bytes.unsafe_to_string out - type secret = Cstruct.t + type secret = string - let basepoint = - let data = Cstruct.create key_len in - Cstruct.set_uint8 data 0 9; - data + let basepoint = String.init key_len (function 0 -> '\009' | _ -> '\000') let public priv = scalar_mult priv basepoint - let gen_key ?compress:_ ?g () = - let secret = Mirage_crypto_rng.generate ?g key_len in + let gen_key_octets ?compress:_ ?g () = + let secret = Cstruct.to_string (Mirage_crypto_rng.generate ?g key_len) in secret, public secret - let secret_of_cs ?compress:_ s = - if Cstruct.length s = key_len then + let gen_key ?compress ?g () = + let secret, public = gen_key_octets ~compress ?g () in + secret, Cstruct.of_string public + + let secret_of_octets ?compress:_ s = + if String.length s = key_len then Ok (s, public s) else Error `Invalid_length + let secret_of_cs ?compress cs = + Result.map (fun (secret, public) -> secret, Cstruct.of_string public) + (secret_of_octets ~compress (Cstruct.to_string cs)) + let is_zero = - let zero = Cstruct.create key_len in - fun cs -> Cstruct.equal zero cs + let zero = String.make key_len '\000' in + fun buf -> String.equal zero buf - let key_exchange secret public = - if Cstruct.length public = key_len then + let key_exchange_octets secret public = + if String.length public = key_len then let res = scalar_mult secret public in if is_zero res then Error `Low_order else Ok res else Error `Invalid_length + + let key_exchange secret public = + Result.map Cstruct.of_string + (key_exchange_octets secret (Cstruct.to_string public)) end module Ed25519 = struct + external scalar_mult_base_to_bytes : bytes -> string -> unit = "mc_25519_scalar_mult_base" [@@noalloc] + external reduce_l : bytes -> unit = "mc_25519_reduce_l" [@@noalloc] + external muladd : bytes -> string -> string -> string -> unit = "mc_25519_muladd" [@@noalloc] + external double_scalar_mult : bytes -> string -> string -> string -> bool = "mc_25519_double_scalar_mult" [@@noalloc] + external pub_ok : string -> bool = "mc_25519_pub_ok" [@@noalloc] - external scalar_mult_base_to_bytes : Cstruct.buffer -> Cstruct.buffer -> unit = "mc_25519_scalar_mult_base" [@@noalloc] - external reduce_l : Cstruct.buffer -> unit = "mc_25519_reduce_l" [@@noalloc] - external muladd : Cstruct.buffer -> Cstruct.buffer -> Cstruct.buffer -> Cstruct.buffer -> unit = "mc_25519_muladd" [@@noalloc] - external double_scalar_mult : Cstruct.buffer -> Cstruct.buffer -> Cstruct.buffer -> Cstruct.buffer -> int -> bool = "mc_25519_double_scalar_mult" [@@noalloc] - external pub_ok : Cstruct.buffer -> bool = "mc_25519_pub_ok" [@@noalloc] + let key_len = 32 - type pub = Cstruct.t + let scalar_mult_base_to_bytes p = + let tmp = Bytes.make key_len '\000' in + scalar_mult_base_to_bytes tmp p; + Bytes.unsafe_to_string tmp - type priv = Cstruct.t + let muladd a b c = + let tmp = Bytes.make key_len '\000' in + muladd tmp a b c; + Bytes.unsafe_to_string tmp - (* RFC 8032 *) - let key_len = 32 + let double_scalar_mult a b c = + let tmp = Bytes.make key_len '\000' in + let s = double_scalar_mult tmp a b c in + s, Bytes.unsafe_to_string tmp + + type pub = string + + type priv = string + (* RFC 8032 *) let public secret = (* section 5.1.5 *) (* step 1 *) - let h = Mirage_crypto.Hash.SHA512.digest secret in + let h = Mirage_crypto.Hash.SHA512.digest (Cstruct.of_string secret) in (* step 2 *) let s, rest = Cstruct.split h key_len in - Cstruct.set_uint8 s 0 (Cstruct.get_uint8 s 0 land 248); - Cstruct.set_uint8 s 31 ((Cstruct.get_uint8 s 31 land 127) lor 64); + let s, rest = Cstruct.to_bytes s, Cstruct.to_string rest in + Bytes.set_uint8 s 0 ((Bytes.get_uint8 s 0) land 248); + Bytes.set_uint8 s 31 (((Bytes.get_uint8 s 31) land 127) lor 64); + let s = Bytes.unsafe_to_string s in (* step 3 and 4 *) - let public = Cstruct.create key_len in - scalar_mult_base_to_bytes public.Cstruct.buffer s.Cstruct.buffer; + let public = scalar_mult_base_to_bytes s in public, (s, rest) let pub_of_priv secret = fst (public secret) - let priv_of_cstruct cs = - if Cstruct.length cs = key_len then Ok cs else Error `Invalid_length + let priv_of_octets buf = + if String.length buf = key_len then Ok buf else Error `Invalid_length + + let priv_of_cstruct p = priv_of_octets (Cstruct.to_string p) - let priv_to_cstruct priv = priv + let priv_to_octets priv = priv - let pub_of_cstruct cs = - if Cstruct.length cs = key_len then - let cs_copy = Cstruct.create key_len in - Cstruct.blit cs 0 cs_copy 0 key_len; - if pub_ok cs_copy.Cstruct.buffer then - Ok cs_copy + let priv_to_cstruct p = Cstruct.of_string (priv_to_octets p) + + let pub_of_octets buf = + if String.length buf = key_len then + if pub_ok buf then + Ok buf else Error `Not_on_curve else Error `Invalid_length - let pub_to_cstruct pub = pub + let pub_of_cstruct p = pub_of_octets (Cstruct.to_string p) + + let pub_to_octets pub = pub + + let pub_to_cstruct p = Cstruct.of_string (pub_to_octets p) let generate ?g () = let secret = Mirage_crypto_rng.generate ?g key_len in + let secret = Cstruct.to_string secret in secret, pub_of_priv secret - let sign ~key msg = + let sign_octets ~key msg = (* section 5.1.6 *) let pub, (s, prefix) = public key in - let r = Mirage_crypto.Hash.SHA512.digest (Cstruct.append prefix msg) in - reduce_l r.Cstruct.buffer; - let r_big = Cstruct.create key_len in - scalar_mult_base_to_bytes r_big.Cstruct.buffer r.Cstruct.buffer; - let k = Mirage_crypto.Hash.SHA512.digest (Cstruct.concat [ r_big ; pub ; msg ]) in - reduce_l k.Cstruct.buffer; - let s_out = Cstruct.create key_len in - muladd s_out.Cstruct.buffer k.Cstruct.buffer s.Cstruct.buffer r.Cstruct.buffer; - Cstruct.append r_big s_out - - let verify ~key signature ~msg = + let r = Mirage_crypto.Hash.SHA512.digest (Cstruct.of_string (String.concat "" [ prefix; msg ])) in + let r = Cstruct.to_bytes r in + reduce_l r; + let r = Bytes.unsafe_to_string r in + let r_big = scalar_mult_base_to_bytes r in + let k = Mirage_crypto.Hash.SHA512.digest (Cstruct.of_string (String.concat "" [ r_big; pub; msg])) in + let k = Cstruct.to_bytes k in + reduce_l k; + let k = Bytes.unsafe_to_string k in + let s_out = muladd k s r in + let res = Bytes.make (key_len + key_len) '\000' in + Bytes.blit_string r_big 0 res 0 key_len ; + Bytes.blit_string s_out 0 res key_len key_len ; + Bytes.unsafe_to_string res + + let sign ~key msg = Cstruct.of_string (sign_octets ~key (Cstruct.to_string msg)) + + let verify_octets ~key signature ~msg = (* section 5.1.7 *) - if Cstruct.length signature = 2 * key_len then - let r, s = Cstruct.split signature key_len in + if String.length signature = 2 * key_len then + let r, s = + String.sub signature 0 key_len, + String.sub signature key_len key_len + in let s_smaller_l = (* check s within 0 <= s < L *) - let s' = Cstruct.create (key_len * 2) in - Cstruct.blit s 0 s' 0 key_len; - reduce_l s'.Cstruct.buffer; - let s'' = Cstruct.(append s (create key_len)) in - Cstruct.equal s'' s' + let s' = Bytes.make (key_len * 2) '\000' in + Bytes.blit_string s 0 s' 0 key_len; + reduce_l s'; + let s' = Bytes.unsafe_to_string s' in + let s'' = String.concat "" [ s; String.make key_len '\000' ] in + String.equal s'' s' in if s_smaller_l then begin let k = - Mirage_crypto.Hash.SHA512.digest (Cstruct.concat [ r ; key ; msg ]) + let data_to_hash = String.concat "" [ r ; key ; msg ] in + Mirage_crypto.Hash.SHA512.digest (Cstruct.of_string data_to_hash) in - reduce_l k.Cstruct.buffer; - let r' = Cstruct.create key_len in - let success = - double_scalar_mult r'.Cstruct.buffer k.Cstruct.buffer - key.Cstruct.buffer s.Cstruct.buffer s.Cstruct.off - in - success && Cstruct.equal r r' + let k = Cstruct.to_bytes k in + reduce_l k; + let k = Bytes.unsafe_to_string k in + let success, r' = double_scalar_mult k key s in + success && String.equal r r' end else false else false + + let verify ~key signature ~msg = + verify_octets ~key (Cstruct.to_string signature) ~msg:(Cstruct.to_string msg) end diff --git a/ec/native/curve25519_stubs.c b/ec/native/curve25519_stubs.c index 283a453a..f4056f53 100644 --- a/ec/native/curve25519_stubs.c +++ b/ec/native/curve25519_stubs.c @@ -1804,10 +1804,10 @@ static void sc_muladd(uint8_t *s, const uint8_t *a, const uint8_t *b, #include -CAMLprim value mc_x25519_scalar_mult_generic(value out, value scalar, value soff, value point, value poff) +CAMLprim value mc_x25519_scalar_mult_generic(value out, value scalar, value point) { - CAMLparam5(out, scalar, soff, point, poff); - x25519_scalar_mult_generic(Caml_ba_data_val(out), _ba_uint8_off(scalar, soff), _ba_uint8_off(point, poff)); + CAMLparam3(out, scalar, point); + x25519_scalar_mult_generic(Bytes_val(out), _st_uint8(scalar), _st_uint8(point)); CAMLreturn(Val_unit); } @@ -1816,39 +1816,39 @@ CAMLprim value mc_25519_scalar_mult_base(value out, value hash) CAMLparam2(out, hash); ge_p3 A; ge_p3_0(&A); - x25519_ge_scalarmult_base(&A, Caml_ba_data_val(hash)); - ge_p3_tobytes(Caml_ba_data_val(out), &A); + x25519_ge_scalarmult_base(&A, _st_uint8(hash)); + ge_p3_tobytes(Bytes_val(out), &A); CAMLreturn(Val_unit); } CAMLprim value mc_25519_reduce_l(value buf) { CAMLparam1(buf); - x25519_sc_reduce(Caml_ba_data_val(buf)); + x25519_sc_reduce(Bytes_val(buf)); CAMLreturn(Val_unit); } CAMLprim value mc_25519_muladd(value out, value a, value b, value c) { CAMLparam4(out, a, b, c); - sc_muladd(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b), Caml_ba_data_val(c)); + sc_muladd(Bytes_val(out), _st_uint8(a), _st_uint8(b), _st_uint8(c)); CAMLreturn(Val_unit); } -CAMLprim value mc_25519_double_scalar_mult(value out, value k, value key, value c, value coff) +CAMLprim value mc_25519_double_scalar_mult(value out, value k, value key, value c) { - CAMLparam5(out, k, key, c, coff); + CAMLparam4(out, k, key, c); ge_p2 R; ge_p3 B; fe_loose t; int success = 0; - success = x25519_ge_frombytes_vartime(&B, Caml_ba_data_val(key)); + success = x25519_ge_frombytes_vartime(&B, _st_uint8(key)); fe_neg(&t, &B.X); fe_carry(&B.X, &t); fe_neg(&t, &B.T); fe_carry(&B.T, &t); - ge_double_scalarmult_vartime(&R, Caml_ba_data_val(k), &B, _ba_uint8_off(c, coff)); - x25519_ge_tobytes(Caml_ba_data_val(out), &R); + ge_double_scalarmult_vartime(&R, _st_uint8(k), &B, _st_uint8(c)); + x25519_ge_tobytes(Bytes_val(out), &R); CAMLreturn(Val_bool(success)); } @@ -1857,6 +1857,6 @@ CAMLprim value mc_25519_pub_ok(value key) CAMLparam1(key); int success = 0; ge_p3 B; - success = x25519_ge_frombytes_vartime(&B, Caml_ba_data_val(key)); + success = x25519_ge_frombytes_vartime(&B, _st_uint8(key)); CAMLreturn(Val_bool(success)); } diff --git a/ec/native/inversion_template.h b/ec/native/inversion_template.h index cb4c079a..7eb5ab35 100644 --- a/ec/native/inversion_template.h +++ b/ec/native/inversion_template.h @@ -52,7 +52,7 @@ static void inverse(WORD out[LIMBS], WORD g[SAT_LIMBS]) { return; } -static void inversion (WORD out[LIMBS], WORD in[LIMBS]) { +static void inversion (WORD out[LIMBS], const WORD in[LIMBS]) { WORD in_[SAT_LIMBS]; for (int i = 0; i < LIMBS; i++) in_[i] = in[i]; in_[LIMBS] = 0; diff --git a/ec/native/np224_stubs.c b/ec/native/np224_stubs.c index fe64b5b4..75083a5d 100644 --- a/ec/native/np224_stubs.c +++ b/ec/native/np224_stubs.c @@ -22,55 +22,55 @@ CAMLprim value mc_np224_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np224_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np224_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np224_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np224_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_one(value out) { CAMLparam1(out); - fiat_np224_set_one(Caml_ba_data_val(out)); + fiat_np224_set_one((WORD*)Bytes_val(out)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np224_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np224_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np224_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np224_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_from_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np224_from_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np224_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np224_to_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np224_to_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np224_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } diff --git a/ec/native/np256_stubs.c b/ec/native/np256_stubs.c index cd4aa30d..2152231c 100644 --- a/ec/native/np256_stubs.c +++ b/ec/native/np256_stubs.c @@ -22,55 +22,55 @@ CAMLprim value mc_np256_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np256_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np256_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np256_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np256_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_one(value out) { CAMLparam1(out); - fiat_np256_set_one(Caml_ba_data_val(out)); + fiat_np256_set_one((WORD*)Bytes_val(out)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np256_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np256_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np256_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np256_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_from_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np256_from_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np256_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np256_to_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np256_to_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np256_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } diff --git a/ec/native/np384_stubs.c b/ec/native/np384_stubs.c index c1abd4e5..37a41577 100644 --- a/ec/native/np384_stubs.c +++ b/ec/native/np384_stubs.c @@ -22,55 +22,55 @@ CAMLprim value mc_np384_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np384_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np384_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np384_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np384_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_one(value out) { CAMLparam1(out); - fiat_np384_set_one(Caml_ba_data_val(out)); + fiat_np384_set_one((WORD*)Bytes_val(out)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np384_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np384_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np384_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np384_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_from_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np384_from_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np384_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np384_to_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np384_to_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np384_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } diff --git a/ec/native/np521_stubs.c b/ec/native/np521_stubs.c index aaa8bf5a..7f7e1c0b 100644 --- a/ec/native/np521_stubs.c +++ b/ec/native/np521_stubs.c @@ -22,56 +22,56 @@ CAMLprim value mc_np521_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np521_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np521_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_np521_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_np521_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_one(value out) { CAMLparam1(out); - fiat_np521_set_one(Caml_ba_data_val(out)); + fiat_np521_set_one((WORD*)Bytes_val(out)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np521_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np521_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_np521_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np521_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_from_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np521_from_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np521_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_np521_to_montgomery(value out, value in) { CAMLparam2(out, in); - fiat_np521_to_montgomery(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_np521_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } diff --git a/ec/native/p224_stubs.c b/ec/native/p224_stubs.c index 0a5da891..b7ee1980 100644 --- a/ec/native/p224_stubs.c +++ b/ec/native/p224_stubs.c @@ -23,78 +23,76 @@ CAMLprim value mc_p224_sub(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p224_sub(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p224_sub((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p224_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p224_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p224_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p224_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p224_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p224_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p224_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p224_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_sqr(value out, value in) { CAMLparam2(out, in); - fiat_p224_square(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p224_square((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p224_from_montgomery(value x) +CAMLprim value mc_p224_from_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p224_from_montgomery(l, l); + CAMLparam2(out, in); + fiat_p224_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p224_to_montgomery(value x) +CAMLprim value mc_p224_to_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p224_to_montgomery(l, l); + CAMLparam2(out, in); + fiat_p224_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_nz(value x) { CAMLparam1(x); - CAMLreturn(Val_bool(fe_nz(Caml_ba_data_val(x)))); + CAMLreturn(Val_bool(fe_nz((const WORD*)String_val(x)))); } CAMLprim value mc_p224_set_one(value x) { CAMLparam1(x); - fiat_p224_set_one(Caml_ba_data_val(x)); + fiat_p224_set_one((WORD*)Bytes_val(x)); CAMLreturn(Val_unit); } CAMLprim value mc_p224_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } @@ -102,12 +100,12 @@ CAMLprim value mc_p224_point_double(value out, value in) { CAMLparam2(out, in); point_double( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(in, 0)), - Caml_ba_data_val(Field(in, 1)), - Caml_ba_data_val(Field(in, 2)) + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(in, 0)), + (const WORD*)String_val(Field(in, 1)), + (const WORD*)String_val(Field(in, 2)) ); CAMLreturn(Val_unit); } @@ -116,16 +114,16 @@ CAMLprim value mc_p224_point_add(value out, value p, value q) { CAMLparam3(out, p, q); point_add( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(p, 0)), - Caml_ba_data_val(Field(p, 1)), - Caml_ba_data_val(Field(p, 2)), + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(p, 0)), + (const WORD*)String_val(Field(p, 1)), + (const WORD*)String_val(Field(p, 2)), 0, - Caml_ba_data_val(Field(q, 0)), - Caml_ba_data_val(Field(q, 1)), - Caml_ba_data_val(Field(q, 2)) + (const WORD*)String_val(Field(q, 0)), + (const WORD*)String_val(Field(q, 1)), + (const WORD*)String_val(Field(q, 2)) ); CAMLreturn(Val_unit); } @@ -134,10 +132,10 @@ CAMLprim value mc_p224_select(value out, value bit, value t, value f) { CAMLparam4(out, bit, t, f); fe_cmovznz( - Caml_ba_data_val(out), + (WORD*)Bytes_val(out), Bool_val(bit), - Caml_ba_data_val(f), - Caml_ba_data_val(t) + (const WORD*)String_val(f), + (const WORD*)String_val(t) ); CAMLreturn(Val_unit); } diff --git a/ec/native/p256_stubs.c b/ec/native/p256_stubs.c index 7dc2d927..e5ac7d47 100644 --- a/ec/native/p256_stubs.c +++ b/ec/native/p256_stubs.c @@ -23,78 +23,76 @@ CAMLprim value mc_p256_sub(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p256_sub(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p256_sub((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p256_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p256_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p256_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p256_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p256_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p256_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p256_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p256_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_sqr(value out, value in) { CAMLparam2(out, in); - fiat_p256_square(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p256_square((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p256_from_montgomery(value x) +CAMLprim value mc_p256_from_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p256_from_montgomery(l, l); + CAMLparam2(out, in); + fiat_p256_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p256_to_montgomery(value x) +CAMLprim value mc_p256_to_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p256_to_montgomery(l, l); + CAMLparam2(out, in); + fiat_p256_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_nz(value x) { CAMLparam1(x); - CAMLreturn(Val_bool(fe_nz(Caml_ba_data_val(x)))); + CAMLreturn(Val_bool(fe_nz((const WORD*)String_val(x)))); } CAMLprim value mc_p256_set_one(value x) { CAMLparam1(x); - fiat_p256_set_one(Caml_ba_data_val(x)); + fiat_p256_set_one((WORD*)Bytes_val(x)); CAMLreturn(Val_unit); } CAMLprim value mc_p256_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } @@ -102,12 +100,12 @@ CAMLprim value mc_p256_point_double(value out, value in) { CAMLparam2(out, in); point_double( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(in, 0)), - Caml_ba_data_val(Field(in, 1)), - Caml_ba_data_val(Field(in, 2)) + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(in, 0)), + (const WORD*)String_val(Field(in, 1)), + (const WORD*)String_val(Field(in, 2)) ); CAMLreturn(Val_unit); } @@ -116,16 +114,16 @@ CAMLprim value mc_p256_point_add(value out, value p, value q) { CAMLparam3(out, p, q); point_add( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(p, 0)), - Caml_ba_data_val(Field(p, 1)), - Caml_ba_data_val(Field(p, 2)), + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(p, 0)), + (const WORD*)String_val(Field(p, 1)), + (const WORD*)String_val(Field(p, 2)), 0, - Caml_ba_data_val(Field(q, 0)), - Caml_ba_data_val(Field(q, 1)), - Caml_ba_data_val(Field(q, 2)) + (const WORD*)String_val(Field(q, 0)), + (const WORD*)String_val(Field(q, 1)), + (const WORD*)String_val(Field(q, 2)) ); CAMLreturn(Val_unit); } @@ -134,10 +132,10 @@ CAMLprim value mc_p256_select(value out, value bit, value t, value f) { CAMLparam4(out, bit, t, f); fe_cmovznz( - Caml_ba_data_val(out), + (WORD*)Bytes_val(out), Bool_val(bit), - Caml_ba_data_val(f), - Caml_ba_data_val(t) + (const WORD*)String_val(f), + (const WORD*)String_val(t) ); CAMLreturn(Val_unit); } diff --git a/ec/native/p384_stubs.c b/ec/native/p384_stubs.c index 2b2efd2a..7ffb60e4 100644 --- a/ec/native/p384_stubs.c +++ b/ec/native/p384_stubs.c @@ -23,78 +23,76 @@ CAMLprim value mc_p384_sub(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p384_sub(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p384_sub((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p384_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p384_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p384_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p384_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p384_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p384_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p384_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p384_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_sqr(value out, value in) { CAMLparam2(out, in); - fiat_p384_square(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p384_square((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p384_from_montgomery(value x) +CAMLprim value mc_p384_from_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p384_from_montgomery(l, l); + CAMLparam2(out, in); + fiat_p384_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p384_to_montgomery(value x) +CAMLprim value mc_p384_to_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p384_to_montgomery(l, l); + CAMLparam2(out, in); + fiat_p384_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_nz(value x) { CAMLparam1(x); - CAMLreturn(Val_bool(fe_nz(Caml_ba_data_val(x)))); + CAMLreturn(Val_bool(fe_nz((const WORD*)String_val(x)))); } CAMLprim value mc_p384_set_one(value x) { CAMLparam1(x); - fiat_p384_set_one(Caml_ba_data_val(x)); + fiat_p384_set_one((WORD*)Bytes_val(x)); CAMLreturn(Val_unit); } CAMLprim value mc_p384_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } @@ -102,12 +100,12 @@ CAMLprim value mc_p384_point_double(value out, value in) { CAMLparam2(out, in); point_double( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(in, 0)), - Caml_ba_data_val(Field(in, 1)), - Caml_ba_data_val(Field(in, 2)) + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(in, 0)), + (const WORD*)String_val(Field(in, 1)), + (const WORD*)String_val(Field(in, 2)) ); CAMLreturn(Val_unit); } @@ -116,16 +114,16 @@ CAMLprim value mc_p384_point_add(value out, value p, value q) { CAMLparam3(out, p, q); point_add( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(p, 0)), - Caml_ba_data_val(Field(p, 1)), - Caml_ba_data_val(Field(p, 2)), + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(p, 0)), + (const WORD*)String_val(Field(p, 1)), + (const WORD*)String_val(Field(p, 2)), 0, - Caml_ba_data_val(Field(q, 0)), - Caml_ba_data_val(Field(q, 1)), - Caml_ba_data_val(Field(q, 2)) + (const WORD*)String_val(Field(q, 0)), + (const WORD*)String_val(Field(q, 1)), + (const WORD*)String_val(Field(q, 2)) ); CAMLreturn(Val_unit); } @@ -134,10 +132,10 @@ CAMLprim value mc_p384_select(value out, value bit, value t, value f) { CAMLparam4(out, bit, t, f); fe_cmovznz( - Caml_ba_data_val(out), + (WORD*)Bytes_val(out), Bool_val(bit), - Caml_ba_data_val(f), - Caml_ba_data_val(t) + (const WORD*)String_val(f), + (const WORD*)String_val(t) ); CAMLreturn(Val_unit); } diff --git a/ec/native/p521_stubs.c b/ec/native/p521_stubs.c index e8d6764d..81f92fbc 100644 --- a/ec/native/p521_stubs.c +++ b/ec/native/p521_stubs.c @@ -23,78 +23,76 @@ CAMLprim value mc_p521_sub(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p521_sub(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p521_sub((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_add(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p521_add(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p521_add((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_mul(value out, value a, value b) { CAMLparam3(out, a, b); - fiat_p521_mul(Caml_ba_data_val(out), Caml_ba_data_val(a), Caml_ba_data_val(b)); + fiat_p521_mul((WORD*)Bytes_val(out), (const WORD*)String_val(a), (const WORD*)String_val(b)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_from_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p521_from_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p521_from_bytes((WORD*)Bytes_val(out), _st_uint8(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_to_bytes(value out, value in) { CAMLparam2(out, in); - fiat_p521_to_bytes(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p521_to_bytes(Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_sqr(value out, value in) { CAMLparam2(out, in); - fiat_p521_square(Caml_ba_data_val(out), Caml_ba_data_val(in)); + fiat_p521_square((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p521_from_montgomery(value x) +CAMLprim value mc_p521_from_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p521_from_montgomery(l, l); + CAMLparam2(out, in); + fiat_p521_from_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } -CAMLprim value mc_p521_to_montgomery(value x) +CAMLprim value mc_p521_to_montgomery(value out, value in) { - CAMLparam1(x); - WORD *l = Caml_ba_data_val(x); - fiat_p521_to_montgomery(l, l); + CAMLparam2(out, in); + fiat_p521_to_montgomery((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_nz(value x) { CAMLparam1(x); - CAMLreturn(Val_bool(fe_nz(Caml_ba_data_val(x)))); + CAMLreturn(Val_bool(fe_nz((const WORD*)String_val(x)))); } CAMLprim value mc_p521_set_one(value x) { CAMLparam1(x); - fiat_p521_set_one(Caml_ba_data_val(x)); + fiat_p521_set_one((WORD*)Bytes_val(x)); CAMLreturn(Val_unit); } CAMLprim value mc_p521_inv(value out, value in) { CAMLparam2(out, in); - inversion(Caml_ba_data_val(out), Caml_ba_data_val(in)); + inversion((WORD*)Bytes_val(out), (const WORD*)String_val(in)); CAMLreturn(Val_unit); } @@ -102,12 +100,12 @@ CAMLprim value mc_p521_point_double(value out, value in) { CAMLparam2(out, in); point_double( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(in, 0)), - Caml_ba_data_val(Field(in, 1)), - Caml_ba_data_val(Field(in, 2)) + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(in, 0)), + (const WORD*)String_val(Field(in, 1)), + (const WORD*)String_val(Field(in, 2)) ); CAMLreturn(Val_unit); } @@ -116,16 +114,16 @@ CAMLprim value mc_p521_point_add(value out, value p, value q) { CAMLparam3(out, p, q); point_add( - Caml_ba_data_val(Field(out, 0)), - Caml_ba_data_val(Field(out, 1)), - Caml_ba_data_val(Field(out, 2)), - Caml_ba_data_val(Field(p, 0)), - Caml_ba_data_val(Field(p, 1)), - Caml_ba_data_val(Field(p, 2)), + (WORD*)Bytes_val(Field(out, 0)), + (WORD*)Bytes_val(Field(out, 1)), + (WORD*)Bytes_val(Field(out, 2)), + (const WORD*)String_val(Field(p, 0)), + (const WORD*)String_val(Field(p, 1)), + (const WORD*)String_val(Field(p, 2)), 0, - Caml_ba_data_val(Field(q, 0)), - Caml_ba_data_val(Field(q, 1)), - Caml_ba_data_val(Field(q, 2)) + (const WORD*)String_val(Field(q, 0)), + (const WORD*)String_val(Field(q, 1)), + (const WORD*)String_val(Field(q, 2)) ); CAMLreturn(Val_unit); } @@ -134,10 +132,10 @@ CAMLprim value mc_p521_select(value out, value bit, value t, value f) { CAMLparam4(out, bit, t, f); fe_cmovznz( - Caml_ba_data_val(out), + (WORD*)Bytes_val(out), Bool_val(bit), - Caml_ba_data_val(f), - Caml_ba_data_val(t) + (const WORD*)String_val(f), + (const WORD*)String_val(t) ); CAMLreturn(Val_unit); } diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 22cc7af6..ad3e2855 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -60,6 +60,8 @@ extern struct _mc_cpu_features mc_detected_cpu_features; #endif #define __unit() value __unused(_) +#define _st_uint8(v) ((const uint8_t*) (String_val(v))) + #define _ba_uint8_off(ba, off) ((uint8_t*) Caml_ba_data_val (ba) + Long_val (off)) #define _ba_uint32_off(ba, off) ((uint32_t*) Caml_ba_data_val (ba) + Long_val (off))