Skip to content

Commit 5b6487a

Browse files
committed
ec: expose low-level point arithmetic
1 parent 0624a80 commit 5b6487a

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

ec/mirage_crypto_ec.ml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ end
6868
module type Dh_dsa = sig
6969
module Dh : Dh
7070
module Dsa : Dsa
71+
module Point : sig
72+
type point
73+
type scalar
74+
val of_octets : string -> (point, error) result
75+
val to_octets : ?compress:bool -> point -> string
76+
val scalar_of_octets : string -> (scalar, error) result
77+
val scalar_to_octets : scalar -> string
78+
val generator : point
79+
val add : point -> point -> point
80+
val scalar_mult : scalar -> point -> point
81+
end
7182
end
7283

7384
type field_element = string
@@ -774,6 +785,18 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
774785
end
775786
end
776787

788+
module Point (P : Point) (S : Scalar) = struct
789+
type nonrec point = point
790+
type nonrec scalar = scalar
791+
let of_octets = P.of_octets
792+
let to_octets ?(compress = false) p = P.to_octets ~compress p
793+
let scalar_of_octets = S.of_octets
794+
let scalar_to_octets = S.to_octets
795+
let generator = P.params_g
796+
let add = P.add
797+
let scalar_mult = S.scalar_mult
798+
end
799+
777800
module P256 : Dh_dsa = struct
778801
module Params = struct
779802
let a = "\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFC"
@@ -823,6 +846,7 @@ module P256 : Dh_dsa = struct
823846
module Dh = Make_dh(Params)(P)(S)
824847
module Fn = Make_Fn(Params)(Foreign_n)
825848
module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA256)
849+
module Point = Point(P)(S)
826850
end
827851

828852
module P384 : Dh_dsa = struct
@@ -875,6 +899,7 @@ module P384 : Dh_dsa = struct
875899
module Dh = Make_dh(Params)(P)(S)
876900
module Fn = Make_Fn(Params)(Foreign_n)
877901
module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA384)
902+
module Point = Point(P)(S)
878903
end
879904

880905
module P521 : Dh_dsa = struct
@@ -928,6 +953,7 @@ module P521 : Dh_dsa = struct
928953
module Dh = Make_dh(Params)(P)(S)
929954
module Fn = Make_Fn(Params)(Foreign_n)
930955
module Dsa = Make_dsa(Params)(Fn)(P)(S)(Digestif.SHA512)
956+
module Point = Point(P)(S)
931957
end
932958

933959
module X25519 = struct

ec/mirage_crypto_ec.mli

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,40 @@ module type Dh_dsa = sig
163163

164164
(** Digital signature algorithm. *)
165165
module Dsa : Dsa
166+
167+
(** Low-level point arithmetic. *)
168+
module Point : sig
169+
type point
170+
(** The type for points on the elliptic curve. *)
171+
172+
type scalar
173+
(** The type for scalars. *)
174+
175+
val of_octets : string -> (point, error) result
176+
(** [of_octets buf] decodes a point from [buf] in uncompressed or compressed
177+
SEC 1 format. Returns an error if the point is not on the curve. *)
178+
179+
val to_octets : ?compress:bool -> point -> string
180+
(** [to_octets ~compress point] encodes [point] to SEC 1 format. If
181+
[compress] is [true] (default [false]), the compressed format is used. *)
182+
183+
val scalar_of_octets : string -> (scalar, error) result
184+
(** [scalar_of_octets buf] decodes a scalar from [buf]. Returns an error if
185+
the scalar is not in the valid range \[1, n-1\] where n is the group
186+
order. *)
187+
188+
val scalar_to_octets : scalar -> string
189+
(** [scalar_to_octets scalar] encodes [scalar] to a byte string. *)
190+
191+
val generator : point
192+
(** [generator] is the generator point (base point) of the curve. *)
193+
194+
val add : point -> point -> point
195+
(** [add p q] is the sum of points [p] and [q]. *)
196+
197+
val scalar_mult : scalar -> point -> point
198+
(** [scalar_mult s p] is the scalar multiplication of [p] by [s]. *)
199+
end
166200
end
167201

168202
(** The NIST P-256 curve, also known as SECP256R1. *)

tests/test_ec.ml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,91 @@ let ed25519 =
803803
|};
804804
]
805805

