Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DCE/Subst01 to work under lambdas #1809

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions src/BoundsPipeline.v
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,6 @@ Module Pipeline.
(E : Expr t)
: DebugM (Expr t)
:= (E <- DoRewrite E;
(* Note that DCE evaluates the expr with two different [var]
arguments, and so results in a pipeline that is 2x slower
unless we pass through a uniformly concrete [var] type
first *)
dlet_nd e := ToFlat E in
let E := FromFlat e in
E <- if with_subst01 return DebugM (Expr t)
then wrap_debug_rewrite ("subst01 for " ++ descr) (Subst01.Subst01 ident.is_comment) E
else if with_dead_code_elimination return DebugM (Expr t)
Expand All @@ -675,8 +669,6 @@ Module Pipeline.
then wrap_debug_rewrite ("LetBindReturn for " ++ descr) (UnderLets.LetBindReturn (@ident.is_var_like)) E
else Debug.ret E;
E <- DoRewrite E; (* after inlining, see if any new rewrite redexes are available *)
dlet_nd e := ToFlat E in
let E := FromFlat e in
E <- if with_dead_code_elimination
then wrap_debug_rewrite ("DCE after " ++ descr) (DeadCodeElimination.EliminateDead ident.is_comment) E
else Debug.ret E;
Expand Down Expand Up @@ -1150,7 +1142,7 @@ Module Pipeline.
first [ progress destruct_head'_and
| progress cbv [Classes.base Classes.ident Classes.ident_interp Classes.base_interp Classes.exprInfo] in *
| progress intros
| progress rewrite_strat repeat topdown hints interp
| progress rewrite_strat repeat topdown choice (hints interp_extra) (hints interp)
| solve [ typeclasses eauto with nocore interp_extra wf_extra ]
| solve [ typeclasses eauto ]
| break_innermost_match_step
Expand Down
70 changes: 70 additions & 0 deletions src/Language/TreeCaching.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
(** * Tree Caching for PHOAS Expressions *)
(** The naive encoding of PHOAS passes that need to produce both
[expr]-like output and data-like output simultaneously involves
exponential blowup.

This file allows caching of results (and/or intermediates) of a
data-producing PHOAS pass in a tree structure that mimics the
PHOAS expression so that a subsequent pass can consume this tree
and a PHOAS expression to produce a new expression.

More concretely, suppose we are trying to write a pass that is
[expr var1 * expr var2 -> A * expr var3]. We can define an
[expr]-like-tree-structure that (a) doesn't use higher-order
things for [Abs] nodes, and (b) stores [A] at every node. Then we
can write a pass that is [expr var1 * expr var2 -> A * tree-of-A]
and then [expr var1 * expr var2 * tree-of-A -> expr var3] such
that we incur only linear overhead.

See also
%\href{https://github.com/mit-plv/fiat-crypto/issues/1604#issuecomment-1553341559}{mit-plv/fiat-crypto\#1604 with option (2)}%
#<a href=https://github.com/mit-plv/fiat-crypto/issues/1604##issuecomment-1553341559">mit-plv/fiat-crypto##1604 with option (2)</a>#
and
%\href{https://github.com/mit-plv/fiat-crypto/issues/1761}{mit-plv/fiat-crypto\#1761}%
#<a href=https://github.com/mit-plv/fiat-crypto/issues/1761#">mit-plv/fiat-crypto##1761</a>#. *)

Require Import Rewriter.Language.Language.

Module Compilers.
Export Language.Compilers.
Local Set Boolean Equality Schemes.
Local Set Decidable Equality Schemes.

Module tree_nd.
Section with_result.
Context {base_type : Type}.
Local Notation type := (type base_type).
Context {ident : type -> Type}
{result : Type}.
Local Notation expr := (@expr.expr base_type ident).

Inductive tree : Type :=
| Ident (r : result) : tree
| Var (r : result) : tree
| Abs (r : result) (f : option tree) : tree
| App (r : result) (f : option tree) (x : option tree) : tree
| LetIn (r : result) (x : option tree) (f : option tree) : tree
.
End with_result.
Global Arguments tree result : clear implicits, assert.
End tree_nd.

Module tree.
Section with_result.
Context {base_type : Type}.
Local Notation type := (type base_type).
Context {ident : type -> Type}
{result : type -> Type}.
Local Notation expr := (@expr.expr base_type ident).

Inductive tree : type -> Type :=
| Ident {t} (r : result t) : tree t
| Var {t} (r : result t) : tree t
| Abs {s d} (r : result (s -> d)) (f : option (tree d)) : tree (s -> d)
| App {s d} (r : result d) (f : option (tree (s -> d))) (x : option (tree s)) : tree d
| LetIn {A B} (r : result B) (x : option (tree A)) (f : option (tree B)) : tree B
.
End with_result.
Global Arguments tree {base_type} {result} t, {base_type} result t : assert.
End tree.
End Compilers.
147 changes: 92 additions & 55 deletions src/MiscCompilerPasses.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ Require Import Coq.MSets.MSetPositive.
Require Import Coq.FSets.FMapPositive.
Require Import Crypto.Util.ListUtil Coq.Lists.List.
Require Import Rewriter.Language.Language.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.Language.TreeCaching.
Import ListNotations. Local Open Scope Z_scope.

Module Compilers.
Export Language.Compilers.
Export Language.TreeCaching.Compilers.
Import invert_expr.

Module Subst01.
Expand All @@ -33,6 +36,8 @@ Module Compilers.
(** some identifiers, like [comment], might always be live *)
(is_ident_always_live : forall t, ident t -> bool).
Local Notation expr := (@expr.expr base_type ident).
(* [option t] is "is the let-in here live?", meaningless elsewhere; the thunk is for debugging *)
Local Notation tree := (@tree_nd.tree (option t * (unit -> positive * list (positive * t)))).
(** N.B. This does not work well when let-binders are not at top-level *)
Fixpoint contains_always_live_ident {var} (dummy : forall t, var t) {t} (e : @expr var t)
: bool
Expand All @@ -46,28 +51,39 @@ Module Compilers.
| expr.LetIn tx tC ex eC
=> contains_always_live_ident dummy ex || contains_always_live_ident dummy (eC (dummy _))
end%bool.
Definition meaningless : option t * (unit -> positive * list (positive * t)) := (None, (fun 'tt => (1%positive, []%list))).
Global Opaque meaningless.
Fixpoint compute_live_counts' {t} (e : @expr (fun _ => positive) t) (cur_idx : positive) (live : PositiveMap.t _)
: positive * PositiveMap.t _
: positive * PositiveMap.t _ * option tree
:= match e with
| expr.Var t v => (cur_idx, PositiveMap_incr v live)
| expr.Ident t idc => (cur_idx, live)
| expr.Var t v
=> let '(idx, live) := (cur_idx, PositiveMap_incr v live) in
(idx, live, Some (tree_nd.Var meaningless))
| expr.Ident t idc
=> let '(idx, live) := (cur_idx, live) in
(idx, live, Some (tree_nd.Ident meaningless))
| expr.App s d f x
=> let '(idx, live) := @compute_live_counts' _ f cur_idx live in
let '(idx, live) := @compute_live_counts' _ x idx live in
(idx, live)
=> let '(idx, live, f_tree) := @compute_live_counts' _ f cur_idx live in
let '(idx, live, x_tree) := @compute_live_counts' _ x idx live in
(idx, live, Some (tree_nd.App meaningless f_tree x_tree))
| expr.Abs s d f
=> let '(idx, live) := @compute_live_counts' _ (f cur_idx) (Pos.succ cur_idx) live in
(cur_idx, live)
=> let '(idx, live, f_tree) := @compute_live_counts' _ (f cur_idx) (Pos.succ cur_idx) live in
(idx, live, Some (tree_nd.Abs meaningless f_tree))
| expr.LetIn tx tC ex eC
=> let '(idx, live) := @compute_live_counts' tC (eC cur_idx) (Pos.succ cur_idx) live in
=> let '(idx, live, C_tree) := @compute_live_counts' tC (eC cur_idx) (Pos.succ cur_idx) live in
let live := if contains_always_live_ident (fun _ => cur_idx (* dummy *)) ex
then PositiveMap_incr_always_live cur_idx live
else live in
if PositiveMap.mem cur_idx live
then @compute_live_counts' tx ex idx live
else (idx, live)
let debug_info := fun 'tt => (Pos.succ cur_idx, PositiveMap.elements live) in
match PositiveMap.find cur_idx live with
| Some x_count
=> let '(x_idx, x_live, x_tree) := @compute_live_counts' tx ex idx live in
(x_idx, x_live, Some (tree_nd.LetIn (Some x_count, debug_info) x_tree C_tree))
| None
=> (idx, live, Some (tree_nd.LetIn (None, debug_info) None C_tree))
end
end%bool.
Definition compute_live_counts {t} e : PositiveMap.t _ := snd (@compute_live_counts' t e 1 (PositiveMap.empty _)).
Definition compute_live_counts {t} e : option tree := snd (@compute_live_counts' t e 1 (PositiveMap.empty _)).
Definition ComputeLiveCounts {t} (e : expr.Expr t) := compute_live_counts (e _).

Section with_var.
Expand All @@ -79,36 +95,61 @@ Module Compilers.
in extraction *)
Context (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1)
{var : type -> Type}
(should_subst : t -> bool)
(live : PositiveMap.t t).
Fixpoint subst0n' {t} (e : @expr (@expr var) t) (cur_idx : positive)
: positive * @expr var t
(should_subst : t -> bool).
(** When [live] is [None], we don't inline anything, just
dropping [var]. This is required for preventing blowup
in inlining lets in unused [LetIn]-bound expressions.
*)
Fixpoint subst0n (live : option tree) {t} (e : @expr (@expr var) t)
: @expr var t
:= match e with
| expr.Var t v => (cur_idx, v)
| expr.Ident t idc => (cur_idx, expr.Ident idc)
| expr.Var t v => v
| expr.Ident t idc => expr.Ident idc
| expr.App s d f x
=> let '(idx, f') := @subst0n' _ f cur_idx in
let '(idx, x') := @subst0n' _ x idx in
(idx, expr.App f' x')
=> let '(f_live, x_live)
:= match live with
| Some (tree_nd.App _ f_live x_live) => (f_live, x_live)
| _ => (None, None)
end%core in
let f' := @subst0n f_live _ f in
let x' := @subst0n x_live _ x in
expr.App f' x'
| expr.Abs s d f
=> (cur_idx, expr.Abs (fun v => snd (@subst0n' _ (f (expr.Var v)) (Pos.succ cur_idx))))
=> let f_tree
:= match live with
| Some (tree_nd.Abs _ f_tree) => f_tree
| _ => None
end in
expr.Abs (fun v => @subst0n f_tree _ (f (expr.Var v)))
| expr.LetIn tx tC ex eC
=> let '(idx, ex') := @subst0n' tx ex cur_idx in
let eC' := fun v => snd (@subst0n' tC (eC v) (Pos.succ cur_idx)) in
if match PositiveMap.find cur_idx live with
| Some n => should_subst n
| None => true
end
then (Pos.succ cur_idx, eC' (doing_subst_debug _ _ ex' (fun 'tt => (Pos.succ cur_idx, PositiveMap.elements live))))
else (Pos.succ cur_idx, expr.LetIn ex' (fun v => eC' (expr.Var v)))
=> match live with
| Some (tree_nd.LetIn (x_count, debug_info) x_tree C_tree)
=> let ex' := @subst0n x_tree tx ex in
let eC' := fun v => @subst0n C_tree tC (eC v) in
if match x_count with
| Some n => should_subst n
| None => true
end
then eC' (doing_subst_debug _ _ ex' debug_info)
else expr.LetIn ex' (fun v => eC' (expr.Var v))
| _
=> let ex' := @subst0n None tx ex in
let eC' := fun v => @subst0n None tC (eC v) in
expr.LetIn ex' (fun v => eC' (expr.Var v))
end
end.

Definition subst0n {t} e : expr t
:= snd (@subst0n' t e 1).
End with_var.

Definition Subst0n (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1) (should_subst : t -> bool) {t} (e : expr.Expr t) : expr.Expr t
:= fun var => subst0n doing_subst_debug should_subst (ComputeLiveCounts e) (e _).
Section with_transport.
Context {try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT type (@expr var)}.
(** We pass through [Flat] to ensure that the passed in
[Expr] only gets invoked at a single [var] type *)
Definition Subst0n (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1) (should_subst : t -> bool) {t} (E : expr.Expr t) : expr.Expr t
:= dlet_nd e := GeneralizeVar.ToFlat E in
let E := GeneralizeVar.FromFlat e in
fun var => subst0n doing_subst_debug should_subst (ComputeLiveCounts E) (E _).
End with_transport.
End with_ident.
End with_counter.

Expand All @@ -122,34 +163,30 @@ Module Compilers.
| more => false
end.

Definition Subst01 {base_type ident} (is_ident_always_live : forall t, ident t -> bool) {t} (e : expr.Expr t) : expr.Expr t
:= @Subst0n _ one incr (fun _ => more) base_type ident is_ident_always_live (fun _ _ x _ => x) should_subst t e.
Definition Subst01
{base_type ident}
{try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT _ _}
(is_ident_always_live : forall t, ident t -> bool)
{t} (e : expr.Expr t) : expr.Expr t
:= @Subst0n _ one incr (fun _ => more) base_type ident is_ident_always_live try_make_transport_base_type_cps exprDefault (fun _ _ x _ => x) should_subst t e.
End for_01.
End Subst01.

Module DeadCodeElimination.
Section with_ident.
Context {base_type : Type}.
Local Notation type := (type.type base_type).
Context {ident : type -> Type}
(is_ident_always_live : forall t, ident t -> bool).
Local Notation expr := (@expr.expr base_type ident).

Definition OUGHT_TO_BE_UNUSED {T1 T2} (v : T1) (v' : T2) := v.
Global Opaque OUGHT_TO_BE_UNUSED.

Definition ComputeLive {t} (e : expr.Expr t) : PositiveMap.t unit
:= @Subst01.ComputeLiveCounts unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live _ e.
Definition is_live (map : PositiveMap.t unit) (idx : positive) : bool
:= match PositiveMap.find idx map with
| Some tt => true
| None => false
end.
Definition is_dead (map : PositiveMap.t unit) (idx : positive) : bool
:= negb (is_live map idx).

Definition EliminateDead {t} (e : expr.Expr t) : expr.Expr t
:= @Subst01.Subst0n unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live (fun T1 T2 => OUGHT_TO_BE_UNUSED) (fun _ => false) t e.
Definition EliminateDead
{ident : type -> Type}
{try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT _ _}
(is_ident_always_live : forall t, ident t -> bool)
{t} (e : expr.Expr t)
: expr.Expr t
:= @Subst01.Subst0n unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live try_make_transport_base_type_cps exprDefault (fun T1 T2 => OUGHT_TO_BE_UNUSED) (fun _ => false) t e.
End with_ident.
End DeadCodeElimination.
End Compilers.
Loading
Loading