Skip to content

Commit

Permalink
Add loop-unrolling debug rewrite for SystemVerilog (#744)
Browse files Browse the repository at this point in the history
* Add loop-unrolling debug rewrite for SystemVerilog

It can be useful for debug of SystemVerilog generation.
Loops being a relevant source of issues in the generation,
being able to replace loops with their expanded equivalent
helps us determine whether an issue is related to loops or not.

* Format with ocamlformat 0.26.0

* Replace print_endline with Reporting.warn for warning

* Add option to override warning supression on rewrites

---------

Co-authored-by: Nicolas Phan <nicolas.phan@codasip.com>
  • Loading branch information
NicolasVanPhan and Nicolas Phan authored Oct 29, 2024
1 parent b74c771 commit 40a6548
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/bin/sail.ml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ let rec options =
("-dno_cast", Arg.Unit (fun () -> ()), "");
(* No longer does anything, preserved for backwards compatibility only *)
("-dallow_cast", Arg.Unit (fun () -> ()), "");
("-unroll_loops", Arg.Set Rewrites.opt_unroll_loops, " turn on rewrites for unrolling loops with constant bounds.");
( "-unroll_loops_max_iter",
Arg.Int (fun n -> Rewrites.opt_unroll_loops_max_iter := n),
"<nb_iter> Don't unroll loops if they have more than <nb_iter> iterations."
);
( "-ddump_rewrite_ast",
Arg.String
(fun l ->
Expand Down
2 changes: 2 additions & 0 deletions src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ val def_loc : ('a, 'b) def -> Parse_ast.l
Note: For debugging and error messages only - not guaranteed to
produce parseable Sail, or even print all language constructs! *)

val string_of_order : order -> string

val string_of_id : id -> string
val string_of_kid : kid -> string
val string_of_kind_aux : kind_aux -> string
Expand Down
4 changes: 2 additions & 2 deletions src/lib/reporting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ let suppressed_warning_info () =
suppressed_warnings := 0
)

let warn ?once_from short_str l explanation =
let warn ?once_from ?(force_show = false) short_str l explanation =
let already_shown =
match once_from with
| Some (file, lnum, cnum, enum) when not !opt_all_warnings ->
Expand All @@ -268,7 +268,7 @@ let warn ?once_from short_str l explanation =
)
| _ -> false
in
if !opt_warnings && not already_shown then (
if (!opt_warnings && not already_shown) || force_show then (
match simp_loc l with
| Some (p1, p2) when not (StringSet.mem p1.pos_fname !ignored_files) ->
let shorts = RangeMap.find_opt (p1, p2) !seen_warnings |> Option.value ~default:[] in
Expand Down
2 changes: 1 addition & 1 deletion src/lib/reporting.mli
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ val forbid_errors : string * int * int * int -> ('a -> 'b) -> 'a -> 'b

(** Print a warning message. The first string is printed before the
location, the second after. *)
val warn : ?once_from:string * int * int * int -> string -> Parse_ast.l -> string -> unit
val warn : ?once_from:string * int * int * int -> ?force_show:bool -> string -> Parse_ast.l -> string -> unit

val format_warn : ?once_from:string * int * int * int -> string -> Parse_ast.l -> Error_format.message -> unit

Expand Down
120 changes: 120 additions & 0 deletions src/lib/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4186,6 +4186,125 @@ let rewrite_truncate_hex_literals _type_env defs =
{ rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) }
defs

let opt_unroll_loops = ref false
let opt_unroll_loops_max_iter = ref 0

(** The loop unrolling pass replaces :
{[
foreach k in 0 to 3 by 1 increasing:
f(k, foo, bar) + k
]}
with
{[
f(0, foo, bar) + 0;
f(1, foo, bar) + 1;
f(2, foo, bar) + 2;
f(3, foo, bar) + 3;
}]
*)
let rewrite_unroll_constant_loops _type_env defs =
(* This pass replaces expressions like
f(k, foo, bar) + k
with
f(42, foo, bar) + 42
*)
let rewrite_exp_replace_id_with_num (id : string) (num : Z.t) =
let rewrite_aux (id : string) (num : Z.t) (e, annot) =
match e with
| E_id (Id_aux (Id v, _)) when String.equal v id -> E_aux (E_lit (L_aux (L_num num, Parse_ast.Unknown)), annot)
| _ -> E_aux (e, annot)
in
fun e ->
rewrite_exp
{ rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux id num }) }
e
in

(* Builds a list of integers from a start, end, step and direction
[list_of_ord_range 0 10 2 inc = 0, 2, 4, 6, 8, 10]
[list_of_ord_range 7 0 1 dec = 7, 6, 5, 4, 3, 2, 1, 0]
*)
let list_of_ord_range (Ord_aux (ord, _)) (n_start : Z.t) (n_end : Z.t) (n_step : Z.t) : Z.t list =
let list_of_range ns ne =
let rec aux n acc = if Z.gt n ne then acc else aux (Z.add n n_step) (n :: acc) in
aux ns []
in
match ord with Ord_inc -> List.rev @@ list_of_range n_start n_end | Ord_dec -> list_of_range n_end n_start
in