806+
let point_module_tests (module C : Mirage_crypto_ec.Dh_dsa) name =
807+
let open C in
808+
let test_generator_not_identity () =
809+
(* Generator should not be the identity (at infinity) *)
810+
let g = Point.generator in
811+
let g_bytes = Point.to_octets g in
812+
(* Generator serialized should not be just the identity point *)
813+
Alcotest.(check bool) "generator has non-trivial encoding"
814+
true (String.length g_bytes > 1)
815+
in
816+
let test_point_serialization_roundtrip () =
817+
(* Generate a key pair and check that the public key roundtrips through Point *)
818+
let _priv, pub = Dsa.generate () in
819+
let pub_bytes = Dsa.pub_to_octets pub in
820+
match Point.of_octets pub_bytes with
821+
| Ok point ->
822+
let point_bytes = Point.to_octets point in
823+
Alcotest.(check string) "point roundtrip" pub_bytes point_bytes
824+
| Error e -> Alcotest.failf "of_octets failed: %a" pp_error e
825+
in
826+
let test_point_compressed_serialization () =
827+
let _priv, pub = Dsa.generate () in
828+
let pub_bytes = Dsa.pub_to_octets pub in
829+
match Point.of_octets pub_bytes with
830+
| Ok point ->
831+
let compressed = Point.to_octets ~compress:true point in
832+
(* Compressed form should be shorter *)
833+
Alcotest.(check bool) "compressed is shorter"
834+
true (String.length compressed < String.length pub_bytes);
835+
(* Should be able to decode compressed form *)
836+
(match Point.of_octets compressed with
837+
| Ok point' ->
838+
let uncompressed = Point.to_octets point' in
839+
Alcotest.(check string) "compressed roundtrip" pub_bytes uncompressed
840+
| Error e -> Alcotest.failf "compressed of_octets failed: %a" pp_error e)
841+
| Error e -> Alcotest.failf "of_octets failed: %a" pp_error e
842+
in
843+
let test_scalar_serialization_roundtrip () =
844+
(* Generate a key and check scalar roundtrip *)
845+
let secret, _pub = Dh.gen_key () in
846+
let secret_bytes = Dh.secret_to_octets secret in
847+
match Point.scalar_of_octets secret_bytes with
848+
| Ok scalar ->
849+
let scalar_bytes = Point.scalar_to_octets scalar in
850+
Alcotest.(check string) "scalar roundtrip" secret_bytes scalar_bytes
851+
| Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e
852+
in
853+
let test_scalar_mult_with_generator () =
854+
(* scalar_mult with generator should give the same result as pub_of_priv *)
855+
let priv, pub = Dsa.generate () in
856+
let priv_bytes = Dsa.priv_to_octets priv in
857+
let pub_bytes = Dsa.pub_to_octets pub in
858+
match Point.scalar_of_octets priv_bytes with
859+
| Ok scalar ->
860+
let computed_pub = Point.scalar_mult scalar Point.generator in
861+
let computed_bytes = Point.to_octets computed_pub in
862+
Alcotest.(check string) "scalar_mult generator" pub_bytes computed_bytes
863+
| Error e -> Alcotest.failf "scalar_of_octets failed: %a" pp_error e
864+
in
865+
let test_point_add () =
866+
(* Test that P + P = 2P (scalar_mult 2 P) *)
867+
let g = Point.generator in
868+
let g_plus_g = Point.add g g in
869+
(* scalar 2 in big-endian encoding *)
870+
let two =
871+
let buf = Bytes.make Dsa.byte_length '\000' in
872+
Bytes.set_uint8 buf (Dsa.byte_length - 1) 2;
873+
Bytes.to_string buf
874+
in
875+
match Point.scalar_of_octets two with
876+
| Ok scalar_2 ->
877+
let two_g = Point.scalar_mult scalar_2 g in
878+
Alcotest.(check string) "G + G = 2G"
879+
(Point.to_octets g_plus_g) (Point.to_octets two_g)
880+
| Error e -> Alcotest.failf "scalar_of_octets 2 failed: %a" pp_error e
881+
in
882+
[
883+
name ^ " Point generator", `Quick, test_generator_not_identity;
884+
name ^ " Point serialization roundtrip", `Quick, test_point_serialization_roundtrip;
885+
name ^ " Point compressed serialization", `Quick, test_point_compressed_serialization;
886+
name ^ " Scalar serialization roundtrip", `Quick, test_scalar_serialization_roundtrip;
887+
name ^ " scalar_mult with generator", `Quick, test_scalar_mult_with_generator;
888+
name ^ " Point add", `Quick, test_point_add;
889+
]
890+
806891
let p521_regression () =
807892
let key = of_hex
808893
"04 01 e4 f8 8a 40 3d fe 2f 65 a0 20 50 01 9b 87
@@ -853,4 +938,7 @@ let () =
853938
("X25519", [ "RFC 7748", `Quick, x25519 ]);
854939
("ED25519", ed25519);
855940
("ECDSA P521 regression", [ "regreesion1", `Quick, p521_regression ]);
941+
("P256 Point module", point_module_tests (module P256) "P256");
942+
("P384 Point module", point_module_tests (module P384) "P384");
943+
("P521 Point module", point_module_tests (module P521) "P521");
856944
]

0 commit comments

Comments
 (0)