Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions ac_simplifier/elpi/theories/AC.v
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
From Coq Require Import ZArith.

Fixpoint nth {T : Type} (x0 : T) (s : list T) (n : nat) {struct n} : T :=
match s, n with
| cons x s', S n' => nth x0 s' n'
| cons x _, 0 => x
| _, _ => x0
end.

Fixpoint ncons {T : Type} (n : nat) (x : T) (s : list T) : list T :=
(* ncons n x s = x :: ... :: x :: s where x is repeated for n times *)
Fixpoint ncons {T : Type} (n : nat) (x : T) (s : list T) : list T :=
match n with
| S n' => cons x (ncons n' x s)
| 0 => s
end.

Fixpoint all {T : Type} (p : T -> bool) (xs : list T) : bool :=
match xs with
| cons x xs' => p x && all p xs'
| _ => true
end.

(* reified additive group expressions *)
Inductive AGExpr : Set :=
| AGX : nat -> AGExpr (* variable *)
Expand All @@ -43,34 +31,53 @@ Fixpoint ZMnorm (e : AGExpr) : list Z :=

Section AGeval.

(* We assume the carrier type and the additive group operators on it as *)
(* section variables. *)
Context {G : Type} (zeroG : G) (oppG : G -> G) (addG : G -> G -> G).

(* the interpretation function for AGExpr *)
Fixpoint AGeval (vm : list G) (e : AGExpr) : G :=
match e with
| AGX j => nth zeroG vm j
| AGX j => List.nth j vm zeroG
| AGO => zeroG
| AGOpp e1 => oppG (AGeval vm e1)
| AGAdd e1 e2 => addG (AGeval vm e1) (AGeval vm e2)
end.

(* multiplication of a group element by a binary integer *)
Definition mulGz (x : G) (n : Z) : G :=
match n with
| Z0 => zeroG
| Zpos p => Pos.iter (fun y => addG y x) zeroG p
| Zneg p => oppG (Pos.iter (fun y => addG y x) zeroG p)
end.

(* the interpretation function for formal sums *)
Fixpoint ZMsubst (vm : list G) (e : list Z) {struct e} : G :=
match e, vm with
| cons n e', cons x vm' => addG (mulGz x n) (ZMsubst vm' e')
| _, _ => zeroG
end.

(* an auxiliary function for the Z-module simplifier *)
Fixpoint ZMsimpl_aux (vm : list G) (e : list Z) {struct e} : list G * list G :=
match e, vm with
| cons Z0 e', cons _ vm' => ZMsimpl_aux vm' e'
| cons (Z.pos p) e', cons x vm' =>
let '(e1, e2) := ZMsimpl_aux vm' e' in (Pos.iter (cons x) e1 p, e2)
| cons (Z.neg p) e', cons x vm' =>
let '(e1, e2) := ZMsimpl_aux vm' e' in (e1, Pos.iter (cons x) e2 p)
| _, _ => (nil, nil)
end.

(* We assume the commutative group axioms on G and its operators as section *)
(* variables. *)
Context (addA : forall x y z : G, addG x (addG y z) = addG (addG x y) z).
Context (addC : forall x y : G, addG x y = addG y x).
Context (add0x : forall x : G, addG zeroG x = x).
Context (addNx : forall x : G, addG (oppG x) x = zeroG).

(* some facts about the commutative group structure *)
Let addx0 x : addG x zeroG = x.
Proof. now rewrite addC, add0x. Qed.

Expand Down Expand Up @@ -136,6 +143,7 @@ destruct n as [|n|n], m as [|m|m]; rewrite ?add0x, ?addx0; trivial.
now rewrite <- Hpos, oppD.
Qed.

(* the correctness lemma for normalization *)
Lemma ZM_norm_subst (vm : list G) (e : AGExpr) :
ZMsubst vm (ZMnorm e) = AGeval vm e.
Proof.
Expand All @@ -155,27 +163,89 @@ induction e as [j| |e IHe|e1 IHe1 e2 IHe2]; simpl; trivial.
now rewrite addACA, <-IHxs, mulzDl.
Qed.

(* the reflection lemma for proving valid commutative group equations *)
Lemma ZM_correct (vm : list G) (e1 e2 : AGExpr) :
let isZero (n : Z) := match n with Z0 => true | _ => false end in
all isZero (ZMnorm (AGAdd e1 (AGOpp e2))) = true ->
List.forallb isZero (ZMnorm (AGAdd e1 (AGOpp e2))) = true ->
AGeval vm e1 = AGeval vm e2.
Proof.
set (e := AGAdd e1 (AGOpp e2)); intros isZero Hzeros.
rewrite <- (addx0 (AGeval vm e1)), <- (add0x (AGeval vm e2)).
rewrite <- (addNx (AGeval vm e2)), addA at 1; f_equal.
change (addG (AGeval vm e1) (oppG (AGeval vm e2))) with (AGeval vm e).
rewrite <- !ZM_norm_subst; revert vm Hzeros.
rewrite <- ZM_norm_subst; revert vm Hzeros.
induction (ZMnorm e) as [|x xs IHxs]; destruct vm as [|v vm]; simpl; trivial.
now destruct x; try discriminate; rewrite add0x; apply IHxs.
Qed.