(* This is the main rewrite function of the pass.
Replacing :
{[
foreach k in 0 to 3 by 1 increasing:
f(k, foo, bar) + k
]}
with
{[
f(0, foo, bar) + 0;
f(1, foo, bar) + 1;
f(2, foo, bar) + 2;
f(3, foo, bar) + 3;
]}
*)
let rewrite_aux (e, annot) =
match e with
| E_for
( id (* 'k' in our example *),
E_aux (_, (_, tannot1)) (* '0' in our example *),
E_aux (_, (_, tannot2)) (* '3' in our example *),
E_aux (_, (_, tannot3)) (* '1' in our example *),
atyp (* 'increasing' in our example *),
e_loop_body (* 'f(k, foo, bar) + k' in our example *)
) -> (
(* We get the int values of the bounds from their types inferred by the typer *)
let int_of_tannot_opt tannot =
let int_of_typ_aux_opt : typ_aux -> Z.t option = function
| Typ_app (id, [A_aux (A_nexp nexp, _)]) when Id.compare id (mk_id "atom") = 0 ->
int_of_nexp_opt @@ nexp_simp nexp
| _ -> None
in
match destruct_tannot tannot with Some (_, Typ_aux (typ, l)) -> int_of_typ_aux_opt typ | None -> None
in
let n_start_opt = int_of_tannot_opt tannot1 in
let n_end_opt = int_of_tannot_opt tannot2 in
let n_step_opt = int_of_tannot_opt tannot3 in

(* Abort unrolling with a warning when types infered are too complex (i.e. not immediate literals) *)
match (n_start_opt, n_end_opt, n_step_opt) with
| Some n_start, Some n_end, Some n_step ->
(* Build a range out of the bounds
in our example, from (start=0, end=3, step=1)
the range will be the list [0; 1; 2; 3]
*)
let range = list_of_ord_range atyp n_start n_end n_step in

(* Only unroll "small" loops, i.e. those with less than 'max_iter' iterations *)
if !opt_unroll_loops_max_iter <> 0 && List.length range > !opt_unroll_loops_max_iter then E_aux (e, annot)
else (
(* Build the final expression, a block of n times the body *)
let bodies = List.map (fun z -> rewrite_exp_replace_id_with_num "i" z e_loop_body) range in
E_aux (E_block bodies, annot) (* else *)
)
(* Some n_start... *)
| _ ->
let e' = E_aux (e, annot) in
let (l : Parse_ast.l), _tannot = annot in
Reporting.warn ~force_show:true "" l
@@ Printf.sprintf
"Cannot unroll the loop because the bounds numerical values couldn't be fully determined on \
expression :\n\
%s"
(string_of_exp e');
e'
)
| _ -> E_aux (e, annot)
in
rewrite_ast_base
{ rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) }
defs

(** Remove bitfield records and turn them into plain bitvectors
This can improve performance for Isabelle, because processing record types is slow there
(and we don't gain much by having record types with just a `bits` field).
Expand Down Expand Up @@ -4418,6 +4537,7 @@ let all_rewriters =
("pat_string_append", basic_rewriter rewrite_ast_pat_string_append);
("mapping_patterns", basic_rewriter (fun _ -> Mappings.rewrite_ast));
("truncate_hex_literals", basic_rewriter rewrite_truncate_hex_literals);
("unroll_constant_loops", basic_rewriter rewrite_unroll_constant_loops);
("mono_rewrites", basic_rewriter mono_rewrites);
("complete_record_params", basic_rewriter rewrite_complete_record_params);
("toplevel_nexps", basic_rewriter rewrite_toplevel_nexps);
Expand Down
5 changes: 5 additions & 0 deletions src/lib/rewrites.mli
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ val opt_auto_mono : bool ref
val opt_dall_split_errors : bool ref
val opt_dmono_continue : bool ref

(** Unroll loops with constant bounds if less than 'max_iter' iterations *)
val opt_unroll_loops : bool ref

val opt_unroll_loops_max_iter : int ref

(** Warn about matches where we add a default case for Coq because
they're not exhaustive *)
val opt_coq_warn_nonexhaustive : bool ref
Expand Down
1 change: 1 addition & 0 deletions src/sail_sv_backend/sail_plugin_sv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ let verilog_rewrites =
("merge_function_clauses", []);
("recheck_defs", []);
("constant_fold", [String_arg "systemverilog"]);
("unroll_constant_loops", [If_flag opt_unroll_loops]);
]

module type JIB_CONFIG = sig
Expand Down

0 comments on commit 40a6548

Please sign in to comment.