(* the reflection lemma for simplifying commutative group equations *)
Lemma ZMsimpl_correct (vm : list G) (e1 e2 : AGExpr) :
let sum zero add zs :=
match zs with cons z zs => List.fold_left add zs z | nil => zero end
in
(forall zero add vm', zero = zeroG -> add = addG -> vm' = vm ->
let norm := ZMnorm (AGAdd e1 (AGOpp e2)) in
let '(xs, ys) := ZMsimpl_aux vm' norm in
sum zero add xs = sum zero add ys) ->
Comment on lines +186 to +189
Copy link
Collaborator Author

@pi8027 pi8027 Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the proof obligation of the Z_zmodule_simplify tactic: zero, add, and vm' are abstracted out to make sure that they will not be reduced. I thought that norm should be reduced by vm_compute, and the other part should be reduced by lazy or compute because its readback seems to be costly.

It is quite intricate, so should be documented in detail.

AGeval vm e1 = AGeval vm e2.
Proof.
set (e := AGAdd e1 (AGOpp e2)).
intros sum H; generalize (H _ _ _ eq_refl eq_refl eq_refl); clear H; cbv zeta.
case_eq (ZMsimpl_aux vm (ZMnorm e)); intros xs ys simplE Hsum.
rewrite <- (addx0 (AGeval vm e1)), <- (add0x (AGeval vm e2)).
rewrite <- (addNx (AGeval vm e2)), addA at 1; f_equal.
change (addG (AGeval vm e1) (oppG (AGeval vm e2))) with (AGeval vm e).
rewrite <- ZM_norm_subst.
replace zeroG with (addG (sum zeroG addG xs) (oppG (sum zeroG addG ys)))
by now rewrite Hsum; rewrite addxN.
assert (sumE : forall zs, sum zeroG addG zs = List.fold_right addG zeroG zs).
destruct zs as [|z zs]; simpl; trivial.
revert z; induction zs as [|z' zs IHzs]; intro z; simpl.
now rewrite addx0.
now rewrite IHzs, addA.
rewrite !sumE; clear sum Hsum sumE; revert vm xs ys simplE.
induction (ZMnorm e) as [|z zs IHzs]; destruct vm as [|v vm]; intros xs ys.
- now intro H; injection H; intros; subst; simpl; rewrite addxN.
- now intro H; injection H; intros; subst; simpl; rewrite addxN.
- now destruct z; intro H; injection H; intros; subst; simpl; rewrite addxN.
- destruct z as [|p|p]; simpl; [rewrite add0x; apply IHzs| |].
generalize (IHzs vm); clear IHzs.
destruct (ZMsimpl_aux vm zs) as [xs' ys']; intro IHzs.
rewrite (IHzs xs' ys' eq_refl); intro H; injection H; clear H IHzs.
intros; subst xs ys; rewrite addA, !Pos2Nat.inj_iter; f_equal.
induction (Pos.to_nat p) as [|n IHn]; simpl; [now rewrite add0x |].
now rewrite addAC, IHn, (addC v).
generalize (IHzs vm); clear IHzs.
destruct (ZMsimpl_aux vm zs) as [xs' ys']; intro IHzs.
rewrite (IHzs xs' ys' eq_refl); intro H; injection H; clear H IHzs.
intros; subst xs ys; rewrite addCA, <- oppD, !Pos2Nat.inj_iter.
f_equal; f_equal.
induction (Pos.to_nat p) as [|n IHn]; simpl; [now rewrite add0x |].
now rewrite addAC, IHn, (addC v).
Qed.

End AGeval.

(* the reflection lemmas specialized to binary integers Z *)
Fact ZM_correct_Z (vm : list Z) (e1 e2 : AGExpr) :
all (Z.eqb 0) (ZMnorm (AGAdd e1 (AGOpp e2))) = true ->
List.forallb (Z.eqb 0) (ZMnorm (AGAdd e1 (AGOpp e2))) = true ->
AGeval Z0 Z.opp Z.add vm e1 = AGeval Z0 Z.opp Z.add vm e2.
Proof.
now apply (ZM_correct _ _ _ Z.add_assoc Z.add_comm Z.add_0_l Z.add_opp_diag_l).
Qed.

Lemma ZMsimpl_correct_Z (vm : list Z) (e1 e2 : AGExpr) :
let sum zero add zs :=
match zs with cons z zs => List.fold_left add zs z | nil => zero end
in
(forall zero add vm', zero = Z0 -> add = Z.add -> vm' = vm ->
let norm := ZMnorm (AGAdd e1 (AGOpp e2)) in
let '(xs, ys) := ZMsimpl_aux vm' norm in
sum zero add xs = sum zero add ys) ->
AGeval Z0 Z.opp Z.add vm e1 = AGeval Z0 Z.opp Z.add vm e2.
Proof.
now apply
(ZMsimpl_correct _ _ _ Z.add_assoc Z.add_comm Z.add_0_l Z.add_opp_diag_l).
Qed.

Strategy expand [AGeval].
72 changes: 62 additions & 10 deletions ac_simplifier/elpi/theories/Tactic.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@ From elpi Require Import elpi.
From AC Require Import AC.

(******************************************************************************)
(* The Z_zmodule tactic *)
(* Tactics specific to Z *)
(******************************************************************************)

Ltac Z_zmodule_reflection VM ZE1 ZE2 :=
apply (@ZM_correct_Z VM ZE1 ZE2); [vm_compute; reflexivity].

Elpi Tactic Z_zmodule.
Elpi Accumulate lp:{{

pred mem o:list term, o:term, o:term.
mem [X|_] X {{ O }} :- !.
mem [_|XS] X {{ S lp:N }} :- !, mem XS X N.
Elpi Db Z_reify lp:{{

pred quote i:term, o:term, o:list term.
quote {{ Z0 }} {{ AGO }} _ :- !.
Expand All @@ -23,10 +15,25 @@ quote {{ Z.add lp:In1 lp:In2 }} {{ AGAdd lp:Out1 lp:Out2 }} VM :- !,
quote In1 Out1 VM, quote In2 Out2 VM.
quote In {{ AGX lp:N }} VM :- !, mem VM In N.

pred mem o:list term, o:term, o:term.
mem [X|_] X {{ O }} :- !.
mem [_|XS] X {{ S lp:N }} :- !, mem XS X N.

pred list-constant o:term, o:list term, o:term.
list-constant T [] {{ @nil lp:T }} :- !.
list-constant T [X|XS] {{ @cons lp:T lp:X lp:XS' }} :- list-constant T XS XS'.

}}.

(* The Z_zmodule tactic *)

Ltac Z_zmodule_reflection VM ZE1 ZE2 :=
apply (@ZM_correct_Z VM ZE1 ZE2); [vm_compute; reflexivity].

Elpi Tactic Z_zmodule.
Elpi Accumulate Db Z_reify.
Elpi Accumulate lp:{{

pred solve i:goal, o:list sealed-goal.
solve (goal _ _ {{ @eq Z lp:T1 lp:T2 }} _ _ as Goal) GS :- !,
quote T1 ZE1 VM, !,
Expand All @@ -45,3 +52,48 @@ zmod-reflection _ _ _ _ _ :-
Elpi Typecheck.

Tactic Notation "Z_zmodule" := (elpi Z_zmodule).

(* The Z_zmodule_simplify tactic *)

Ltac Z_zmodule_simplify_reflection VM ZE1 ZE2 :=
let zero := fresh "zero" in
let add := fresh "add" in
let vm := fresh "vm" in
let zeroE := fresh "zeroE" in
let addE := fresh "addE" in
let vmE := fresh "vmE" in
let norm := fresh "norm" in
apply (@ZMsimpl_correct_Z VM ZE1 ZE2);
intros zero add vm zeroE addE vmE norm;
vm_compute in norm; compute;
rewrite zeroE, addE, vmE; clear zero add vm zeroE addE vmE norm.
Comment on lines +67 to +69
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is the normalization of the proof obligation.


Elpi Tactic Z_zmodule_simplify.
Elpi Accumulate Db Z_reify.
Elpi Accumulate lp:{{

pred solve i:goal, o:list sealed-goal.
solve (goal _ _ {{ @eq Z lp:T1 lp:T2 }} _ _ as Goal) GS :- !,
quote T1 ZE1 VM, !,
quote T2 ZE2 VM, !,
list-constant {{ Z }} VM VM', !,
zmod-simpl-reflection VM' ZE1 ZE2 Goal GS.
solve _ _ :- coq.ltac.fail 0 "The goal is not an equation".

pred zmod-simpl-reflection i:term, i:term, i:term, i:goal, o:list sealed-goal.
zmod-simpl-reflection VM ZE1 ZE2 G GS :-
coq.ltac.call "Z_zmodule_simplify_reflection" [trm VM, trm ZE1, trm ZE2] G GS.
zmod-simpl-reflection _ _ _ _ _ :-
coq.ltac.fail 0 "Reflection failed".

}}.
Elpi Typecheck.

Tactic Notation "Z_zmodule_simplify" := (elpi Z_zmodule_simplify).

Goal forall x y y' z, y = y' -> (x + y + - z + y = x + - z + y + y')%Z.
Proof.
intros x y y' z Hy.
Z_zmodule_simplify.
exact Hy.
Qed.
6 changes: 6 additions & 0 deletions ac_simplifier/elpi/theories/Tests.v
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ Proof.
intros x y z.
Z_zmodule.
Qed.

Goal forall x y y' z, y = y' -> (x + y + - z + y = x + - z + y + y')%Z.
Proof.
intros x y y' z Hy.
Z_zmodule_simplify; exact Hy.
Qed.