diff --git a/dune-project b/dune-project index a8bc5091e4..7b35af2834 100644 --- a/dune-project +++ b/dune-project @@ -24,4 +24,5 @@ (why3 (and (>= 1.7.0) (< 1.8))) yojson (zarith (>= 1.10)) -)) + ) +) diff --git a/examples/tcstdlib/TcMonoid.ec b/examples/tcstdlib/TcMonoid.ec new file mode 100644 index 0000000000..f33a9da550 --- /dev/null +++ b/examples/tcstdlib/TcMonoid.ec @@ -0,0 +1,35 @@ +require import Int. + +(* -------------------------------------------------------------------- *) +type class monoid = { + op idm : monoid + op (+) : monoid -> monoid -> monoid + + axiom addmA: associative (+) + axiom addmC: commutative (+) + axiom add0m: left_id idm (+) +}. + +(* -------------------------------------------------------------------- *) +section. +declare type m <: monoid. + +lemma addm0: right_id idm (+)<:m>. +proof. by move=> x; rewrite addmC add0m. qed. + +lemma addmCA: left_commutative (+)<:m>. +proof. by move=> x y z; rewrite !addmA (addmC x). qed. + +lemma addmAC: right_commutative (+)<:m>. +proof. by move=> x y z; rewrite -!addmA (addmC y). qed. + +lemma addmACA: interchange (+)<:m> (+). +proof. by move=> x y z t; rewrite -!addmA (addmCA y). qed. + +lemma iteropE n (x : m): iterop n (+) x idm = iter n ((+) x) idm. +proof. +elim/natcase n => [n le0_n|n ge0_n]. ++ by rewrite ?(iter0, iterop0). ++ by rewrite iterSr // addm0 iteropS. +qed. +end section. diff --git a/examples/tcstdlib/TcRing.ec b/examples/tcstdlib/TcRing.ec new file mode 100644 index 0000000000..7213ba5f32 --- /dev/null +++ b/examples/tcstdlib/TcRing.ec @@ -0,0 +1,857 @@ +pragma +implicits. + +(* -------------------------------------------------------------------- *) +require import Core Int TcMonoid. + +(* -------------------------------------------------------------------- *) +type class group <: monoid = { + op [ - ] : group -> group + + axiom addNr: left_inverse idm [-] (+)<:group> +}. + +section. +declare type g <: group. + +abbrev zeror = idm<:g>. +abbrev ( - ) (x y : g) = x + -y. + +(* -------------------------------------------------------------------- *) +lemma nosmt addrA: associative (+)<:g>. +proof. by exact: addmA. qed. + +lemma nosmt addrC: commutative (+)<:g>. +proof. by exact: addmC. qed. + +lemma nosmt add0r: left_id zeror (+)<:g>. +proof. by exact: add0m. qed. + +(* -------------------------------------------------------------------- *) +lemma nosmt addr0: right_id zeror (+)<:g>. +proof. by move=> x; rewrite addrC add0r. qed. + +lemma nosmt addrN: right_inverse zeror [-] (+)<:g>. +proof. by move=> x; rewrite addrC addNr. qed. + +lemma nosmt addrCA: left_commutative (+)<:g>. +proof. by move=> x y z; rewrite !addrA (@addrC x y). qed. + +lemma nosmt addrAC: right_commutative (+)<:g>. +proof. by move=> x y z; rewrite -!addrA (@addrC y z). qed. + +lemma nosmt addrACA: interchange (+)<:g> (+)<:g>. +proof. by move=> x y z t; rewrite -!addrA (addrCA y). qed. + +lemma nosmt subrr (x : g): x - x = zeror. +proof. by rewrite addrN. qed. + +lemma nosmt addKr: left_loop [-] (+)<:g>. +proof. by move=> x y; rewrite addrA addNr add0r. qed. + +lemma nosmt addNKr: rev_left_loop [-] (+)<:g>. +proof. by move=> x y; rewrite addrA addrN add0r. qed. + +lemma nosmt addrK: right_loop [-] (+)<:g>. +proof. by move=> x y; rewrite -addrA addrN addr0. qed. + +lemma nosmt addrNK: rev_right_loop [-] (+)<:g>. +proof. by move=> x y; rewrite -addrA addNr addr0. qed. + +lemma nosmt subrK (x y : g): (x - y) + y = x. +proof. by rewrite addrNK. qed. + +lemma nosmt addrI: right_injective (+)<:g>. +proof. by move=> x y z h; rewrite -(@addKr x z) -h addKr. qed. + +lemma nosmt addIr: left_injective (+)<:g>. +proof. by move=> x y z h; rewrite -(@addrK x z) -h addrK. qed. + +lemma nosmt opprK: involutive [-]<:g>. +proof. by move=> x; apply (@addIr (-x)); rewrite addNr addrN. qed. + +lemma nosmt oppr_inj : injective [-]<:g>. +proof. by move=> x y eq; apply/(addIr (-x)); rewrite subrr eq subrr. qed. + +lemma nosmt oppr0 : -zeror = zeror. +proof. by rewrite -(@addr0 (-zeror)) addNr. qed. + +lemma nosmt oppr_eq0 (x : g) : (- x = zeror) <=> (x = zeror). +proof. by rewrite (inv_eq opprK) oppr0. qed. + +lemma nosmt subr0 (x : g): x - zeror = x. +proof. by rewrite oppr0 addr0. qed. + +lemma nosmt sub0r (x : g): zeror - x = - x. +proof. by rewrite add0r. qed. + +lemma nosmt opprD (x y : g): -(x + y) = -x + -y. +proof. by apply (@addrI (x + y)); rewrite addrA addrN addrAC addrK addrN. qed. + +lemma nosmt opprB (x y : g): -(x - y) = y - x. +proof. by rewrite opprD opprK addrC. qed. + +lemma nosmt subrACA: interchange (-) (+)<:g>. +proof. by move=> x y z t; rewrite addrACA opprD. qed. + +lemma nosmt subr_eq (x y z : g): + (x - z = y) <=> (x = y + z). +proof. +move: (can2_eq (fun x, x - z) (fun x, x + z) _ _ x y) => //=. ++ by move=> {x} x /=; rewrite addrNK. ++ by move=> {x} x /=; rewrite addrK. +qed. + +lemma nosmt subr_eq0 (x y : g): (x - y = zeror) <=> (x = y). +proof. by rewrite subr_eq add0r. qed. + +lemma nosmt addr_eq0 (x y : g): (x + y = zeror) <=> (x = -y). +proof. by rewrite -(@subr_eq0 x) opprK. qed. + +lemma nosmt eqr_opp (x y : g): (- x = - y) <=> (x = y). +proof. by apply/(@can_eq _ _ opprK x y). qed. + +lemma nosmt eqr_oppLR (x y : g) : (- x = y) <=> (x = - y). +proof. by apply/(@inv_eq _ opprK x y). qed. + +lemma nosmt eqr_sub (x y z t : g) : (x - y = z - t) <=> (x + t = z + y). +proof. +rewrite -{1}(addrK t x) -{1}(addrK y z) -!addrA. +by rewrite (addrC (-t)) !addrA; split=> [/addIr /addIr|->//]. +qed. + +lemma nosmt subr_add2r (z x y : g): (x + z) - (y + z) = x - y. +proof. by rewrite opprD addrACA addrN addr0. qed. + +op intmul (x : g) (n : int) = + (* (signz n) * (iterop `|n| ZModule.(+) x zeror) *) + if n < 0 + then -(iterop (-n) (+)<:g> x zeror) + else (iterop n (+)<:g> x zeror). + +lemma nosmt intmulpE (z : g) c : 0 <= c => + intmul z c = iterop c (+)<:g> z zeror. +proof. by rewrite /intmul lezNgt => ->. qed. + +lemma nosmt mulr0z (x : g): intmul x 0 = zeror. +proof. by rewrite /intmul /= iterop0. qed. + +lemma nosmt mulr1z (x : g): intmul x 1 = x. +proof. by rewrite /intmul /= iterop1. qed. + +lemma nosmt mulr2z (x : g): intmul x 2 = x + x. +proof. by rewrite /intmul /= (@iteropS 1) // (@iterS 0) // iter0. qed. + +lemma nosmt mulrNz (x : g) (n : int): intmul x (-n) = -(intmul x n). +proof. +case: (n = 0)=> [->|nz_c]; first by rewrite oppz0 mulr0z oppr0. +rewrite /intmul oppz_lt0 oppzK ltz_def nz_c lezNgt /=. +by case: (n < 0); rewrite ?opprK. +qed. + +lemma nosmt mulrS (x : g) (n : int): 0 <= n => + intmul x (n+1) = x + intmul x n. +proof. +move=> ge0n; rewrite !intmulpE 1:addz_ge0 //. +by rewrite !iteropE iterS. +qed. + +lemma nosmt mulNrz (x : g) n : intmul (- x) n = - (intmul x n). +proof. +elim/intwlog: n => [n h| | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@mulrNz _ (- n)) h. ++ by rewrite !mulr0z oppr0. ++ by rewrite !mulrS // ih opprD. +qed. + +lemma nosmt mulNrNz (x : g) (n : int) : intmul (-x) (-n) = intmul x n. +proof. by rewrite mulNrz mulrNz opprK. qed. + +lemma nosmt mulrSz (x : g) n : intmul x (n + 1) = x + intmul x n. +proof. +case: (0 <= n) => [/mulrS ->//|]; rewrite -ltzNge => gt0_n. +case: (n = -1) => [->/=|]; 1: by rewrite mulrNz mulr1z mulr0z subrr. +move=> neq_n_N1; rewrite -!(@mulNrNz x). +rewrite (_ : -n = -(n+1) + 1) 1:/# mulrS 1:/#. +by rewrite addrA subrr add0r. +qed. + +lemma nosmt mulrDz (x : g) (n m : int) : intmul x (n + m) = intmul x n + intmul x m. +proof. +wlog: n m / 0 <= m => [wlog|]. ++ case: (0 <= m) => [/wlog|]; first by apply. + rewrite -ltzNge => lt0_m; rewrite (_ : n + m = -(-m - n)) 1:/#. + by rewrite mulrNz addzC wlog 1:/# !mulrNz -opprD opprK. +elim: m => /= [|m ge0_m ih]; first by rewrite mulr0z addr0. +by rewrite addzA !mulrSz ih addrCA. +qed. + +end section. + +(* -------------------------------------------------------------------- *) +type class comring <: group = { + op oner : comring + op ( * ) : comring -> comring -> comring + op invr : comring -> comring + op unit : comring -> bool + + axiom oner_neq0 : oner <> zeror + axiom mulrA : associative ( * ) + axiom mulrC : commutative ( * ) + axiom mul1r : left_id oner ( * ) + axiom mulrDl : left_distributive ( * ) (+)<:comring> + axiom mulVr : left_inverse_in unit oner invr ( * ) + axiom unitP : forall (x y : comring), y * x = oner => unit x + axiom unitout : forall (x : comring), !unit x => invr x = x +}. + +section. +declare type r <: comring. + +instance monoid with r + op idm = oner<:r> + op (+) = ( * )<:r>. +realize addmA by exact: mulrA. +realize addmC by exact: mulrC. +realize add0m by exact: mul1r. + +abbrev ( / ) (x y : r) = x * (invr y). + +lemma nosmt mulr1: right_id oner ( * )<:r>. +proof. by move=> x; rewrite mulrC mul1r. qed. + +lemma nosmt mulrCA: left_commutative ( * )<:r>. +proof. by move=> x y z; rewrite !mulrA (@mulrC x y). qed. + +lemma nosmt mulrAC: right_commutative ( * )<:r>. +proof. by move=> x y z; rewrite -!mulrA (@mulrC y z). qed. + +lemma nosmt mulrACA: interchange ( * ) ( * )<:r>. +proof. by move=> x y z t; rewrite -!mulrA (mulrCA y). qed. + +lemma nosmt mulrSl (x y : r) : (x + oner) * y = x * y + y. +proof. by rewrite mulrDl mul1r. qed. + +lemma nosmt mulrDr: right_distributive ( * ) (+)<:r>. +proof. by move=> x y z; rewrite mulrC mulrDl !(@mulrC _ x). qed. + +lemma nosmt mul0r: left_zero zeror ( * )<:r>. +proof. by move=> x; apply: (@addIr (oner * x)); rewrite -mulrDl !add0r mul1r. qed. + +lemma nosmt mulr0: right_zero zeror ( * )<:r>. +proof. by move=> x; apply: (@addIr (x * oner)); rewrite -mulrDr !add0r mulr1. qed. + +lemma nosmt mulrN (x y : r): x * (- y) = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDr !addrN mulr0. qed. + +lemma nosmt mulNr (x y : r): (- x) * y = - (x * y). +proof. by apply: (@addrI (x * y)); rewrite -mulrDl !addrN mul0r. qed. + +lemma nosmt mulrNN (x y : r): (- x) * (- y) = x * y. +proof. by rewrite mulrN mulNr opprK. qed. + +lemma nosmt mulN1r (x : r): (-oner) * x = -x. +proof. by rewrite mulNr mul1r. qed. + +lemma nosmt mulrN1 (x : r): x * -oner = -x. +proof. by rewrite mulrN mulr1. qed. + +lemma nosmt mulrBl: left_distributive ( * ) (-)<:r>. +proof. by move=> x y z; rewrite mulrDl !mulNr. qed. + +lemma nosmt mulrBr: right_distributive ( * ) (-)<:r>. +proof. by move=> x y z; rewrite mulrDr !mulrN. qed. + +lemma nosmt mulrnAl (x y : r) n : 0 <= n => (intmul x n) * y = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mul0r //. +by rewrite mulrDl ih. +qed. + +lemma nosmt mulrnAr (x y : r) n : 0 <= n => x * (intmul y n) = intmul (x * y) n. +proof. +elim: n => [|n ge0n ih]; rewrite !(mulr0z, mulrS) ?mulr0 //. +by rewrite mulrDr ih. +qed. + +lemma nosmt mulrzAl (x y : r) z : (intmul x z) * y = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAl. +by rewrite -oppzK mulrNz mulNr mulrnAl -?mulrNz // oppz_ge0. +qed. + +lemma nosmt mulrzAr x (y : r) z : x * (intmul y z) = intmul (x * y) z. +proof. +case: (lezWP 0 z)=> [|_] le; first by rewrite mulrnAr. +by rewrite -oppzK mulrNz mulrN mulrnAr -?mulrNz // oppz_ge0. +qed. + +lemma nosmt mulrV: right_inverse_in unit oner invr ( * )<:r>. +proof. by move=> x /mulVr; rewrite mulrC. qed. + +lemma nosmt divrr (x : r): unit x => x / x = oner. +proof. by apply/mulrV. qed. + +lemma nosmt invr_out (x : r): !unit x => invr x = x. +proof. by apply/unitout. qed. + +lemma nosmt unitrP (x : r): unit x <=> (exists y, y * x = oner). +proof. by split=> [/mulVr<- |]; [exists (invr x) | case=> y /unitP]. qed. + +lemma nosmt mulKr: left_loop_in unit invr ( * )<:r>. +proof. by move=> x un_x y; rewrite mulrA mulVr // mul1r. qed. + +lemma nosmt mulrK: right_loop_in unit invr ( * )<:r>. +proof. by move=> y un_y x; rewrite -mulrA mulrV // mulr1. qed. + +lemma nosmt mulVKr: rev_left_loop_in unit invr ( * )<:r>. +proof. by move=> x un_x y; rewrite mulrA mulrV // mul1r. qed. + +lemma nosmt mulrVK: rev_right_loop_in unit invr ( * )<:r>. +proof. by move=> y nz_y x; rewrite -mulrA mulVr // mulr1. qed. + +lemma nosmt mulrI: right_injective_in unit ( * )<:r>. +proof. by move=> x Ux; have /can_inj h := mulKr _ Ux. qed. + +lemma nosmt mulIr: left_injective_in unit ( * )<:r>. +proof. by move=> x /mulrI h y1 y2; rewrite !(@mulrC _ x) => /h. qed. + +lemma nosmt unitrE (x : r): unit x <=> (x / x = oner). +proof. +split=> [Ux|xx1]; 1: by apply/divrr. +by apply/unitrP; exists (invr x); rewrite mulrC. +qed. + +lemma nosmt invrK: involutive invr<:r>. +proof. +move=> x; case: (unit x)=> Ux; 2: by rewrite !invr_out. +rewrite -(mulrK _ Ux (invr (invr x))) -mulrA. +rewrite (@mulrC x) mulKr //; apply/unitrP. +by exists x; rewrite mulrV. +qed. + +lemma nosmt invr_inj: injective invr<:r>. +proof. by apply: (can_inj _ _ invrK). qed. + +lemma nosmt unitrV (x : r): unit (invr x) <=> unit x. +proof. by rewrite !unitrE invrK mulrC. qed. + +lemma nosmt unitr1: unit oner<:r>. +proof. by apply/unitrP; exists oner; rewrite mulr1. qed. + +lemma nosmt invr1: invr oner = oner<:r>. +proof. by rewrite -{2}(mulVr _ unitr1) mulr1. qed. + +lemma nosmt div1r x: oner / x = invr x. +proof. by rewrite mul1r. qed. + +lemma nosmt divr1 x: x / oner = x. +proof. by rewrite invr1 mulr1. qed. + +lemma nosmt unitr0: !unit zeror<:r>. +proof. by apply/negP=> /unitrP [y]; rewrite mulr0 eq_sym oner_neq0. qed. + +lemma nosmt invr0: invr zeror = zeror<:r>. +proof. by rewrite invr_out ?unitr0. qed. + +lemma nosmt unitrN1: unit (-oner<:r>). +proof. by apply/unitrP; exists (-oner); rewrite mulrNN mulr1. qed. + +lemma nosmt invrN1: invr (-oner) = -oner<:r>. +proof. by rewrite -{2}(divrr unitrN1) mulN1r opprK. qed. + +lemma nosmt unitrMl (x y : r) : unit y => (unit (x * y) <=> unit x). +proof. (* FIXME: wlog *) +move=> uy; case: (unit x)=> /=; last first. + apply/contra=> uxy; apply/unitrP; exists (y * invr (x * y)). + apply/(mulrI (invr y)); first by rewrite unitrV. + rewrite !mulrA mulVr // mul1r; apply/(mulIr y)=> //. + by rewrite -mulrA mulVr // mulr1 mulVr. +move=> ux; apply/unitrP; exists (invr y * invr x). +by rewrite -!mulrA mulKr // mulVr. +qed. + +lemma nosmt unitrMr (x y : r): unit x => (unit (x * y) <=> unit y). +proof. +move=> ux; split=> [uxy|uy]; last by rewrite unitrMl. +by rewrite -(mulKr _ ux y) unitrMl ?unitrV. +qed. + +lemma nosmt unitrM (x y : r) : unit (x * y) <=> (unit x /\ unit y). +proof. +case: (unit x) => /=; first by apply: unitrMr. +apply: contra => /unitrP[z] zVE; apply/unitrP. +by exists (y * z); rewrite mulrAC (@mulrC y) (@mulrC _ z). +qed. + +lemma nosmt unitrN (x : r) : unit (-x) <=> unit x. +proof. by rewrite -mulN1r unitrMr // unitrN1. qed. + +lemma nosmt invrM (x y : r) : unit x => unit y => invr (x * y) = invr y * invr x. +proof. +move=> Ux Uy; have Uxy: unit (x * y) by rewrite unitrMl. +by apply: (mulrI _ Uxy); rewrite mulrV ?mulrA ?mulrK ?mulrV. +qed. + +lemma nosmt invrN (x : r) : invr (- x) = - (invr x). +proof. +case: (unit x) => ux; last by rewrite !invr_out ?unitrN. +by rewrite -mulN1r invrM ?unitrN1 // invrN1 mulrN1. +qed. + +lemma nosmt invr_neq0 (x : r) : x <> zeror => invr x <> zeror. +proof. +move=> nx0; case: (unit x)=> Ux; last by rewrite invr_out ?Ux. +by apply/negP=> x'0; move: Ux; rewrite -unitrV x'0 unitr0. +qed. + +lemma nosmt invr_eq0 (x : r) : (invr x = zeror) <=> (x = zeror). +proof. by apply/iff_negb; split=> /invr_neq0; rewrite ?invrK. qed. + +lemma nosmt invr_eq1 (x : r) : (invr x = oner) <=> (x = oner). +proof. by rewrite (inv_eq invrK) invr1. qed. + +op ofint n = intmul oner<:r> n. + +lemma nosmt ofint0: ofint 0 = zeror. +proof. by apply/mulr0z. qed. + +lemma nosmt ofint1: ofint 1 = oner. +proof. by apply/mulr1z. qed. + +lemma nosmt ofintS (i : int): 0 <= i => ofint (i+1) = oner + ofint i. +proof. by apply/mulrS. qed. + +lemma nosmt ofintN (i : int): ofint (-i) = - (ofint i). +proof. by apply/mulrNz. qed. + +lemma nosmt mul1r0z x: x * ofint 0 = zeror. +proof. by rewrite ofint0 mulr0. qed. + +lemma nosmt mul1r1z x : x * ofint 1 = x. +proof. by rewrite ofint1 mulr1. qed. + +lemma nosmt mul1r2z x : x * ofint 2 = x + x. +proof. by rewrite /ofint mulr2z mulrDr mulr1. qed. + +lemma nosmt mulr_intl x z : (ofint z) * x = intmul x z. +proof. by rewrite mulrzAl mul1r. qed. + +lemma nosmt mulr_intr x z : x * (ofint z) = intmul x z. +proof. by rewrite mulrzAr mulr1. qed. + +op exp (x : r) (n : int) = + if n < 0 + then invr (iterop (-n) ( * ) x oner) + else iterop n ( * ) x oner. + +lemma nosmt expr0 x: exp x 0 = oner. +proof. by rewrite /exp /= iterop0. qed. + +lemma nosmt expr1 x: exp x 1 = x. +proof. by rewrite /exp /= iterop1. qed. + +lemma nosmt exprS (x : r) i: 0 <= i => exp x (i+1) = x * (exp x i). +proof. +move=> ge0i; rewrite /exp !ltzNge ge0i addz_ge0 //=. +(* we want to use the multiplicative monoid instance here *) +(* by rewrite !Monoid.iteropE iterS. *) admit. +qed. + +lemma nosmt expr_pred (x : r) i : 0 < i => exp x i = x * (exp x (i - 1)). +proof. smt(exprS). qed. + +lemma nosmt exprSr (x : r) i: 0 <= i => exp x (i+1) = (exp x i) * x. +proof. by move=> ge0_i; rewrite exprS // mulrC. qed. + +lemma nosmt expr2 x: exp x 2 = x * x. +proof. by rewrite (@exprS _ 1) // expr1. qed. + +lemma nosmt exprN (x : r) (i : int): exp x (-i) = invr (exp x i). +proof. +case: (i = 0) => [->|]; first by rewrite oppz0 expr0 invr1. +rewrite /exp oppz_lt0 ltzNge lez_eqVlt oppzK=> -> /=. +by case: (_ < _)%Int => //=; rewrite invrK. +qed. + +lemma nosmt exprN1 (x : r) : exp x (-1) = invr x. +proof. by rewrite exprN expr1. qed. + +lemma nosmt unitrX x m : unit x => unit (exp x m). +proof. +move=> invx; wlog: m / (0 <= m) => [wlog|]. ++ (have [] : (0 <= m \/ 0 <= -m) by move=> /#); first by apply: wlog. + by move=> ?; rewrite -oppzK exprN unitrV &(wlog). +elim: m => [|m ge0_m ih]; first by rewrite expr0 unitr1. +by rewrite exprS // &(unitrMl). +qed. + +lemma nosmt unitrX_neq0 x m : m <> 0 => unit (exp x m) => unit x. +proof. +wlog: m / (0 < m) => [wlog|]. ++ case: (0 < m); [by apply: wlog | rewrite ltzNge /= => le0_m nz_m]. + by move=> h; (apply: (wlog (-m)); 1,2:smt()); rewrite exprN unitrV. +by move=> gt0_m _; rewrite (_ : m = m - 1 + 1) // exprS 1:/# unitrM. +qed. + +lemma nosmt exprV (x : r) (i : int): exp (invr x) i = exp x (-i). +proof. +wlog: i / (0 <= i) => [wlog|]; first by smt(exprN). +elim: i => /= [|i ge0_i ih]; first by rewrite !expr0. +case: (i = 0) => [->|] /=; first by rewrite exprN1 expr1. +move=> nz_i; rewrite exprS // ih !exprN. +case: (unit x) => [invx|invNx]. ++ by rewrite -invrM ?unitrX // exprS // mulrC. +rewrite !invr_out //; last by rewrite exprS. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. ++ by apply: contra invNx; apply: unitrX_neq0 => /#. +qed. + +lemma nosmt exprVn (x : r) (n : int) : 0 <= n => exp (invr x) n = invr (exp x n). +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 invr1. +case: (unit x) => ux. +- by rewrite exprSr -1:exprS // invrM ?unitrX // ih -invrM // unitrX. +- by rewrite !invr_out //; apply: contra ux; apply: unitrX_neq0 => /#. +qed. + +lemma nosmt exprMn (x y : r) (n : int) : 0 <= n => exp (x * y) n = exp x n * exp y n. +proof. +elim: n => [|n ge0_n ih]; first by rewrite !expr0 mulr1. +by rewrite !exprS // mulrACA ih. +qed. + +lemma nosmt exprD_nneg x (m n : int) : 0 <= m => 0 <= n => + exp x (m + n) = exp x m * exp x n. +proof. + move=> ge0_m ge0_n; elim: m ge0_m => [|m ge0_m ih]. + by rewrite expr0 mul1r. + by rewrite addzAC !exprS ?addz_ge0 // ih mulrA. +qed. + +lemma nosmt exprD x (m n : int) : unit x => exp x (m + n) = exp x m * exp x n. +proof. +wlog: m n x / (0 <= m + n) => [wlog invx|]. ++ case: (0 <= m + n); [by move=> ?; apply: wlog | rewrite lezNgt /=]. + move=> lt0_mDn; rewrite -(@oppzK (m + n)) -exprV. + rewrite -{2}(@oppzK m) -{2}(@oppzK n) -!(@exprV _ (- _)%Int). + by rewrite -wlog 1:/# ?unitrV //#. +move=> ge0_mDn invx; wlog: m n ge0_mDn / (m <= n) => [wlog|le_mn]. ++ by case: (m <= n); [apply: wlog | rewrite mulrC addzC /#]. +(have ge0_n: 0 <= n by move=> /#); elim: n ge0_n m le_mn ge0_mDn. ++ by move=> n _ _ /=; rewrite expr0 mulr1. +move=> n ge0_n ih m le_m_Sn ge0_mDSn; move: ge0_mDSn. +rewrite lez_eqVlt => -[?|]; first have->: n+1 = -m by move=> /#. ++ by rewrite subzz exprN expr0 divrr // unitrX. +move=> gt0_mDSn; move: le_m_Sn; rewrite lez_eqVlt. +case=> [->>|lt_m_Sn]; first by rewrite exprD_nneg //#. +by rewrite addzA exprS 1:/# ih 1,2:/# exprS // mulrCA. +qed. + +lemma nosmt exprM x (m n : int) : + exp x (m * n) = exp (exp x m) n. +proof. +wlog : n / 0 <= n. ++ move=> h; case: (0 <= n) => hn; 1: by apply h. + by rewrite -{1}(@oppzK n) (_: m * - -n = -(m * -n)) 1:/# + exprN h 1:/# exprN invrK. +wlog : m / 0 <= m. ++ move=> h; case: (0 <= m) => hm hn; 1: by apply h. + rewrite -{1}(@oppzK m) (_: (- -m) * n = - (-m) * n) 1:/#. + by rewrite exprN h 1:/# // exprN exprV exprN invrK. +elim/natind: n => [|n hn ih hm _]; 1: smt (expr0). +by rewrite mulzDr exprS //= mulrC exprD_nneg 1:/# 1:// ih. +qed. + +lemma nosmt expr0n n : 0 <= n => exp zeror n = if n = 0 then oner else zeror. +proof. +elim: n => [|n ge0_n _]; first by rewrite expr0. +by rewrite exprS // mul0r addz1_neq0. +qed. + +lemma nosmt expr0z z : exp zeror z = if z = 0 then oner else zeror. +proof. +case: (0 <= z) => [/expr0n // | /ltzNge lt0_z]. +rewrite -{1}(@oppzK z) exprN; have ->/=: z <> 0 by smt(). +by rewrite invr_eq0 expr0n ?oppz_ge0 1:ltzW. +qed. + +lemma nosmt expr1z z : exp oner z = oner. +proof. +elim/intwlog: z. ++ by move=> n h; rewrite -(@oppzK n) exprN h invr1. ++ by rewrite expr0. ++ by move=> n ge0_n ih; rewrite exprS // mul1r ih. +qed. + +lemma nosmt sqrrD (x y : r) : + exp (x + y) 2 = exp x 2 + intmul (x * y) 2 + exp y 2. +proof. +by rewrite !expr2 mulrDl !mulrDr mulr2z !addrA (@mulrC y x). +qed. + +lemma nosmt sqrrN x : exp (-x) 2 = exp x 2. +proof. by rewrite !expr2 mulrNN. qed. + +lemma nosmt sqrrB x y : + exp (x - y) 2 = exp x 2 - intmul (x * y) 2 + exp y 2. +proof. by rewrite sqrrD sqrrN mulrN mulNrz. qed. + +lemma nosmt signr_odd n : 0 <= n => exp (-oner) (b2i (odd n)) = exp (-oner) n. +proof. +elim: n => [|n ge0_nih]; first by rewrite odd0 expr0 expr0. +rewrite !(iterS, oddS) // exprS // -/(odd _) => <-. +by case: (odd _); rewrite /b2i /= !(expr0, expr1) mulN1r ?opprK. +qed. + +lemma nosmt subr_sqr_1 x : exp x 2 - oner = (x - oner) * (x + oner). +proof. +rewrite mulrBl mulrDr !(mulr1, mul1r) expr2 -addrA. +by congr; rewrite opprD addrA addrN add0r. +qed. + +op lreg (x : r) = injective (fun y => x * y). + +lemma nosmt mulrI_eq0 x y : lreg x => (x * y = zeror) <=> (y = zeror). +proof. by move=> reg_x; rewrite -{1}(mulr0 x) (inj_eq reg_x). qed. + +lemma nosmt lreg_neq0 x : lreg x => x <> zeror. +proof. +apply/contraL=> ->; apply/negP => /(_ zeror oner). +by rewrite (@eq_sym _ oner) oner_neq0 /= !mul0r. +qed. + +lemma nosmt mulrI0_lreg x : (forall y, x * y = zeror => y = zeror) => lreg x. +proof. +by move=> reg_x y z eq; rewrite -subr_eq0 &(reg_x) mulrBr eq subrr. +qed. + +lemma nosmt lregN x : lreg x => lreg (-x). +proof. by move=> reg_x y z; rewrite !mulNr => /oppr_inj /reg_x. qed. + +lemma nosmt lreg1 : lreg oner. +proof. by move=> x y; rewrite !mul1r. qed. + +lemma nosmt lregM x y : lreg x => lreg y => lreg (x * y). +proof. by move=> reg_x reg_y z t; rewrite -!mulrA => /reg_x /reg_y. qed. + +lemma nosmt lregXn x n : 0 <= n => lreg x => lreg (exp x n). +proof. +move=> + reg_x; elim: n => [|n ge0_n ih]. +- by rewrite expr0 &(lreg1). +- by rewrite exprS // &(lregM). +qed. +end section. + +(* +(* -------------------------------------------------------------------- *) +abstract theory ComRingDflInv. + clone include ComRing with + pred unit (x : t) = exists y, y * x = oner, + op invr (x : t) = choiceb (fun y => y * x = oner) x + + proof mulVr, unitP, unitout. + + realize mulVr. + proof. + move=> x ^ un_x [y ^ -> <-] @/invr_. + by have /= -> := choicebP _ x un_x. + qed. + + realize unitP. + proof. by move=> x y eq; exists y. qed. + + realize unitout. + proof. + by move=> x; rewrite /unit_ negb_exists => /choiceb_dfl /(_ x). + qed. +end ComRingDflInv. +*) + +(* -------------------------------------------------------------------- *) +type class boolring <: comring = { + axiom mulrr : forall (x : boolring), x * x = x +}. + +lemma nosmt addrr ['a <: boolring] (x : 'a): x + x = zeror. +proof. +apply (@addrI (x + x)); rewrite addr0 -{1 2 3 4}mulrr. +by rewrite -mulrDr -mulrDl mulrr. +qed. + +(* -------------------------------------------------------------------- *) +type class idomain <: comring = { + axiom mulf_eq0: + forall (x y : idomain), x * y = zeror <=> x = zeror \/ y = zeror +}. + +section. +declare type r <: idomain. + +lemma nosmt mulf_neq0 (x y : r): x <> zeror => y <> zeror => x * y <> zeror. +proof. by move=> nz_x nz_y; apply/negP; rewrite mulf_eq0 /#. qed. + +lemma nosmt expf_eq0 (x : r) n : (exp x n = zeror) <=> (n <> 0 /\ x = zeror). +proof. +elim/intwlog: n => [n| |n ge0_n ih]. ++ by rewrite exprN invr_eq0 /#. ++ by rewrite expr0 oner_neq0. +by rewrite exprS // mulf_eq0 ih addz1_neq0 ?andKb. +qed. + +lemma nosmt mulfI (x : r): x <> zeror => injective (( * ) x). +proof. +move=> ne0_x y y'; rewrite -(opprK (x * y')) -mulrN -addr_eq0. +by rewrite -mulrDr mulf_eq0 ne0_x /= addr_eq0 opprK. +qed. + +lemma nosmt mulIf (x : r): x <> zeror => injective (fun y => y * x). +proof. by move=> nz_x y z; rewrite -!(@mulrC x); exact: mulfI. qed. + +lemma nosmt sqrf_eq1 (x : r): (exp x 2 = oner) <=> (x = oner \/ x = -oner). +proof. by rewrite -subr_eq0 subr_sqr_1 mulf_eq0 subr_eq0 addr_eq0. qed. + +lemma nosmt lregP (x : r): lreg x <=> x <> zeror. +proof. by split=> [/lreg_neq0//|/mulfI]. qed. + +lemma nosmt eqr_div (x1 y1 x2 y2 : r) : unit y1 => unit y2 => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. +move=> Nut1 Nut2; rewrite -{1}(@mulrK y2 _ x1) //. +rewrite -{1}(@mulrK y1 _ x2) // -!mulrA (@mulrC (invr y1)) !mulrA. +split=> [|->] //; + (have nz_Vy1: unit (invr y1) by rewrite unitrV); + (have nz_Vy2: unit (invr y2) by rewrite unitrV). +by move/(mulIr _ nz_Vy1)/(mulIr _ nz_Vy2). +qed. +end section. + +(* -------------------------------------------------------------------- *) +(* +(* TODO: Disjointness of type class operator names? *) +type class ffield <: group = { + op onef : ffield + op ( * ) : ffield -> ffield -> ffield + op invf : ffield -> ffield + + axiom onef_neq0 : onef <> zeror + axiom mulfA : associative ( * ) + axiom mulfC : commutative ( * ) + axiom mul1f : left_id onef ( * ) + axiom mulfDl : left_distributive ( * ) (+)<:ffield> + axiom mulVf : left_inverse_in (predC (pred1 zeror)) onef invf ( * ) + axiom unitP : forall (x y : ffield), y * x = onef => x <> zeror + axiom unitout : invr zeror = zeror +}. +*) + +(* TODO: Probably not the right way *) +type class ffield <: comring = { + axiom unit_neq0: forall (x : ffield), unit x <=> x <> zeror +}. + +section. +declare type f <: ffield. + +lemma nosmt mulfV (x : f): x <> zeror => x * (invr x) = oner. +proof. by move=> /unit_neq0/mulrV. qed. + +lemma nosmt mulVf (x : f): x <> zeror => (invr x) * x = oner. +proof. by move=> /unit_neq0/mulVr. qed. + +lemma nosmt divff (x : f): x <> zeror => x / x = oner. +proof. by move=> /unit_neq0/divrr. qed. + +lemma nosmt invfM (x y : f) : invr (x * y) = invr x * invr y. +proof. +case: (x = zeror) => [->|nz_x]; first by rewrite !(mul0r, invr0). +case: (y = zeror) => [->|nz_y]; first by rewrite !(mulr0, invr0). +by rewrite invrM ?unit_neq0 // mulrC. +qed. + +lemma nosmt invf_div (x y : f) : invr (x / y) = y / x. +proof. by rewrite invfM invrK mulrC. qed. + +lemma nosmt eqf_div (x1 y1 x2 y2 : f) : y1 <> zeror => y2 <> zeror => + (x1 / y1 = x2 / y2) <=> (x1 * y2 = x2 * y1). +proof. by rewrite -!unit_neq0; exact: eqr_div<:f>. qed. + +lemma nosmt expfM (x y : f) n : exp (x * y) n = exp x n * exp y n. +proof. +elim/intwlog: n => [n h | | n ge0_n ih]. ++ by rewrite -(@oppzK n) !(@exprN _ (-n)) h invfM. ++ by rewrite !expr0 mulr1. ++ by rewrite !exprS // mulrCA -!mulrA -ih mulrCA. +qed. +end section. + +(* --------------------------------------------------------------------- *) +(* Rewrite database for algebra tactic *) + +hint rewrite rw_algebra : . +hint rewrite inj_algebra : . + +(* -------------------------------------------------------------------- *) +(* TODO: Instantiation of type classes with inheritance is broken *) +(* TODO: Instantiation of type class operators with literals is broken *) +op zeroz = 0. +op addz (x y : int) = x + y. +op negz (x : int) = -x. + + +instance monoid with int + op idm = zeroz + op (+) = addz. +realize addmA by exact: addzA. +realize addmC by exact: addzC. +realize add0m by exact: add0z. + +(* TODO: This is just broken *) +instance group with int + (* op idm = zeroz *) + op [-] = negz. +realize addNr. +(* TODO: Note that the zero remains undefined *) +rewrite /left_inverse /negz /idm. +(* by exact: addNz. *) admit. + +(* +theory IntID. +clone include IDomain with + type t <- int, + pred unit (z : int) <- (z = 1 \/ z = -1), + op zeror <- 0, + op oner <- 1, + op ( + ) <- Int.( + ), + op [ - ] <- Int.([-]), + op ( * ) <- Int.( * ), + op invr <- (fun (z : int) => z) + proof * by smt + remove abbrev (-) + remove abbrev (/) + rename "ofint" as "ofint_id". + +abbrev (^) = exp. + +lemma intmulz z c : intmul z c = z * c. +proof. +have h: forall cp, 0 <= cp => intmul z cp = z * cp. + elim=> /= [|cp ge0_cp ih]; first by rewrite mulr0z. + by rewrite mulrS // ih mulrDr /= addrC. +smt(opprK mulrNz opprK). +qed. + +lemma poddX n x : 0 < n => odd (exp x n) = odd x. +proof. +rewrite ltz_def => - [] + ge0_n; elim: n ge0_n => // + + _ _. +elim=> [|n ge0_n ih]; first by rewrite expr1. +by rewrite exprS ?addz_ge0 // oddM ih andbb. +qed. + +lemma oddX n x : 0 <= n => odd (exp x n) = (odd x \/ n = 0). +proof. +rewrite lez_eqVlt; case: (n = 0) => [->// _|+ h]. ++ by rewrite expr0 odd1. ++ by case: h => [<-//|] /poddX ->. +qed. +end IntID. +*) diff --git a/examples/typeclasses/monoidtc.ec b/examples/typeclasses/monoidtc.ec new file mode 100644 index 0000000000..f69122c423 --- /dev/null +++ b/examples/typeclasses/monoidtc.ec @@ -0,0 +1,54 @@ +require import Int. + +(* -------------------------------------------------------------------- *) +type class addmonoid = { + op idm : addmonoid + op (+) : addmonoid -> addmonoid -> addmonoid + + axiom addmA : associative (+) + axiom addmC : commutative (+) + axiom add0m : left_id idm (+) +}. + +(* -------------------------------------------------------------------- *) +lemma addm0 ['a <: addmonoid] : right_id idm (+)<:'a>. +proof. by move=> x; rewrite addmC add0m. qed. + +lemma addmCA ['a <: addmonoid] : left_commutative (+)<:'a>. +proof. by move=> x y z; rewrite !addmA (addmC x). qed. + +lemma addmAC ['a <: addmonoid] : right_commutative (+)<:'a>. +proof. by move=> x y z; rewrite -!addmA (addmC y). qed. + +lemma addmACA ['a <: addmonoid] : interchange (+)<:'a> (+)<:'a>. +proof. by move=> x y z t; rewrite -!addmA (addmCA y). qed. + +lemma iteropE ['a <: addmonoid] n x: iterop n (+)<:'a> x idm<:'a> = iter n ((+)<:'a> x) idm<:'a>. +proof. + elim/natcase n => [n le0_n|n ge0_n]. + + by rewrite ?(iter0, iterop0). + + by rewrite iterSr // addm0 iteropS. +qed. + +(* -------------------------------------------------------------------- *) +abstract theory AddMonoid. + type t. + + op idm : t. + op (+) : t -> t -> t. + + theory Axioms. + axiom nosmt addmA: associative (+). + axiom nosmt addmC: commutative (+). + axiom nosmt add0m: left_id idm (+). + end Axioms. + + instance addmonoid with t + op idm = idm + op (+) = (+). + + realize addmA by exact Axioms.addmA. + realize addmC by exact Axioms.addmC. + realize add0m by exact Axioms.add0m. + +end AddMonoid. diff --git a/examples/typeclasses/typeclass.ec b/examples/typeclasses/typeclass.ec new file mode 100644 index 0000000000..eaee3603cf --- /dev/null +++ b/examples/typeclasses/typeclass.ec @@ -0,0 +1,353 @@ +(* ==================================================================== *) +(* Typeclass examples *) + +(* -------------------------------------------------------------------- *) +require import AllCore List. + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +type class ['a] artificial = { + op myop : artificial * 'a +}. + +op myopi ['a] : int * 'a = (0, witness<:'a>). + +instance 'b artificial with ['b] int + op myop = myopi<:'b>. + +lemma reduce_tc : myop<:bool, int> = (0, witness). +proof. +class. +reflexivity. +qed. + +(* -------------------------------------------------------------------- *) +type class witness = { + op witness : witness +}. + +print witness. + +type class finite = { + op enum : finite list + axiom enumP : forall (x : finite), x \in enum +}. + +print enum. +print enumP. + +type class countable = { + op count : int -> countable + axiom countP : forall (x : countable), exists (n : int), x = count n +}. + +(* -------------------------------------------------------------------- *) +(* Simple algebraic structures *) + +type class magma = { + op mmul : magma -> magma -> magma +}. + +print mmul. + +type class semigroup <: magma = { + axiom mmulA : associative mmul<:semigroup> +}. + +print associative. + +type class monoid <: semigroup = { + op mid : monoid + + axiom mmulr0 : right_id mid mmul<:monoid> + axiom mmul0r : left_id mid mmul<:monoid> +}. + +type class group <: monoid = { + op minv : group -> group + + axiom mmulN : left_inverse mid minv mmul +}. + +type class ['a <: semigroup] semigroup_action = { + op amul : 'a -> semigroup_action -> semigroup_action + + axiom compatibility : + forall (g h : 'a) (x : semigroup_action), amul (mmul g h) x = amul g (amul h x) +}. + +type class ['a <: monoid] monoid_action <: 'a semigroup_action = { + axiom identity : forall (x : monoid_action), amul mid<:'a> x = x +}. + +(* TODO: why again is this not possible/a good idea? *) +(*type class finite_group <: group & finite = {}.*) + +(* -------------------------------------------------------------------- *) +(* Advanced algebraic structures *) + +type class comgroup = { + op zero : comgroup + op ([-]) : comgroup -> comgroup + op ( + ) : comgroup -> comgroup -> comgroup + + axiom addr0 : right_id zero ( + ) + axiom addrN : left_inverse zero ([-]) ( + ) + axiom addrC : commutative ( + ) + axiom addrA : associative ( + ) +}. + +type class comring <: comgroup = { + op one : comring + op ( * ) : comring -> comring -> comring + + axiom mulr1 : right_id one ( * ) + axiom mulrC : commutative ( * ) + axiom mulrA : associative ( * ) + axiom mulrDl : left_distributive ( * ) ( + ) +}. + +type class ['a <: comring] commodule <: comgroup = { + op ( ** ) : 'a -> commodule -> commodule + + axiom scalerDl : forall (a b : 'a) (x : commodule), + (a + b) ** x = (a ** x) + (b ** x) + axiom scalerDr : forall (a : 'a) (x y : commodule), + a ** (x + y) = (a ** x) + (a ** y) +}. + + +(* ==================================================================== *) +(* Abstract type examples *) + +(* TODO: finish the hierarchy here: + https://en.wikipedia.org/wiki/Magma_(algebra) *) +type foo <: witness. +type fingroup <: group & finite. + + + +(* TODO: printing typeclasses *) +print countable. +print magma. +print semigroup. +print monoid. +print group. +print semigroup_action. +print monoid_action. + + +(* ==================================================================== *) +(* Operator examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +op all_finite ['a <: finite] (p : 'a -> bool) = + all p enum<:'a>. + +op all_countable ['a <: countable] (p : 'a -> bool) = + forall (n : int), p (count<:'a> n). + +(* -------------------------------------------------------------------- *) +(* Simple algebraic structures *) + +(* TODO: weird issue and/or inapropriate error message : bug in ecUnify select_op*) + +print amul. +(* +op foo1 ['a <: semigroup, 'b <: 'a semigroup_action] = amul<:'a,'b>. +*) +op foo2 ['a <: semigroup, 'b <: 'a semigroup_action] (g : 'a) (x : 'b) = amul g x. +(* +op foo3 ['a <: semigroup, 'b <: 'a semigroup_action] (g : 'a) (x : 'b) = amul<:'a,'b> g x. +*) + +op big ['a, 'b <: monoid] (P : 'a -> bool) (F : 'a -> 'b) (r : 'a list) = + foldr mmul mid (map F (filter P r)). + + +(* ==================================================================== *) +(* Lemma examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +lemma all_finiteP ['a <: finite] p : (all_finite p) <=> (forall (x : 'a), p x). +proof. by rewrite/all_finite allP; split=> Hp x; rewrite Hp enumP. qed. + +lemma all_countableP ['a <: countable] p : (all_countable p) <=> (forall (x : 'a), p x). +proof. + rewrite/all_countable; split => [Hp x|Hp n]. + by case (countP x) => n ->>; rewrite Hp. + by rewrite Hp. +qed. + +lemma all_finite_countable ['a <: finite & countable] (p : 'a -> bool) : (all_finite p) <=> (all_countable p). +proof. by rewrite all_finiteP all_countableP. qed. + + +(* ==================================================================== *) +(* Instance examples *) + +(* -------------------------------------------------------------------- *) +(* Set theory *) + +op bool_enum = [true; false]. + +(* TODO: we want to be able to give the list directly.*) +instance finite with bool + op enum = bool_enum. + +realize enumP. +proof. by case. qed. + +(* -------------------------------------------------------------------- *) +(* Advanced algebraic structures *) + +(* +op izero = 0. + +instance comgroup with int + op zero = izero + op ( + ) = CoreInt.add + op ([-]) = CoreInt.opp. + +(* TODO: might be any of the two addr0, also apply fails but rewrite works. + In ecScope, where instances are declared. *) +realize addr0 by rewrite addr0. +realize addrN by trivial. +realize addrC by rewrite addrC. +realize addrA by rewrite addrA. + +op foo = 1 + 3. + +print ( + ). +print foo. + +op ione = 1. + +(* TODO: this automatically fetches the only instance of comgroup we have defined for int. + We should give the choice of which instance to use, by adding as desired_name after the with. + Also we should give the choice to define directly an instance of comring with int. *) +instance comring with int + op one = ione + op ( * ) = CoreInt.mul. + +realize mulr1 by trivial. +realize mulrC by rewrite mulrC. +realize mulrA by rewrite mulrA. + +realize mulrDl. +proof. + (*TODO: in the goal, the typeclass operator + should have been replaced with the + from CoreInt, but has not been.*) + print mulrDl. + move => x y z. + class. + apply Ring.IntID.mulrDl. +qed. + +(* ==================================================================== *) +(* Misc *) + +(* -------------------------------------------------------------------- *) +(* TODO: which instance is kept in memory after this? *) + +op bool_enum_alt = [true; false]. + +instance finite with bool + op enum = bool_enum_alt. + +realize enumP. +proof. by case. qed. + +type class find_out <: finite = { + axiom rev_enum : rev<:find_out> enum = enum +}. + +instance find_out with bool. + +realize rev_enum. +proof. + admit. +qed. + + + +(* ==================================================================== *) +(* Old TODO list: 1-3 are done, modulo bugs, 4 is to be done, 5 will be done later. *) + +(* + 1. typage -> selection des operateurs / inference des instances de tc + 2. reduction + 3. unification (tactiques) + 4. clonage + 5. envoi au SMT + + 1. + Fop : + -(old) path * ty list -> form + -(new) path * (ty * (map tcname -> tcinstance)) list -> form + + op ['a <: monoid] (+) : 'a -> 'a -> 'a. + + (+)<:int + monoid -> intadd_monoid> + (+)<:int + monoid -> intmul_monoid> + + 1.1 module de construction des formules avec typage + 1.2 utiliser le module ci-dessous + + let module M = MkForm(struct let env = env' end) in + + 1.3 UnionFind avec contraintes de TC + + 1.4 Overloading: + 3 + 4 + a. 3 Int.(+) 4 + b. 3 Monoid<:int>.(+) 4 (-> instance du dessus -> ignore) + + 1.5 foo<: int[monoid -> intadd_monoid] > + foo<: int[monoid -> intmul_monoid] > + + 2. -> Monoid.(+)<:int> -> Int.(+) + + 3. -> Pb d'unification des op + (+)<: ?[monoid -> ?] > ~ Int.(+) + + Mecanisme de resolution des TC + + 4. -> il faut cloner les TC + + 5. + + a. encodage + + record 'a premonoid = { + op zero : 'a + op add : 'a -> 'a -> 'a; + } + + pred ['a] ismonoid (m : 'a premonoid) = { + left_id m.zero m.add + } + + op ['a <: monoid] foo (x y : 'a) = x + y + + ->> foo ['a] (m : 'a premonoid) (x y : 'a) = m.add x y + + lemma foo ['a <: monoid] P + + ->> foo ['a] (m : 'a premonoid) : ismonoid m => P + + let intmonoid = { zero = 0; add = intadd } + + lemma intmonoid_is_monoid : ismonoid int_monoid + + b. reduction avant envoi + (+)<: int[monoid -> intadd_monoid > -> Int.(+) + + c. ne pas envoyer certaines instances (e.g. int est un groupe) + -> instance [nosmt] e.g. +*) +*) diff --git a/src/ecAst.ml b/src/ecAst.ml index f66e6d93ec..b6ef0c713c 100644 --- a/src/ecAst.ml +++ b/src/ecAst.ml @@ -3,7 +3,6 @@ open EcUtils open EcSymbols open EcIdent open EcPath -open EcUid module BI = EcBigInt @@ -33,7 +32,6 @@ type quantif = type hoarecmp = FHle | FHeq | FHge (* -------------------------------------------------------------------- *) - type 'a use_restr = { ur_pos : 'a option; (* If not None, can use only element in this set. *) ur_neg : 'a; (* Cannot use element in this set. *) @@ -42,6 +40,13 @@ type 'a use_restr = { type mr_xpaths = EcPath.Sx.t use_restr type mr_mpaths = EcPath.Sm.t use_restr +(* -------------------------------------------------------------------- *) +module TyUni = EcUid.CoreGen () +module TcUni = EcUid.CoreGen () + +type tyuni = TyUni.uid +type tcuni = TcUni.uid + (* -------------------------------------------------------------------- *) type ty = { ty_node : ty_node; @@ -51,12 +56,37 @@ type ty = { and ty_node = | Tglob of EcIdent.t (* The tuple of global variable of the module *) - | Tunivar of EcUid.uid + | Tunivar of tyuni | Tvar of EcIdent.t | Ttuple of ty list - | Tconstr of EcPath.path * ty list + | Tconstr of EcPath.path * etyarg list | Tfun of ty * ty +(* -------------------------------------------------------------------- *) +and etyarg = ty * tcwitness list + +and tcwitness = + | TCIUni of tcuni + + | TCIConcrete of { + path: EcPath.path; + etyargs: (ty * tcwitness list) list; + } + + | TCIAbstract of { + support: [ + | `Var of EcIdent.t + | `Abs of EcPath.path + ]; + offset: int; + } + +(* -------------------------------------------------------------------- *) +and typeclass = { + tc_name : EcPath.path; + tc_args : etyarg list; +} + (* -------------------------------------------------------------------- *) and ovariable = { ov_name : EcSymbols.symbol option; @@ -84,7 +114,7 @@ and expr_node = | Eint of BI.zint (* int. literal *) | Elocal of EcIdent.t (* let-variables *) | Evar of prog_var (* module variable *) - | Eop of EcPath.path * ty list (* op apply to type args *) + | Eop of EcPath.path * etyarg list (* op apply to type args *) | Eapp of expr * expr list (* op. application *) | Equant of equantif * ebindings * expr (* fun/forall/exists *) | Elet of lpattern * expr * expr (* let binding *) @@ -185,7 +215,7 @@ and f_node = | Flocal of EcIdent.t | Fpvar of prog_var * memory | Fglob of EcIdent.t * memory - | Fop of EcPath.path * ty list + | Fop of EcPath.path * etyarg list | Fapp of form * form list | Ftuple of form list | Fproj of form * int @@ -354,6 +384,83 @@ let lp_fv = function (fun s (id, _) -> ofold Sid.add s id) Sid.empty ids +(* -------------------------------------------------------------------- *) +let rec tcw_fv (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + Mid.empty + + | TCIConcrete { etyargs } -> + List.fold_left + (fun fv (ty, tcws) -> fv_union fv (fv_union ty.ty_fv (tcws_fv tcws))) + Mid.empty etyargs + + | TCIAbstract _ -> + Mid.empty (* FIXME:TC *) + +and tcws_fv (tcws : tcwitness list) = + List.fold_left + (fun fv tcw -> fv_union fv (tcw_fv tcw)) + Mid.empty tcws + +let etyarg_fv ((ty, tcws) : etyarg) = + fv_union ty.ty_fv (tcws_fv tcws) + +let etyargs_fv (tyargs : etyarg list) = + List.fold_left + (fun fv tyarg -> fv_union fv (etyarg_fv tyarg)) + Mid.empty tyargs + +(* -------------------------------------------------------------------- *) +let rec tcw_equal (tcw1 : tcwitness) (tcw2 : tcwitness) = + match tcw1, tcw2 with + | TCIUni uid1, TCIUni uid2 -> + TcUni.uid_equal uid1 uid2 + + | TCIConcrete tcw1, TCIConcrete tcw2 -> + EcPath.p_equal tcw1.path tcw2.path + && List.all2 etyarg_equal tcw1.etyargs tcw2.etyargs + + | TCIAbstract { support = support1; offset = o1; } + , TCIAbstract { support = support2; offset = o2; } + -> + let tyvar_eq () = + match support1, support2 with + | `Var x1, `Var x2 -> + EcIdent.id_equal x1 x2 + | `Abs p1, `Abs p2 -> + EcPath.p_equal p1 p2 + | _, _ -> false + + in o1 = o2 && tyvar_eq () + + | _, _ -> + false + +and etyarg_equal ((ty1, tcws1) : etyarg) ((ty2, tcws2) : etyarg) = + ty_equal ty1 ty2 && List.all2 tcw_equal tcws1 tcws2 + +(* -------------------------------------------------------------------- *) +let rec tcw_hash (tcw : tcwitness) = + match tcw with + | TCIUni uid -> + Hashtbl.hash uid + + | TCIConcrete tcw -> + Why3.Hashcons.combine_list + etyarg_hash + (p_hash tcw.path) + tcw.etyargs + + | TCIAbstract { support = `Var tyvar; offset } -> + Why3.Hashcons.combine (EcIdent.id_hash tyvar) offset + + | TCIAbstract { support = `Abs p; offset } -> + Why3.Hashcons.combine (EcPath.p_hash p) offset + + and etyarg_hash ((ty, tcws) : etyarg) = + Why3.Hashcons.combine_list tcw_hash (ty_hash ty) tcws + (* -------------------------------------------------------------------- *) let e_equal = ((==) : expr -> expr -> bool) let e_hash = fun e -> e.e_tag @@ -364,7 +471,6 @@ let eqt_equal : equantif -> equantif -> bool = (==) let eqt_hash : equantif -> int = Hashtbl.hash (* -------------------------------------------------------------------- *) - let lv_equal lv1 lv2 = match lv1, lv2 with | LvVar (pv1, ty1), LvVar (pv2, ty2) -> @@ -388,7 +494,6 @@ let lv_fv = function let add s (pv, _) = EcIdent.fv_union s (pv_fv pv) in List.fold_left add Mid.empty pvs - let lv_hash = function | LvVar (pv, ty) -> Why3.Hashcons.combine (pv_hash pv) (ty_hash ty) @@ -398,7 +503,6 @@ let lv_hash = function (fun (pv, ty) -> Why3.Hashcons.combine (pv_hash pv) (ty_hash ty)) 0 pvs - (* -------------------------------------------------------------------- *) let i_equal = ((==) : instr -> instr -> bool) let i_hash = fun i -> i.i_tag @@ -408,7 +512,6 @@ let s_equal = ((==) : stmt -> stmt -> bool) let s_hash = fun s -> s.s_tag let s_fv = fun s -> s.s_fv - (*-------------------------------------------------------------------- *) let qt_equal : quantif -> quantif -> bool = (==) let qt_hash : quantif -> int = Hashtbl.hash @@ -775,7 +878,7 @@ module Hsty = Why3.Hashcons.Make (struct EcIdent.id_equal m1 m2 | Tunivar u1, Tunivar u2 -> - uid_equal u1 u2 + TyUni.uid_equal u1 u2 | Tvar v1, Tvar v2 -> id_equal v1 v2 @@ -784,7 +887,7 @@ module Hsty = Why3.Hashcons.Make (struct List.all2 ty_equal lt1 lt2 | Tconstr (p1, lt1), Tconstr (p2, lt2) -> - EcPath.p_equal p1 p2 && List.all2 ty_equal lt1 lt2 + EcPath.p_equal p1 p2 && List.all2 etyarg_equal lt1 lt2 | Tfun (d1, c1), Tfun (d2, c2)-> ty_equal d1 d2 && ty_equal c1 c2 @@ -794,10 +897,10 @@ module Hsty = Why3.Hashcons.Make (struct let hash ty = match ty.ty_node with | Tglob m -> EcIdent.id_hash m - | Tunivar u -> u + | Tunivar u -> Hashtbl.hash u | Tvar id -> EcIdent.tag id | Ttuple tl -> Why3.Hashcons.combine_list ty_hash 0 tl - | Tconstr (p, tl) -> Why3.Hashcons.combine_list ty_hash p.p_tag tl + | Tconstr (p, tl) -> Why3.Hashcons.combine_list etyarg_hash p.p_tag tl | Tfun (t1, t2) -> Why3.Hashcons.combine (ty_hash t1) (ty_hash t2) let fv ty = @@ -809,7 +912,7 @@ module Hsty = Why3.Hashcons.Make (struct | Tunivar _ -> Mid.empty | Tvar _ -> Mid.empty (* FIXME: section *) | Ttuple tys -> union (fun a -> a.ty_fv) tys - | Tconstr (_, tys) -> union (fun a -> a.ty_fv) tys + | Tconstr (_, tys) -> union etyarg_fv tys | Tfun (t1, t2) -> union (fun a -> a.ty_fv) [t1; t2] let tag n ty = { ty with ty_tag = n; ty_fv = fv ty.ty_node; } @@ -819,7 +922,6 @@ let mk_ty node = Hsty.hashcons { ty_node = node; ty_tag = -1; ty_fv = Mid.empty } (* ----------------------------------------------------------------- *) - module Hexpr = Why3.Hashcons.Make (struct type t = expr @@ -836,7 +938,7 @@ module Hexpr = Why3.Hashcons.Make (struct | Eop (p1, tys1), Eop (p2, tys2) -> (EcPath.p_equal p1 p2) - && (List.all2 ty_equal tys1 tys2) + && (List.all2 etyarg_equal tys1 tys2) | Eapp (e1, es1), Eapp (e2, es2) -> (e_equal e1 e2) @@ -879,9 +981,8 @@ module Hexpr = Why3.Hashcons.Make (struct | Elocal x -> Hashtbl.hash x | Evar x -> pv_hash x - | Eop (p, tys) -> - Why3.Hashcons.combine_list ty_hash - (EcPath.p_hash p) tys + | Eop (p, tyargs) -> + Why3.Hashcons.combine_list etyarg_hash (EcPath.p_hash p) tyargs | Eapp (e, es) -> Why3.Hashcons.combine_list e_hash (e_hash e) es @@ -915,7 +1016,7 @@ module Hexpr = Why3.Hashcons.Make (struct match e with | Eint _ -> Mid.empty - | Eop (_, tys) -> union (fun a -> a.ty_fv) tys + | Eop (_, tyargs) -> etyargs_fv tyargs | Evar v -> pv_fv v | Elocal id -> fv_singleton id | Eapp (e, es) -> union e_fv (e :: es) @@ -932,7 +1033,27 @@ module Hexpr = Why3.Hashcons.Make (struct end) (* -------------------------------------------------------------------- *) -let mk_expr e ty = +let normalize_enode (node : expr_node) : expr_node = + match node with + | Equant (_, [], body) -> + body.e_node + + | Equant (q1, bds1, { e_node = Equant (q2, bds2, body) }) + when q1 = q2 + -> Equant (q1, bds1 @ bds2, body) + + | Eapp (hd, []) -> + hd.e_node + + | Eapp ({ e_node = Eapp (hd, args1) }, args2) -> + Eapp (hd, args1 @ args2) + + | _ -> + node + +(* -------------------------------------------------------------------- *) +let mk_expr (e : expr_node) (ty : ty) = + let e = normalize_enode e in Hexpr.hashcons { e_node = e; e_tag = -1; e_fv = Mid.empty; e_ty = ty } (* -------------------------------------------------------------------- *) @@ -971,7 +1092,7 @@ module Hsform = Why3.Hashcons.Make (struct EcIdent.id_equal mp1 mp2 && EcIdent.id_equal m1 m2 | Fop(p1,lty1), Fop(p2,lty2) -> - EcPath.p_equal p1 p2 && List.all2 ty_equal lty1 lty2 + EcPath.p_equal p1 p2 && List.all2 etyarg_equal lty1 lty2 | Fapp(f1,args1), Fapp(f2,args2) -> f_equal f1 f2 && List.all2 f_equal args1 args2 @@ -1025,8 +1146,10 @@ module Hsform = Why3.Hashcons.Make (struct | Fglob(mp, m) -> Why3.Hashcons.combine (EcIdent.id_hash mp) (EcIdent.id_hash m) - | Fop(p, lty) -> - Why3.Hashcons.combine_list ty_hash (EcPath.p_hash p) lty + | Fop(p, tyargs) -> + Why3.Hashcons.combine_list + etyarg_hash (EcPath.p_hash p) + tyargs | Fapp(f, args) -> Why3.Hashcons.combine_list f_hash (f_hash f) args @@ -1056,7 +1179,7 @@ module Hsform = Why3.Hashcons.Make (struct match f with | Fint _ -> Mid.empty - | Fop (_, tys) -> union (fun a -> a.ty_fv) tys + | Fop (_, tyargs) -> union etyarg_fv tyargs | Fpvar (PVglob pv,m) -> EcPath.x_fv (fv_add m Mid.empty) pv | Fpvar (PVloc _,m) -> fv_add m Mid.empty | Fglob (mp,m) -> fv_add mp (fv_add m Mid.empty) @@ -1132,7 +1255,28 @@ module Hsform = Why3.Hashcons.Make (struct { f with f_tag = n; f_fv = fv; } end) -let mk_form node ty = +(* -------------------------------------------------------------------- *) +let normalize_fnode (node : f_node) : f_node = + match node with + | Fquant (_, [], body) -> + body.f_node + + | Fquant (q1, bds1, { f_node = Fquant (q2, bds2, body) }) + when q1 = q2 + -> Fquant (q1, bds1 @ bds2, body) + + | Fapp (hd, []) -> + hd.f_node + + | Fapp ({ f_node = Fapp (hd, args1)}, args2) -> + Fapp (hd, args1 @ args2) + + | _ -> + node + +(* -------------------------------------------------------------------- *) +let mk_form (node : f_node) (ty : ty) = + let node = normalize_fnode (node) in let aout = Hsform.hashcons { f_node = node; diff --git a/src/ecAst.mli b/src/ecAst.mli index 9ef452a231..55e177353f 100644 --- a/src/ecAst.mli +++ b/src/ecAst.mli @@ -37,6 +37,13 @@ type mr_xpaths = EcPath.Sx.t use_restr type mr_mpaths = EcPath.Sm.t use_restr +(* -------------------------------------------------------------------- *) +module TyUni : EcUid.ICore with type uid = private EcUid.uid +module TcUni : EcUid.ICore with type uid = private EcUid.uid + +type tyuni = TyUni.uid +type tcuni = TcUni.uid + (* -------------------------------------------------------------------- *) type ty = private { ty_node : ty_node; @@ -46,12 +53,37 @@ type ty = private { and ty_node = | Tglob of EcIdent.t (* The tuple of global variable of the module *) - | Tunivar of EcUid.uid + | Tunivar of tyuni | Tvar of EcIdent.t | Ttuple of ty list - | Tconstr of EcPath.path * ty list + | Tconstr of EcPath.path * etyarg list | Tfun of ty * ty +(* -------------------------------------------------------------------- *) +and etyarg = ty * tcwitness list + +and tcwitness = + | TCIUni of tcuni + + | TCIConcrete of { + path: EcPath.path; + etyargs: (ty * tcwitness list) list; + } + + | TCIAbstract of { + support: [ + | `Var of EcIdent.t + | `Abs of EcPath.path + ]; + offset: int; + } + +(* -------------------------------------------------------------------- *) +and typeclass = { + tc_name : EcPath.path; + tc_args : etyarg list; +} + (* -------------------------------------------------------------------- *) and ovariable = { ov_name : EcSymbols.symbol option; @@ -79,7 +111,7 @@ and expr_node = | Eint of BI.zint (* int. literal *) | Elocal of EcIdent.t (* let-variables *) | Evar of prog_var (* module variable *) - | Eop of EcPath.path * ty list (* op apply to type args *) + | Eop of EcPath.path * etyarg list (* op apply to type args *) | Eapp of expr * expr list (* op. application *) | Equant of equantif * ebindings * expr (* fun/forall/exists *) | Elet of lpattern * expr * expr (* let binding *) @@ -92,7 +124,6 @@ and ebinding = EcIdent.t * ty and ebindings = ebinding list (* -------------------------------------------------------------------- *) - and lvalue = | LvVar of (prog_var * ty) | LvTuple of (prog_var * ty) list @@ -180,7 +211,7 @@ and f_node = | Flocal of EcIdent.t | Fpvar of prog_var * memory | Fglob of EcIdent.t * memory - | Fop of EcPath.path * ty list + | Fop of EcPath.path * etyarg list | Fapp of form * form list | Ftuple of form list | Fproj of form * int @@ -301,6 +332,17 @@ val lp_equal : lpattern equality val lp_hash : lpattern hash val lp_fv : lpattern -> EcIdent.Sid.t +(* -------------------------------------------------------------------- *) +val etyarg_fv : etyarg -> int Mid.t +val etyargs_fv : etyarg list -> int Mid.t +val etyarg_hash : etyarg -> int +val etyarg_equal : etyarg -> etyarg -> bool + +(* -------------------------------------------------------------------- *) +val tcw_fv : tcwitness -> int Mid.t +val tcw_hash : tcwitness -> int +val tcw_equal : tcwitness -> tcwitness -> bool + (* -------------------------------------------------------------------- *) val e_equal : expr equality val e_hash : expr hash diff --git a/src/ecBigInt.ml b/src/ecBigInt.ml index a9a8b5a845..85d741e473 100644 --- a/src/ecBigInt.ml +++ b/src/ecBigInt.ml @@ -71,6 +71,7 @@ module ZImpl : EcBigIntCore.TheInterface = struct with Failure _ -> raise InvalidString let pp_print = (Z.pp_print : Format.formatter -> zint -> unit) + let pp_zint = pp_print let to_why3 (x : zint) = Why3.BigInt.of_string (to_string x) @@ -148,6 +149,8 @@ module BigNumImpl : EcBigIntCore.TheInterface = struct let pp_print fmt x = Format.fprintf fmt "%s" (B.string_of_big_int x) + let pp_zint = pp_print + let to_why3 (x : zint) = Why3.BigInt.of_string (to_string x) end diff --git a/src/ecBigIntCore.ml b/src/ecBigIntCore.ml index 39d9391478..1b7de0b7e7 100644 --- a/src/ecBigIntCore.ml +++ b/src/ecBigIntCore.ml @@ -62,6 +62,7 @@ module type TheInterface = sig val to_string : zint -> string val pp_print : Format.formatter -> zint -> unit + val pp_zint : Format.formatter -> zint -> unit val to_why3 : zint -> Why3.BigInt.t end diff --git a/src/ecCallbyValue.ml b/src/ecCallbyValue.ml index aee423acb8..23ad0bebab 100644 --- a/src/ecCallbyValue.ml +++ b/src/ecCallbyValue.ml @@ -216,7 +216,7 @@ and betared st s bd f args = (* -------------------------------------------------------------------- *) and try_reduce_record_projection - (st : state) ((p, _tys) : EcPath.path * ty list) (args : args) + (st : state) ((p, _tys) : EcPath.path * EcAst.etyarg list) (args : args) = let exception Bailout in @@ -244,7 +244,7 @@ and try_reduce_record_projection (* -------------------------------------------------------------------- *) and try_reduce_fixdef - (st : state) ((p, tys) : EcPath.path * ty list) (args : args) + (st : state) ((p, tys) : EcPath.path * EcAst.etyarg list) (args : args) = let exception Bailout in @@ -299,7 +299,10 @@ and try_reduce_fixdef let body = EcFol.form_of_expr EcFol.mhr body in let body = - Tvar.f_subst ~freshen:true (List.map fst op.EcDecl.op_tparams) tys body in + Tvar.f_subst + ~freshen:true + (List.combine (List.fst op.EcDecl.op_tparams) tys) + body in Some (cbv st subst body (Args.create ty eargs)) @@ -336,7 +339,12 @@ and reduce_user_delta st f1 p tys args = | #Op.redmode as mode when Op.reducible ~mode ~nargs st.st_env p -> let f = Op.reduce ~mode ~nargs st.st_env p tys in cbv st Subst.subst_id f args - | _ -> f2 + | _ -> + if st.st_ri.delta_tc then begin + match Op.tc_reduce st.st_env p tys with + | f -> cbv st Subst.subst_id f args + | exception NotReducible -> f2 + end else f2 (* -------------------------------------------------------------------- *) and reduce_logic st f = diff --git a/src/ecCoreEqTest.ml b/src/ecCoreEqTest.ml new file mode 100644 index 0000000000..c16d062942 --- /dev/null +++ b/src/ecCoreEqTest.ml @@ -0,0 +1,86 @@ +(* -------------------------------------------------------------------- + * Copyright (c) - 2012--2016 - IMDEA Software Institute + * Copyright (c) - 2012--2018 - Inria + * Copyright (c) - 2012--2018 - Ecole Polytechnique + * + * Distributed under the terms of the CeCILL-C-V1 license + * -------------------------------------------------------------------- *) + +(* -------------------------------------------------------------------- *) +open EcUtils +open EcTypes +open EcEnv + +(* -------------------------------------------------------------------- *) +type 'a eqtest = env -> 'a -> 'a -> bool + +(* -------------------------------------------------------------------- *) +let rec for_type env t1 t2 = + ty_equal t1 t2 || for_type_r env t1 t2 + +(* -------------------------------------------------------------------- *) +and for_type_r env t1 t2 = + match t1.ty_node, t2.ty_node with + | Tunivar uid1, Tunivar uid2 -> + EcAst.TyUni.uid_equal uid1 uid2 + + | Tvar i1, Tvar i2 -> i1 = i2 + + | Ttuple lt1, Ttuple lt2 -> + List.length lt1 = List.length lt2 + && List.all2 (for_type env) lt1 lt2 + + | Tfun (t1, t2), Tfun (t1', t2') -> + for_type env t1 t1' && for_type env t2 t2' + + | Tglob m1, Tglob m2 -> EcIdent.id_equal m1 m2 + + | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> + if + List.length lt1 = List.length lt2 + && List.all2 (for_etyarg env) lt1 lt2 + then true + else + if Ty.defined p1 env + then for_type env (Ty.unfold p1 lt1 env) (Ty.unfold p2 lt2 env) + else false + + | Tconstr (p1, lt1), _ when Ty.defined p1 env -> + for_type env (Ty.unfold p1 lt1 env) t2 + + | _, Tconstr (p2, lt2) when Ty.defined p2 env -> + for_type env t1 (Ty.unfold p2 lt2 env) + + | _, _ -> false + +(* -------------------------------------------------------------------- *) +and for_etyarg env ((ty1, tcws1) : etyarg) ((ty2, tcws2) : etyarg) = + for_type env ty1 ty2 && for_tcws env tcws1 tcws2 + +and for_etyargs env (tyargs1 : etyarg list) (tyargs2 : etyarg list) = + List.length tyargs1 = List.length tyargs2 + && List.for_all2 (for_etyarg env) tyargs1 tyargs2 + +and for_tcw env (tcw1 : tcwitness) (tcw2 : tcwitness) = + match tcw1, tcw2 with + | TCIUni uid1, TCIUni uid2 -> + EcAst.TcUni.uid_equal uid1 uid2 + + | TCIConcrete tcw1, TCIConcrete tcw2 -> + EcPath.p_equal tcw1.path tcw2.path + && for_etyargs env tcw1.etyargs tcw2.etyargs + + | TCIAbstract { support = `Var v1; offset = o1 }, + TCIAbstract { support = `Var v2; offset = o2 } -> + EcIdent.id_equal v1 v2 && o1 = o2 + + | TCIAbstract { support = `Abs p1; offset = o1 }, + TCIAbstract { support = `Abs p2; offset = o2 } -> + EcPath.p_equal p1 p2 && o1 = o2 + + | _, _ -> + false + +and for_tcws env (tcws1 : tcwitness list) (tcws2 : tcwitness list) = + List.length tcws1 = List.length tcws2 + && List.for_all2 (for_tcw env) tcws1 tcws2 diff --git a/src/ecCoreEqTest.mli b/src/ecCoreEqTest.mli new file mode 100644 index 0000000000..aa6e5f705b --- /dev/null +++ b/src/ecCoreEqTest.mli @@ -0,0 +1,9 @@ +(* -------------------------------------------------------------------- *) +open EcTypes +open EcEnv + +(* -------------------------------------------------------------------- *) +type 'a eqtest = env -> 'a -> 'a -> bool + +val for_type : ty eqtest +val for_etyarg : etyarg eqtest diff --git a/src/ecCoreFol.ml b/src/ecCoreFol.ml index 962125360b..b4224bdc5d 100644 --- a/src/ecCoreFol.ml +++ b/src/ecCoreFol.ml @@ -12,12 +12,9 @@ module Sx = EcPath.Sx open EcBigInt.Notations (* -------------------------------------------------------------------- *) -type quantif = EcAst.quantif - +type quantif = EcAst.quantif type hoarecmp = EcAst.hoarecmp - -type gty = EcAst.gty - +type gty = EcAst.gty type binding = (EcIdent.t * gty) type bindings = binding list @@ -158,18 +155,14 @@ let mk_form = EcAst.mk_form let f_node { f_node = form } = form (* -------------------------------------------------------------------- *) -let f_op x tys ty = mk_form (Fop (x, tys)) ty +let f_op_tc x tyargs ty = + mk_form (Fop (x, tyargs)) ty -let f_app f args ty = - let f, args' = - match f.f_node with - | Fapp (f, args') -> (f, args') - | _ -> (f, []) - in let args' = args' @ args in +let f_op x tyargs ty = + f_op_tc x (List.map (fun ty -> (ty, [])) tyargs) ty - if List.is_empty args' then begin - (*if ty_equal ty f.f_ty then f else mk_form f.f_node ty *) f - end else mk_form (Fapp (f, args')) ty +let f_app f args ty = + mk_form (Fapp (f, args)) ty (* -------------------------------------------------------------------- *) let f_local x ty = mk_form (Flocal x) ty @@ -194,18 +187,18 @@ let f_tuple args = | [x] -> x | _ -> mk_form (Ftuple args) (ttuple (List.map f_ty args)) +(* -------------------------------------------------------------------- *) let f_quant q b f = - if List.is_empty b then f else - let (q, b, f) = - match f.f_node with - | Fquant(q',b',f') when q = q' -> (q, b@b', f') - | _ -> q, b , f in - let ty = - if q = Llambda - then toarrow (List.map (fun (_,gty) -> gty_as_ty gty) b) f.f_ty - else tbool in - - mk_form (Fquant (q, b, f)) ty + let ty = + match q with + | Llambda -> + let dom = + List.map (fun (_, gty) -> gty_as_ty gty) b + in toarrow dom f.f_ty + + | _ -> tbool in + + mk_form (Fquant (q, b, f)) ty let f_proj f i ty = mk_form (Fproj(f, i)) ty let f_if f1 f2 f3 = mk_form (Fif (f1, f2, f3)) f2.f_ty @@ -396,115 +389,88 @@ let f_some ({ f_ty = ty } as f : form) : form = f_app op [f] (toption ty) (* -------------------------------------------------------------------- *) -let f_map gt g fp = +let f_map (g : form -> form) (fp : form) : form = match fp.f_node with - | Fquant(q, b, f) -> - let map_gty ((x, gty) as b1) = - let gty' = - match gty with - | GTty ty -> - let ty' = gt ty in if ty == ty' then gty else GTty ty' - | _ -> gty - in - if gty == gty' then b1 else (x, gty') - in - - let b' = List.Smart.map map_gty b in - let f' = g f in - - f_quant q b' f' + | Fint _ -> fp + | Fglob _ -> fp + | Flocal _ -> fp + | Fpvar _ -> fp + | Fop _ -> fp - | Fint _ -> fp - | Fglob _ -> fp + | Fquant(q, b, f) -> + f_quant q b (g f) | Fif (f1, f2, f3) -> - f_if (g f1) (g f2) (g f3) + f_if (g f1) (g f2) (g f3) | Fmatch (b, fs, ty) -> - f_match (g b) (List.map g fs) (gt ty) + f_match (g b) (List.map g fs) ty | Flet (lp, f1, f2) -> - f_let lp (g f1) (g f2) - - | Flocal id -> - let ty' = gt fp.f_ty in - f_local id ty' - - | Fpvar (id, s) -> - let ty' = gt fp.f_ty in - f_pvar id ty' s - - | Fop (p, tys) -> - let tys' = List.Smart.map gt tys in - let ty' = gt fp.f_ty in - f_op p tys' ty' + f_let lp (g f1) (g f2) - | Fapp (f, fs) -> - let f' = g f in - let fs' = List.Smart.map g fs in - let ty' = gt fp.f_ty in - f_app f' fs' ty' + | Fapp (hd, args) -> + let hd = g hd in + let args = List.Smart.map g args in + f_app hd args fp.f_ty | Ftuple fs -> - let fs' = List.Smart.map g fs in - f_tuple fs' + f_tuple (List.Smart.map g fs) | Fproj (f, i) -> - let f' = g f in - let ty' = gt fp.f_ty in - f_proj f' i ty' + f_proj (g f) i fp.f_ty | FhoareF hf -> - let pr' = g hf.hf_pr in - let po' = g hf.hf_po in - f_hoareF_r { hf with hf_pr = pr'; hf_po = po'; } + let pr' = g hf.hf_pr in + let po' = g hf.hf_po in + f_hoareF_r { hf with hf_pr = pr'; hf_po = po'; } | FhoareS hs -> - let pr' = g hs.hs_pr in - let po' = g hs.hs_po in - f_hoareS_r { hs with hs_pr = pr'; hs_po = po'; } + let pr' = g hs.hs_pr in + let po' = g hs.hs_po in + f_hoareS_r { hs with hs_pr = pr'; hs_po = po'; } | FeHoareF hf -> - let pr' = g hf.ehf_pr in - let po' = g hf.ehf_po in - f_eHoareF_r { hf with ehf_pr = pr'; ehf_po = po' } + let pr' = g hf.ehf_pr in + let po' = g hf.ehf_po in + f_eHoareF_r { hf with ehf_pr = pr'; ehf_po = po' } | FeHoareS hs -> - let pr' = g hs.ehs_pr in - let po' = g hs.ehs_po in - f_eHoareS_r { hs with ehs_pr = pr'; ehs_po = po'; } + let pr' = g hs.ehs_pr in + let po' = g hs.ehs_po in + f_eHoareS_r { hs with ehs_pr = pr'; ehs_po = po'; } | FbdHoareF bhf -> - let pr' = g bhf.bhf_pr in - let po' = g bhf.bhf_po in - let bd' = g bhf.bhf_bd in - f_bdHoareF_r { bhf with bhf_pr = pr'; bhf_po = po'; bhf_bd = bd'; } + let pr' = g bhf.bhf_pr in + let po' = g bhf.bhf_po in + let bd' = g bhf.bhf_bd in + f_bdHoareF_r { bhf with bhf_pr = pr'; bhf_po = po'; bhf_bd = bd'; } | FbdHoareS bhs -> - let pr' = g bhs.bhs_pr in - let po' = g bhs.bhs_po in - let bd' = g bhs.bhs_bd in - f_bdHoareS_r { bhs with bhs_pr = pr'; bhs_po = po'; bhs_bd = bd'; } + let pr' = g bhs.bhs_pr in + let po' = g bhs.bhs_po in + let bd' = g bhs.bhs_bd in + f_bdHoareS_r { bhs with bhs_pr = pr'; bhs_po = po'; bhs_bd = bd'; } | FequivF ef -> - let pr' = g ef.ef_pr in - let po' = g ef.ef_po in - f_equivF_r { ef with ef_pr = pr'; ef_po = po'; } + let pr' = g ef.ef_pr in + let po' = g ef.ef_po in + f_equivF_r { ef with ef_pr = pr'; ef_po = po'; } | FequivS es -> - let pr' = g es.es_pr in - let po' = g es.es_po in - f_equivS_r { es with es_pr = pr'; es_po = po'; } + let pr' = g es.es_pr in + let po' = g es.es_po in + f_equivS_r { es with es_pr = pr'; es_po = po'; } | FeagerF eg -> - let pr' = g eg.eg_pr in - let po' = g eg.eg_po in - f_eagerF_r { eg with eg_pr = pr'; eg_po = po'; } + let pr' = g eg.eg_pr in + let po' = g eg.eg_po in + f_eagerF_r { eg with eg_pr = pr'; eg_po = po'; } | Fpr pr -> - let args' = g pr.pr_args in - let ev' = g pr.pr_event in - f_pr_r { pr with pr_args = args'; pr_event = ev'; } + let args' = g pr.pr_args in + let ev' = g pr.pr_event in + f_pr_r { pr with pr_args = args'; pr_event = ev'; } (* -------------------------------------------------------------------- *) let f_iter g f = @@ -910,7 +876,7 @@ let rec form_of_expr mem (e : expr) = f_pvar pv e.e_ty mem | Eop (op, tys) -> - f_op op tys e.e_ty + f_op_tc op tys e.e_ty | Eapp (ef, es) -> f_app (form_of_expr mem ef) (List.map (form_of_expr mem) es) e.e_ty @@ -950,7 +916,7 @@ let expr_of_form mh f = | Fint z -> e_int z | Flocal x -> e_local x fp.f_ty - | Fop (p, tys) -> e_op p tys fp.f_ty + | Fop (p, tys) -> e_op_tc p tys fp.f_ty | Fapp (f, fs) -> e_app (aux f) (List.map aux fs) fp.f_ty | Ftuple fs -> e_tuple (List.map aux fs) | Fproj (f, i) -> e_proj (aux f) i fp.f_ty diff --git a/src/ecCoreFol.mli b/src/ecCoreFol.mli index 07f61851d1..1b6b22db7a 100644 --- a/src/ecCoreFol.mli +++ b/src/ecCoreFol.mli @@ -14,12 +14,9 @@ val mleft : memory val mright : memory (* -------------------------------------------------------------------- *) -type quantif = EcAst.quantif - +type quantif = EcAst.quantif type hoarecmp = EcAst.hoarecmp - -type gty = EcAst.gty - +type gty = EcAst.gty type binding = (EcIdent.t * gty) type bindings = binding list @@ -79,8 +76,9 @@ val f_node : form -> f_node (* -------------------------------------------------------------------- *) (* not recursive *) -val f_map : (EcTypes.ty -> EcTypes.ty) -> (form -> form) -> form -> form +val f_map : (form -> form) -> form -> form val f_iter : (form -> unit) -> form -> unit + val form_exists: (form -> bool) -> form -> bool val form_forall: (form -> bool) -> form -> bool @@ -98,7 +96,8 @@ val f_pvloc : variable -> memory -> form val f_glob : EcIdent.t -> memory -> form (* soft-constructors - common formulas constructors *) -val f_op : path -> EcTypes.ty list -> EcTypes.ty -> form +val f_op : path -> ty list -> EcTypes.ty -> form +val f_op_tc : path -> etyarg list -> EcTypes.ty -> form val f_app : form -> form list -> EcTypes.ty -> form val f_tuple : form list -> form val f_proj : form -> int -> EcTypes.ty -> form @@ -254,13 +253,13 @@ val destr_forall1 : form -> ident * gty * form val destr_exists1 : form -> ident * gty * form val destr_lambda1 : form -> ident * gty * form -val destr_op : form -> EcPath.path * ty list +val destr_op : form -> EcPath.path * etyarg list val destr_local : form -> EcIdent.t val destr_pvar : form -> prog_var * memory val destr_proj : form -> form * int val destr_tuple : form -> form list val destr_app : form -> form * form list -val destr_op_app : form -> (EcPath.path * ty list) * form list +val destr_op_app : form -> (EcPath.path * etyarg list) * form list val destr_not : form -> form val destr_nots : form -> bool * form val destr_and : form -> form * form diff --git a/src/ecCoreGoal.ml b/src/ecCoreGoal.ml index 74ff095f5b..97d4440063 100644 --- a/src/ecCoreGoal.ml +++ b/src/ecCoreGoal.ml @@ -51,7 +51,7 @@ and pt_head = | PTCut of EcFol.form * cutsolve option | PTHandle of handle | PTLocal of EcIdent.t -| PTGlobal of EcPath.path * (ty list) +| PTGlobal of EcPath.path * etyarg list | PTTerm of proofterm and cutsolve = [`Done | `Smt | `DoneSmt] diff --git a/src/ecCoreGoal.mli b/src/ecCoreGoal.mli index f574b49bf3..7725546407 100644 --- a/src/ecCoreGoal.mli +++ b/src/ecCoreGoal.mli @@ -53,7 +53,7 @@ and pt_head = | PTCut of EcFol.form * cutsolve option | PTHandle of handle | PTLocal of EcIdent.t -| PTGlobal of EcPath.path * (ty list) +| PTGlobal of EcPath.path * etyarg list | PTTerm of proofterm and cutsolve = [`Done | `Smt | `DoneSmt] @@ -82,12 +82,12 @@ val pamemory : EcMemory.memory -> pt_arg val pamodule : EcPath.mpath * EcModules.module_sig -> pt_arg (* -------------------------------------------------------------------- *) -val paglobal : ?args:pt_arg list -> tys:ty list -> EcPath.path -> pt_arg +val paglobal : ?args:pt_arg list -> tys:etyarg list -> EcPath.path -> pt_arg val palocal : ?args:pt_arg list -> EcIdent.t -> pt_arg val pahandle : ?args:pt_arg list -> handle -> pt_arg (* -------------------------------------------------------------------- *) -val ptglobal : ?args:pt_arg list -> tys:ty list -> EcPath.path -> proofterm +val ptglobal : ?args:pt_arg list -> tys:etyarg list -> EcPath.path -> proofterm val ptlocal : ?args:pt_arg list -> EcIdent.t -> proofterm val pthandle : ?args:pt_arg list -> handle -> proofterm val ptcut : ?args:pt_arg list -> ?cutsolve:cutsolve -> EcFol.form -> proofterm diff --git a/src/ecCorePrinting.ml b/src/ecCorePrinting.ml index a8187a70e0..ae1690ee39 100644 --- a/src/ecCorePrinting.ml +++ b/src/ecCorePrinting.ml @@ -4,7 +4,7 @@ module type PrinterAPI = sig open EcIdent open EcSymbols open EcPath - open EcTypes + open EcAst open EcFol open EcDecl open EcModules @@ -59,7 +59,8 @@ module type PrinterAPI = sig val pp_mem : PPEnv.t -> EcIdent.t pp val pp_memtype : PPEnv.t -> EcMemory.memtype pp val pp_tyvar : PPEnv.t -> ident pp - val pp_tyunivar : PPEnv.t -> EcUid.uid pp + val pp_tyunivar : PPEnv.t -> EcAst.tyuni pp + val pp_tcunivar : PPEnv.t -> EcAst.tcuni pp val pp_path : path pp (* ------------------------------------------------------------------ *) @@ -70,6 +71,7 @@ module type PrinterAPI = sig (* ------------------------------------------------------------------ *) val pp_typedecl : PPEnv.t -> (path * tydecl ) pp + val pp_typeclass : PPEnv.t -> (typeclass ) pp val pp_opdecl : ?long:bool -> PPEnv.t -> (path * operator ) pp val pp_added_op : PPEnv.t -> operator pp val pp_axiom : ?long:bool -> PPEnv.t -> (path * axiom ) pp diff --git a/src/ecCoreSubst.ml b/src/ecCoreSubst.ml index 7e0253c63c..c234ee5372 100644 --- a/src/ecCoreSubst.ml +++ b/src/ecCoreSubst.ml @@ -14,17 +14,12 @@ type mod_extra = { mex_glob : memory -> form; } -type sc_instanciate = { - sc_memtype : memtype; - sc_mempred : mem_pr Mid.t; - sc_expr : expr Mid.t; -} - (* -------------------------------------------------------------------- *) type f_subst = { fs_freshen : bool; (* true means freshen locals *) - fs_u : ty Muid.t; - fs_v : ty Mid.t; + fs_u : ty TyUni.Muid.t; + fs_utc : tcwitness TcUni.Muid.t; + fs_v : etyarg Mid.t; fs_mod : EcPath.mpath Mid.t; fs_modex : mod_extra Mid.t; fs_loc : form Mid.t; @@ -49,25 +44,41 @@ let mex_fv (mp : mpath) (ex : mod_extra) : uid Mid.t = (* -------------------------------------------------------------------- *) let fv_Mid (type a) - (fv : a -> uid Mid.t) (m : a Mid.t) (s : uid Mid.t) : uid Mid.t + (fv : a -> int Mid.t) (m : a Mid.t) (s : int Mid.t) : int Mid.t = Mid.fold (fun _ t s -> fv_union s (fv t)) m s +(* -------------------------------------------------------------------- *) +type unisubst = { + uvars : ty TyUni.Muid.t; + utcvars : tcwitness TcUni.Muid.t; +} + +(* -------------------------------------------------------------------- *) +let unisubst0 : unisubst = { + uvars = TyUni.Muid.empty; + utcvars = TcUni.Muid.empty; +} + (* -------------------------------------------------------------------- *) let f_subst_init - ?(freshen=false) - ?(tu=Muid.empty) - ?(tv=Mid.empty) - ?(esloc=Mid.empty) - () = + ?(freshen = false) + ?(tu = unisubst0) + ?(tv = Mid.empty) + ?(esloc = Mid.empty) + () += + let fv = Mid.empty in - let fv = Muid.fold (fun _ t s -> fv_union s (ty_fv t)) tu fv in - let fv = fv_Mid ty_fv tv fv in + let fv = TyUni.Muid.fold (fun _ t s -> fv_union s (ty_fv t)) tu.uvars fv in + let fv = TcUni.Muid.fold (fun _ t s -> fv_union s (tcw_fv t)) tu.utcvars fv in + let fv = fv_Mid etyarg_fv tv fv in let fv = fv_Mid e_fv esloc fv in { fs_freshen = freshen; - fs_u = tu; + fs_u = tu.uvars; + fs_utc = tu.utcvars; fs_v = tv; fs_mod = Mid.empty; fs_modex = Mid.empty; @@ -158,7 +169,8 @@ let f_rem_mod (s : f_subst) (x : ident) : f_subst = (* -------------------------------------------------------------------- *) let is_ty_subst_id (s : f_subst) : bool = Mid.is_empty s.fs_mod - && Muid.is_empty s.fs_u + && TyUni.Muid.is_empty s.fs_u + && TcUni.Muid.is_empty s.fs_utc && Mid.is_empty s.fs_v (* -------------------------------------------------------------------- *) @@ -168,19 +180,78 @@ let rec ty_subst (s : f_subst) (ty : ty) : ty = Mid.find_opt m s.fs_modex |> Option.map (fun ex -> ex.mex_tglob) |> Option.value ~default:ty + | Tunivar id -> - Muid.find_opt id s.fs_u + TyUni.Muid.find_opt id s.fs_u |> Option.map (ty_subst s) |> Option.value ~default:ty + | Tvar id -> - Mid.find_def ty id s.fs_v - | _ -> - ty_map (ty_subst s) ty + Mid.find_opt id s.fs_v + |> Option.map fst + |> Option.value ~default:ty + + | Tfun (ty1, ty2) -> + let ty1 = ty_subst s ty1 in + let ty2 = ty_subst s ty2 in + tfun ty1 ty2 + + | Ttuple tys -> + let tys = List.Smart.map (ty_subst s) tys in + ttuple tys + + | Tconstr (p, etyargs) -> + let etyargs = List.Smart.map (etyarg_subst s) etyargs in + tconstr_tc p etyargs + +(* -------------------------------------------------------------------- *) +and tcw_subst (s : f_subst) (tcw : tcwitness) : tcwitness = + match tcw with + | TCIUni uid -> + TcUni.Muid.find_opt uid s.fs_utc + |> Option.value ~default:tcw + + | TCIConcrete ({ etyargs = etyargs0 } as rtcw) -> + let etyargs = List.Smart.map (etyarg_subst s) etyargs0 in + if etyargs ==(*phy*) etyargs0 then + tcw + else TCIConcrete { rtcw with etyargs } + + | TCIAbstract { support = `Var tyvar; offset } -> + Mid.find_opt tyvar s.fs_v + |> Option.map (fun (_, tcws) -> List.nth tcws offset) + |> Option.value ~default:tcw + + | TCIAbstract { support = `Abs _ } -> + tcw + +(* -------------------------------------------------------------------- *) +and etyarg_subst (s : f_subst) ((ty, tcws) as tyarg : etyarg) : etyarg = + let ty' = ty_subst s ty in + let tcws' = List.Smart.map (tcw_subst s) tcws in + SmartPair.mk tyarg ty' tcws' + +(* -------------------------------------------------------------------- *) +let tc_subst (s : f_subst) (tc : typeclass) : typeclass = + { tc_name = tc.tc_name; + tc_args = List.map (etyarg_subst s) tc.tc_args; } (* -------------------------------------------------------------------- *) let ty_subst (s : f_subst) : ty -> ty = if is_ty_subst_id s then identity else ty_subst s +(* -------------------------------------------------------------------- *) +let etyarg_subst (s : f_subst) : etyarg -> etyarg = + if is_ty_subst_id s then identity else etyarg_subst s + +(* -------------------------------------------------------------------- *) +let tcw_subst (s : f_subst) : tcwitness -> tcwitness = + if is_ty_subst_id s then identity else tcw_subst s + +(* -------------------------------------------------------------------- *) +let tc_subst (s : f_subst) : typeclass -> typeclass = + if is_ty_subst_id s then identity else tc_subst s + (* -------------------------------------------------------------------- *) let is_e_subst_id (s : f_subst) = not s.fs_freshen @@ -243,35 +314,57 @@ let elp_subst (s : f_subst) (lp : lpattern) : f_subst * lpattern = (* -------------------------------------------------------------------- *) let rec e_subst (s : f_subst) (e : expr) : expr = + let mk (node : expr_node) = + let ty = ty_subst s e.e_ty in + mk_expr node ty in + match e.e_node with + | Eint _ -> + e + | Elocal id -> begin match Mid.find_opt id s.fs_eloc with | Some e' -> e' - | None -> e_local id (ty_subst s e.e_ty) + | None -> mk (Elocal id) end | Evar pv -> - let pv' = pv_subst s pv in - let ty' = ty_subst s e.e_ty in - e_var pv' ty' + mk (Evar (pv_subst s pv)) - | Eop (p, tys) -> - let tys' = List.Smart.map (ty_subst s) tys in - let ty' = ty_subst s e.e_ty in - e_op p tys' ty' + | Eop (p, etyargs) -> + mk (Eop (p, List.Smart.map (etyarg_subst s) etyargs)) | Elet (lp, e1, e2) -> let e1' = e_subst s e1 in let s, lp' = elp_subst s lp in let e2' = e_subst s e2 in - e_let lp' e1' e2' + mk (Elet (lp', e1', e2')) - | Equant (q, b, e1) -> + | Equant (q, b, bd) -> let s, b' = add_elocals s b in - let e1' = e_subst s e1 in - e_quantif q b' e1' - - | _ -> e_map (ty_subst s) (e_subst s) e + let bd' = e_subst s bd in + mk (Equant (q, b', bd')) + + | Eapp (e, es) -> + let e = e_subst s e in + let es = List.Smart.map (e_subst s) es in + mk (Eapp (e, es)) + + | Etuple es -> + let es = List.Smart.map (e_subst s) es in + mk (Etuple es) + + | Eif (c, e1, e2) -> + mk (Eif (e_subst s c, e_subst s e1, e_subst s e2)) + + | Ematch (e, bs, ty) -> + let e = e_subst s e in + let bs = List.Smart.map (e_subst s) bs in + let ty = ty_subst s ty in + mk (Ematch (e, bs, ty)) + + | Eproj (e, (i : int)) -> + mk (Eproj (e_subst s e, i)) (* -------------------------------------------------------------------- *) let e_subst (s : f_subst) : expr -> expr= @@ -411,37 +504,44 @@ module Fsubst = struct (* ------------------------------------------------------------------ *) let rec f_subst ~(tx : tx) (s : f_subst) (fp : form) : form = + let f_subst = f_subst ~tx in + + let mk (node : f_node) : form = + let ty = ty_subst s fp.f_ty in + mk_form node ty in + tx ~before:fp ~after:(match fp.f_node with - | Fquant (q, b, f) -> - let s, b' = add_bindings s b in - let f' = f_subst ~tx s f in - f_quant q b' f' + | Fint _ -> + fp + + | Fquant (q, b, bd) -> + let s, b = add_bindings s b in + let bd = f_subst s bd in + mk (Fquant (q, b, bd)) | Flet (lp, f1, f2) -> - let f1' = f_subst ~tx s f1 in - let s, lp' = lp_subst s lp in - let f2' = f_subst ~tx s f2 in - f_let lp' f1' f2' - - | Flocal id -> begin - match Mid.find_opt id s.fs_loc with - | Some f -> - f - | None -> - let ty' = ty_subst s fp.f_ty in - f_local id ty' - end + let f1 = f_subst s f1 in + let s, lp = lp_subst s lp in + let f2 = f_subst s f2 in + mk (Flet (lp, f1, f2)) + + | Flocal id -> + Mid.find_opt id s.fs_loc + |> ofdfl (fun () -> mk (Flocal id)) - | Fop (p, tys) -> - let ty' = ty_subst s fp.f_ty in - let tys' = List.Smart.map (ty_subst s) tys in - f_op p tys' ty' + | Fop (p, etyargs) -> + let etyargs = List.Smart.map (etyarg_subst s) etyargs in + mk (Fop (p, etyargs)) + + | Fapp (f, fs) -> + let f = f_subst s f in + let fs = List.Smart.map (f_subst s) fs in + mk (Fapp (f, fs)) | Fpvar (pv, m) -> - let pv' = pv_subst s pv in - let m' = m_subst s m in - let ty' = ty_subst s fp.f_ty in - f_pvar pv' ty' m' + let pv = pv_subst s pv in + let m = m_subst s m in + mk (Fpvar (pv, m)) | Fglob (mid, m) -> let m' = m_subst s m in @@ -450,48 +550,68 @@ module Fsubst = struct | Some _ -> (Mid.find mid s.fs_modex).mex_glob m' end + | Ftuple fs -> + let fs = List.Smart.map (f_subst s) fs in + mk (Ftuple fs) + + | Fproj (f, (i : int)) -> + let f = f_subst s f in + mk (Fproj (f, i)) + + | Fif (c, f1, f2) -> + let c = f_subst s c in + let f1 = f_subst s f1 in + let f2 = f_subst s f2 in + mk (Fif (c, f1, f2)) + + | Fmatch (f, bs, ty) -> + let f = f_subst s f in + let bs = List.Smart.map (f_subst s) bs in + let ty = ty_subst s ty in + mk (Fmatch (f, bs, ty)) + | FhoareF hf -> let hf_f = x_subst s hf.hf_f in let s = f_rem_mem s mhr in - let hf_pr = f_subst ~tx s hf.hf_pr in - let hf_po = f_subst ~tx s hf.hf_po in + let hf_pr = f_subst s hf.hf_pr in + let hf_po = f_subst s hf.hf_po in f_hoareF hf_pr hf_f hf_po | FhoareS hs -> let hs_s = s_subst s hs.hs_s in let s, hs_m = add_me_binding s hs.hs_m in - let hs_pr = f_subst ~tx s hs.hs_pr in - let hs_po = f_subst ~tx s hs.hs_po in + let hs_pr = f_subst s hs.hs_pr in + let hs_po = f_subst s hs.hs_po in f_hoareS hs_m hs_pr hs_s hs_po | FeHoareF hf -> let hf_f = x_subst s hf.ehf_f in let s = f_rem_mem s mhr in - let hf_pr = f_subst ~tx s hf.ehf_pr in - let hf_po = f_subst ~tx s hf.ehf_po in + let hf_pr = f_subst s hf.ehf_pr in + let hf_po = f_subst s hf.ehf_po in f_eHoareF hf_pr hf_f hf_po | FeHoareS hs -> let hs_s = s_subst s hs.ehs_s in let s, hs_m = add_me_binding s hs.ehs_m in - let hs_pr = f_subst ~tx s hs.ehs_pr in - let hs_po = f_subst ~tx s hs.ehs_po in + let hs_pr = f_subst s hs.ehs_pr in + let hs_po = f_subst s hs.ehs_po in f_eHoareS hs_m hs_pr hs_s hs_po | FbdHoareF hf -> let hf_f = x_subst s hf.bhf_f in let s = f_rem_mem s mhr in - let hf_pr = f_subst ~tx s hf.bhf_pr in - let hf_po = f_subst ~tx s hf.bhf_po in - let hf_bd = f_subst ~tx s hf.bhf_bd in + let hf_pr = f_subst s hf.bhf_pr in + let hf_po = f_subst s hf.bhf_po in + let hf_bd = f_subst s hf.bhf_bd in f_bdHoareF hf_pr hf_f hf_po hf.bhf_cmp hf_bd | FbdHoareS hs -> let hs_s = s_subst s hs.bhs_s in let s, hs_m = add_me_binding s hs.bhs_m in - let hs_pr = f_subst ~tx s hs.bhs_pr in - let hs_po = f_subst ~tx s hs.bhs_po in - let hs_bd = f_subst ~tx s hs.bhs_bd in + let hs_pr = f_subst s hs.bhs_pr in + let hs_po = f_subst s hs.bhs_po in + let hs_bd = f_subst s hs.bhs_bd in f_bdHoareS hs_m hs_pr hs_s hs_po hs.bhs_cmp hs_bd | FequivF ef -> @@ -499,8 +619,8 @@ module Fsubst = struct let ef_fr = x_subst s ef.ef_fr in let s = f_rem_mem s mleft in let s = f_rem_mem s mright in - let ef_pr = f_subst ~tx s ef.ef_pr in - let ef_po = f_subst ~tx s ef.ef_po in + let ef_pr = f_subst s ef.ef_pr in + let ef_po = f_subst s ef.ef_po in f_equivF ef_pr ef_fl ef_fr ef_po | FequivS es -> @@ -508,8 +628,8 @@ module Fsubst = struct let es_sr = s_subst s es.es_sr in let s, es_ml = add_me_binding s es.es_ml in let s, es_mr = add_me_binding s es.es_mr in - let es_pr = f_subst ~tx s es.es_pr in - let es_po = f_subst ~tx s es.es_po in + let es_pr = f_subst s es.es_pr in + let es_po = f_subst s es.es_po in f_equivS es_ml es_mr es_pr es_sl es_sr es_po | FeagerF eg -> @@ -519,21 +639,18 @@ module Fsubst = struct let eg_sr = s_subst s eg.eg_sr in let s = f_rem_mem s mleft in let s = f_rem_mem s mright in - let eg_pr = f_subst ~tx s eg.eg_pr in - let eg_po = f_subst ~tx s eg.eg_po in + let eg_pr = f_subst s eg.eg_pr in + let eg_po = f_subst s eg.eg_po in f_eagerF eg_pr eg_sl eg_fl eg_fr eg_sr eg_po | Fpr pr -> let pr_mem = m_subst s pr.pr_mem in let pr_fun = x_subst s pr.pr_fun in - let pr_args = f_subst ~tx s pr.pr_args in + let pr_args = f_subst s pr.pr_args in let s = f_rem_mem s mhr in - let pr_event = f_subst ~tx s pr.pr_event in + let pr_event = f_subst s pr.pr_event in - f_pr pr_mem pr_fun pr_args pr_event - - | _ -> - f_map (ty_subst s) (f_subst ~tx s) fp) + f_pr pr_mem pr_fun pr_args pr_event) (* ------------------------------------------------------------------ *) and oi_subst (s : f_subst) (oi : PreOI.t) : PreOI.t = @@ -667,60 +784,65 @@ module Fsubst = struct fun f -> if Mid.mem m1 f.f_fv then f_subst s f else f (* ------------------------------------------------------------------ *) - let init_subst_tvar ~(freshen : bool) (s : ty Mid.t) : f_subst = + let init_subst_tvar ~(freshen : bool) (s : etyarg Mid.t) : f_subst = f_subst_init ~freshen ~tv:s () - let f_subst_tvar ~(freshen : bool) (s : ty Mid.t) : form -> form = + let f_subst_tvar ~(freshen : bool) (s : etyarg Mid.t) : form -> form = f_subst (init_subst_tvar ~freshen s) end (* -------------------------------------------------------------------- *) module Tuni = struct - let subst (uidmap : ty Muid.t) : f_subst = + let subst (uidmap : unisubst) : f_subst = f_subst_init ~tu:uidmap () - let subst1 ((id, t) : uid * ty) : f_subst = - subst (Muid.singleton id t) + let subst1 ((id, t) : tyuni * ty) : f_subst = + subst { unisubst0 with uvars = TyUni.Muid.singleton id t } - let subst_dom (uidmap : ty Muid.t) (dom : dom) : dom = + let subst_dom (uidmap : unisubst) (dom : dom) : dom = List.map (ty_subst (subst uidmap)) dom - let occurs (u : uid) : ty -> bool = + let occurs (u : tyuni) : ty -> bool = let rec aux t = match t.ty_node with - | Tunivar u' -> uid_equal u u' + | Tunivar u' -> TyUni.uid_equal u u' | _ -> ty_sub_exists aux t in aux - let univars : ty -> Suid.t = + let univars : ty -> TyUni.Suid.t = let rec doit univars t = match t.ty_node with - | Tunivar uid -> Suid.add uid univars + | Tunivar uid -> TyUni.Suid.add uid univars | _ -> ty_fold doit univars t - in fun t -> doit Suid.empty t + in fun t -> doit TyUni.Suid.empty t - let rec fv_rec (fv : Suid.t) (t : ty) : Suid.t = + let rec fv_rec (fv : TyUni.Suid.t) (t : ty) : TyUni.Suid.t = match t.ty_node with - | Tunivar id -> Suid.add id fv + | Tunivar id -> TyUni.Suid.add id fv | _ -> ty_fold fv_rec fv t - let fv (ty : ty) : Suid.t = - fv_rec Suid.empty ty + let fv (ty : ty) : TyUni.Suid.t = + fv_rec TyUni.Suid.empty ty end (* -------------------------------------------------------------------- *) module Tvar = struct - let subst (s : ty Mid.t) (ty : ty) : ty = + let subst (s : etyarg Mid.t) (ty : ty) : ty = ty_subst { f_subst_id with fs_v = s } ty - let subst1 ((id, t) : ebinding) (ty : ty) : ty = + let subst1 ((id, t) : ident * etyarg) (ty : ty) : ty = subst (Mid.singleton id t) ty - let init (lv : ident list) (lt : ty list) : ty Mid.t = - assert (List.length lv = List.length lt); - List.fold_left2 (fun s v t -> Mid.add v t s) Mid.empty lv lt + let init (init : (ident * etyarg) list) : etyarg Mid.t = + Mid.of_list init + + let subst_etyarg (s : etyarg Mid.t) (ety : etyarg) : etyarg = + etyarg_subst { f_subst_id with fs_v = s } ety + + let subst_tc (s : etyarg Mid.t) (tc : typeclass) : typeclass = + tc_subst { f_subst_id with fs_v = s } tc - let f_subst ~(freshen : bool) (lv : ident list) (lt : ty list) : form -> form = - Fsubst.f_subst_tvar ~freshen (init lv lt) + let f_subst ~(freshen : bool) (bds : (ident * etyarg) list) : form -> form = + Fsubst.f_subst_tvar ~freshen (init bds) end diff --git a/src/ecCoreSubst.mli b/src/ecCoreSubst.mli index 80531ef9c6..a22d5f572c 100644 --- a/src/ecCoreSubst.mli +++ b/src/ecCoreSubst.mli @@ -1,5 +1,4 @@ (* -------------------------------------------------------------------- *) -open EcUid open EcIdent open EcPath open EcAst @@ -7,13 +6,6 @@ open EcTypes open EcCoreModules open EcCoreFol -(* -------------------------------------------------------------------- *) -type sc_instanciate = { - sc_memtype : memtype; - sc_mempred : mem_pr Mid.t; - sc_expr : expr Mid.t; -} - (* -------------------------------------------------------------------- *) type f_subst @@ -23,31 +15,40 @@ type tx = before:form -> after:form -> form type 'a tx_substitute = ?tx:tx -> 'a substitute type 'a subst_binder = f_subst -> 'a -> f_subst * 'a +(* -------------------------------------------------------------------- *) +type unisubst = { + uvars : ty TyUni.Muid.t; + utcvars : tcwitness TcUni.Muid.t; +} + (* -------------------------------------------------------------------- *) val f_subst_init : ?freshen:bool - -> ?tu:ty Muid.t - -> ?tv:ty Mid.t + -> ?tu:unisubst + -> ?tv:etyarg Mid.t -> ?esloc:expr Mid.t -> unit -> f_subst (* -------------------------------------------------------------------- *) module Tuni : sig - val univars : ty -> Suid.t - val subst1 : (uid * ty) -> f_subst - val subst : ty Muid.t -> f_subst - val subst_dom : ty Muid.t -> dom -> dom - val occurs : uid -> ty -> bool - val fv : ty -> Suid.t + val univars : ty -> TyUni.Suid.t + val subst1 : (tyuni * ty) -> f_subst + val subst : unisubst -> f_subst + val subst_dom : unisubst -> dom -> dom + val occurs : tyuni -> ty -> bool + val fv : ty -> TyUni.Suid.t end (* -------------------------------------------------------------------- *) module Tvar : sig - val init : EcIdent.t list -> ty list -> ty Mid.t - val subst1 : (EcIdent.t * ty) -> ty -> ty - val subst : ty Mid.t -> ty -> ty - val f_subst : freshen:bool -> EcIdent.t list -> ty list -> form -> form + val init : (EcIdent.t * etyarg) list -> etyarg Mid.t + val subst1 : (EcIdent.t * etyarg) -> ty -> ty + val subst : etyarg Mid.t -> ty -> ty + val subst_etyarg : etyarg Mid.t -> etyarg -> etyarg + val subst_tc : etyarg Mid.t -> typeclass -> typeclass + + val f_subst : freshen:bool -> (EcIdent.t * etyarg) list -> form -> form end (* -------------------------------------------------------------------- *) @@ -55,11 +56,12 @@ val add_elocal : (EcIdent.t * ty) subst_binder val add_elocals : (EcIdent.t * ty) list subst_binder val bind_elocal : f_subst -> EcIdent.t -> expr -> f_subst - (* -------------------------------------------------------------------- *) -val ty_subst : ty substitute -val e_subst : expr substitute -val s_subst : stmt substitute +val ty_subst : ty substitute +val etyarg_subst : etyarg substitute +val tc_subst : typeclass substitute +val e_subst : expr substitute +val s_subst : stmt substitute (* -------------------------------------------------------------------- *) module Fsubst : sig @@ -68,8 +70,8 @@ module Fsubst : sig val f_subst_init : ?freshen:bool - -> ?tu:ty Muid.t - -> ?tv:ty Mid.t + -> ?tu:unisubst + -> ?tv:etyarg Mid.t -> ?esloc:expr Mid.t -> unit -> f_subst @@ -85,11 +87,7 @@ module Fsubst : sig val f_subst_local : EcIdent.t -> form -> form -> form val f_subst_mem : EcIdent.t -> EcIdent.t -> form -> form - - val f_subst_tvar : - freshen:bool -> - EcTypes.ty EcIdent.Mid.t -> - form -> form + val f_subst_tvar : freshen:bool -> etyarg Mid.t -> form -> form val add_binding : binding subst_binder val add_bindings : bindings subst_binder diff --git a/src/ecDecl.ml b/src/ecDecl.ml index 5806407fa3..0f6084d0fb 100644 --- a/src/ecDecl.ml +++ b/src/ecDecl.ml @@ -5,13 +5,12 @@ open EcTypes open EcCoreFol module Sp = EcPath.Sp -module TC = EcTypeClass module BI = EcBigInt module Ssym = EcSymbols.Ssym module CS = EcCoreSubst (* -------------------------------------------------------------------- *) -type ty_param = EcIdent.t * EcPath.Sp.t +type ty_param = EcIdent.t * typeclass list type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] @@ -24,7 +23,7 @@ type tydecl = { and ty_body = [ | `Concrete of EcTypes.ty - | `Abstract of Sp.t + | `Abstract of typeclass list | `Datatype of ty_dtype | `Record of EcCoreFol.form * (EcSymbols.symbol * EcTypes.ty) list ] @@ -48,7 +47,7 @@ let tydecl_as_record (td : tydecl) = match td.tyd_type with `Record x -> Some x | _ -> None (* -------------------------------------------------------------------- *) -let abs_tydecl ?(resolve = true) ?(tc = Sp.empty) ?(params = `Int 0) lc = +let abs_tydecl ?(resolve = true) ?(tc = []) ?(params = `Int 0) lc = let params = match params with | `Named params -> @@ -56,15 +55,26 @@ let abs_tydecl ?(resolve = true) ?(tc = Sp.empty) ?(params = `Int 0) lc = | `Int n -> let fmt = fun x -> Printf.sprintf "'%s" x in List.map - (fun x -> (EcIdent.create x, Sp.empty)) + (fun x -> (EcIdent.create x, [])) (EcUid.NameGen.bulk ~fmt n) in - { tyd_params = params; tyd_type = `Abstract tc; tyd_resolve = resolve; tyd_loca = lc; } + { tyd_params = params; + tyd_type = `Abstract tc; + tyd_resolve = resolve; + tyd_loca = lc; } (* -------------------------------------------------------------------- *) -let ty_instanciate (params : ty_params) (args : ty list) (ty : ty) = - let subst = CS.Tvar.init (List.map fst params) args in +let etyargs_of_tparams (tps : ty_params) : etyarg list = + List.map (fun (a, tcs) -> + let ety = + List.mapi (fun offset _ -> TCIAbstract { support = `Var a; offset }) tcs + in (tvar a, ety) + ) tps + +(* -------------------------------------------------------------------- *) +let ty_instanciate (params : ty_params) (args : etyarg list) (ty : ty) = + let subst = CS.Tvar.init (List.combine (List.map fst params) args) in CS.Tvar.subst subst ty (* -------------------------------------------------------------------- *) @@ -81,7 +91,7 @@ and opbody = | OP_Record of EcPath.path | OP_Proj of EcPath.path * int * int | OP_Fix of opfix - | OP_TC + | OP_TC of EcPath.path * string and prbody = | PR_Plain of form @@ -176,6 +186,11 @@ let is_rcrd op = | OB_oper (Some (OP_Record _)) -> true | _ -> false +let is_tc_op op = + match op.op_kind with + | OB_oper (Some (OP_TC _)) -> true + | _ -> false + let is_fix op = match op.op_kind with | OB_oper (Some (OP_Fix _)) -> true @@ -249,41 +264,18 @@ let operator_as_prind (op : operator) = | OB_pred (Some (PR_Ind pri)) -> pri | _ -> assert false -(* -------------------------------------------------------------------- *) -let axiomatized_op ?(nargs = 0) ?(nosmt = false) path (tparams, axbd) lc = - let axbd, axpm = - let bdpm = List.map fst tparams in - let axpm = List.map EcIdent.fresh bdpm in - (CS.Tvar.f_subst ~freshen:true bdpm (List.map EcTypes.tvar axpm) axbd, - List.combine axpm (List.map snd tparams)) - in - - let args, axbd = - match axbd.f_node with - | Fquant (Llambda, bds, axbd) -> - let bds, flam = List.split_at nargs bds in - (bds, f_lambda flam axbd) - | _ -> [], axbd - in - - let opargs = List.map (fun (x, ty) -> f_local x (gty_as_ty ty)) args in - let tyargs = List.map (EcTypes.tvar |- fst) axpm in - let op = f_op path tyargs (toarrow (List.map f_ty opargs) axbd.EcAst.f_ty) in - let op = f_app op opargs axbd.f_ty in - let axspec = f_forall args (f_eq op axbd) in - - { ax_tparams = axpm; - ax_spec = axspec; - ax_kind = `Axiom (Ssym.empty, false); - ax_loca = lc; - ax_visibility = if nosmt then `NoSmt else `Visible; } +let operator_as_tc (op : operator) = + match op.op_kind with + | OB_oper (Some OP_TC (tcpath, name)) -> (tcpath, name) + | _ -> assert false (* -------------------------------------------------------------------- *) -type typeclass = { - tc_prt : EcPath.path option; - tc_ops : (EcIdent.t * EcTypes.ty) list; - tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; - tc_loca: is_local; +type tc_decl = { + tc_tparams : ty_params; + tc_prt : typeclass option; + tc_ops : (EcIdent.t * EcTypes.ty) list; + tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; + tc_loca : is_local; } (* -------------------------------------------------------------------- *) diff --git a/src/ecDecl.mli b/src/ecDecl.mli index 65e2dea27c..22ee075d46 100644 --- a/src/ecDecl.mli +++ b/src/ecDecl.mli @@ -1,13 +1,13 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcAst open EcSymbols open EcBigInt -open EcPath open EcTypes open EcCoreFol (* -------------------------------------------------------------------- *) -type ty_param = EcIdent.t * EcPath.Sp.t +type ty_param = EcIdent.t * typeclass list type ty_params = ty_param list type ty_pctor = [ `Int of int | `Named of ty_params ] @@ -20,7 +20,7 @@ type tydecl = { and ty_body = [ | `Concrete of EcTypes.ty - | `Abstract of Sp.t + | `Abstract of typeclass list | `Datatype of ty_dtype | `Record of form * (EcSymbols.symbol * EcTypes.ty) list ] @@ -32,13 +32,15 @@ and ty_dtype = { } val tydecl_as_concrete : tydecl -> EcTypes.ty option -val tydecl_as_abstract : tydecl -> Sp.t option +val tydecl_as_abstract : tydecl -> typeclass list option val tydecl_as_datatype : tydecl -> ty_dtype option val tydecl_as_record : tydecl -> (form * (EcSymbols.symbol * EcTypes.ty) list) option -val abs_tydecl : ?resolve:bool -> ?tc:Sp.t -> ?params:ty_pctor -> locality -> tydecl +val abs_tydecl : ?resolve:bool -> ?tc:typeclass list -> ?params:ty_pctor -> locality -> tydecl -val ty_instanciate : ty_params -> ty list -> ty -> ty +val etyargs_of_tparams : ty_params -> etyarg list + +val ty_instanciate : ty_params -> etyarg list -> ty -> ty (* -------------------------------------------------------------------- *) type locals = EcIdent.t list @@ -54,7 +56,7 @@ and opbody = | OP_Record of EcPath.path | OP_Proj of EcPath.path * int * int | OP_Fix of opfix - | OP_TC + | OP_TC of EcPath.path * string and prbody = | PR_Plain of form @@ -112,6 +114,7 @@ val is_oper : operator -> bool val is_ctor : operator -> bool val is_proj : operator -> bool val is_rcrd : operator -> bool +val is_tc_op : operator -> bool val is_fix : operator -> bool val is_abbrev : operator -> bool val is_prind : operator -> bool @@ -130,6 +133,7 @@ val operator_as_rcrd : operator -> EcPath.path val operator_as_proj : operator -> EcPath.path * int * int val operator_as_fix : operator -> opfix val operator_as_prind : operator -> prind +val operator_as_tc : operator -> EcPath.path * string (* -------------------------------------------------------------------- *) type axiom_kind = [`Axiom of (Ssym.t * bool) | `Lemma] @@ -149,20 +153,12 @@ val is_axiom : axiom_kind -> bool val is_lemma : axiom_kind -> bool (* -------------------------------------------------------------------- *) -val axiomatized_op : - ?nargs: int - -> ?nosmt:bool - -> EcPath.path - -> (ty_params * form) - -> locality - -> axiom - -(* -------------------------------------------------------------------- *) -type typeclass = { - tc_prt : EcPath.path option; - tc_ops : (EcIdent.t * EcTypes.ty) list; - tc_axs : (EcSymbols.symbol * form) list; - tc_loca: is_local; +type tc_decl = { + tc_tparams : ty_params; + tc_prt : typeclass option; + tc_ops : (EcIdent.t * EcTypes.ty) list; + tc_axs : (EcSymbols.symbol * EcCoreFol.form) list; + tc_loca : is_local; } (* -------------------------------------------------------------------- *) diff --git a/src/ecEnv.ml b/src/ecEnv.ml index 34662c4b04..346d138535 100644 --- a/src/ecEnv.ml +++ b/src/ecEnv.ml @@ -18,8 +18,8 @@ module Msym = EcSymbols.Msym module Mp = EcPath.Mp module Sid = EcIdent.Sid module Mid = EcIdent.Mid -module TC = EcTypeClass module Mint = EcMaps.Mint +module Mstr = EcMaps.Mstr (* -------------------------------------------------------------------- *) type 'a suspension = { @@ -89,7 +89,8 @@ type mc = { mc_operators : (ipath * EcDecl.operator) MMsym.t; mc_axioms : (ipath * EcDecl.axiom) MMsym.t; mc_theories : (ipath * ctheory) MMsym.t; - mc_typeclasses: (ipath * typeclass) MMsym.t; + mc_typeclasses: (ipath * tc_decl) MMsym.t; + mc_tcinstances: (ipath * tcinstance) MMsym.t; mc_rwbase : (ipath * path) MMsym.t; mc_components : ipath MMsym.t; } @@ -178,8 +179,7 @@ type preenv = { env_memories : EcMemory.memtype Mmem.t; env_actmem : EcMemory.memory option; env_abs_st : EcModules.abs_uses Mid.t; - env_tci : ((ty_params * ty) * tcinstance) list; - env_tc : TC.graph; + env_tci : (path option * tcinstance) list; env_rwbase : Sp.t Mip.t; env_atbase : (path list Mint.t) Msym.t; env_redbase : mredinfo; @@ -205,12 +205,6 @@ and scope = [ | `Fun of EcPath.xpath ] -and tcinstance = [ - | `Ring of EcDecl.ring - | `Field of EcDecl.field - | `General of EcPath.path -] - and redinfo = { ri_priomap : (EcTheory.rule list) Mint.t; ri_list : (EcTheory.rule list) Lazy.t; } @@ -272,6 +266,7 @@ let empty_mc params = { mc_variables = MMsym.empty; mc_functions = MMsym.empty; mc_typeclasses= MMsym.empty; + mc_tcinstances= MMsym.empty; mc_rwbase = MMsym.empty; mc_components = MMsym.empty; } @@ -303,7 +298,6 @@ let empty gstate = env_actmem = None; env_abs_st = Mid.empty; env_tci = []; - env_tc = TC.Graph.empty; env_rwbase = Mip.empty; env_atbase = Msym.empty; env_redbase = Mrd.empty; @@ -501,12 +495,13 @@ module MC = struct | IPIdent _ -> assert false | IPPath p -> p - let _downpath_for_tydecl = _downpath_for_th - let _downpath_for_modsig = _downpath_for_th - let _downpath_for_operator = _downpath_for_th - let _downpath_for_axiom = _downpath_for_th - let _downpath_for_typeclass = _downpath_for_th - let _downpath_for_rwbase = _downpath_for_th + let _downpath_for_tydecl = _downpath_for_th + let _downpath_for_modsig = _downpath_for_th + let _downpath_for_operator = _downpath_for_th + let _downpath_for_axiom = _downpath_for_th + let _downpath_for_typeclass = _downpath_for_th + let _downpath_for_tcinstance = _downpath_for_th + let _downpath_for_rwbase = _downpath_for_th (* ------------------------------------------------------------------ *) let _params_of_path p env = @@ -899,10 +894,12 @@ module MC = struct let on1 (opid, optype) = let opname = EcIdent.name opid in let optype = EcSubst.subst_ty tsubst optype in - let opdecl = - mk_op ~opaque:optransparent [(self, Sp.singleton mypath)] - optype (Some OP_TC) loca - in (opid, xpath opname, optype, opdecl) + let tcargs = etyargs_of_tparams tc.tc_tparams in + let opargs = (self, [{tc_name = mypath; tc_args = tcargs;}]) in + let opargs = tc.tc_tparams @ [opargs] in + let opdecl = OP_TC (mypath, opname) in + let opdecl = mk_op ~opaque:optransparent opargs optype (Some opdecl) loca in + (opid, xpath opname, optype, opdecl) in List.map on1 tc.tc_ops in @@ -919,8 +916,11 @@ module MC = struct let axioms = List.map (fun (x, ax) -> + let tcargs = etyargs_of_tparams tc.tc_tparams in + let axargs = (self, [{tc_name = mypath; tc_args = tcargs}]) in + let axargs = tc.tc_tparams @ [axargs] in let ax = EcSubst.subst_form fsubst ax in - (x, { ax_tparams = [(self, Sp.singleton mypath)]; + (x, { ax_tparams = axargs; ax_spec = ax; ax_kind = `Lemma; ax_loca = loca; @@ -944,6 +944,20 @@ module MC = struct let import_typeclass p ax env = import (_up_typeclass true) (IPPath p) ax env + (* -------------------------------------------------------------------- *) + let lookup_tcinstance qnx env = + match lookup (fun mc -> mc.mc_tcinstances) qnx env with + | None -> lookup_error (`QSymbol qnx) + | Some (p, (args, obj)) -> (_downpath_for_tcinstance env p args, obj) + + let _up_tcinstance candup mc x obj= + if not candup && MMsym.last x mc.mc_tcinstances <> None then + raise (DuplicatedBinding x); + { mc with mc_tcinstances = MMsym.add x obj mc.mc_tcinstances } + + let import_tcinstance p tci env = + import (_up_tcinstance true) (IPPath p) tci env + (* -------------------------------------------------------------------- *) let lookup_rwbase qnx env = match lookup (fun mc -> mc.mc_rwbase) qnx env with @@ -1099,11 +1113,17 @@ module MC = struct | Th_typeclass (x, tc) -> (add2mc _up_typeclass x tc mc, None) + | Th_instance (x, tci) -> + let mc = + x |> Option.fold + ~none:mc + ~some:(fun x -> add2mc _up_tcinstance x tci mc) + in (mc, None) + | Th_baserw (x, _) -> (add2mc _up_rwbase x (expath x) mc, None) - | Th_export _ | Th_addrw _ | Th_instance _ - | Th_auto _ | Th_reduction _ -> + | Th_export _ | Th_addrw _ | Th_auto _ | Th_reduction _ -> (mc, None) in @@ -1182,6 +1202,9 @@ module MC = struct and bind_typeclass x tc env = bind _up_typeclass x tc env + and bind_tcinstance x tci env = + bind _up_tcinstance x tci env + and bind_rwbase x p env = bind _up_rwbase x p env end @@ -1338,7 +1361,7 @@ let gen_all fmc flk ?(check = fun _ _ -> true) ?name (env : env) = (* ------------------------------------------------------------------ *) module TypeClass = struct - type t = typeclass + type t = tc_decl let by_path_opt (p : EcPath.path) (env : env) = omap @@ -1351,47 +1374,77 @@ module TypeClass = struct | Some obj -> obj let add (p : EcPath.path) (env : env) = - let obj = by_path p env in - MC.import_typeclass p obj env + MC.import_typeclass p (by_path p env) env - let rebind name tc env = - let env = MC.bind_typeclass name tc env in - match tc.tc_prt with - | None -> env - | Some prt -> - let myself = EcPath.pqname (root env) name in - { env with env_tc = TC.Graph.add ~src:myself ~dst:prt env.env_tc } + let rebind (name : symbol) (tc : t) (env : env) = + MC.bind_typeclass name tc env - let bind ?(import = import0) name tc env = + let bind ?(import = import0) (name : symbol) (tc : t) (env : env) = let env = if import.im_immediate then rebind name tc env else env in { env with env_item = mkitem import (Th_typeclass (name, tc)) :: env.env_item } - let lookup qname (env : env) = + let lookup (qname : qsymbol) (env : env) = MC.lookup_typeclass qname env - let lookup_opt name env = + let lookup_opt (name : qsymbol) (env : env) = try_lf (fun () -> lookup name env) - let lookup_path name env = + let lookup_path (name : qsymbol) (env : env) = fst (lookup name env) +end - let graph (env : env) = - env.env_tc +(* ------------------------------------------------------------------ *) +module TcInstance = struct + type t = tcinstance - let bind_instance ty cr tci = - (ty, cr) :: tci + let by_path_opt (p : EcPath.path) (env : env) = + omap + check_not_suspended + (MC.by_path (fun mc -> mc.mc_tcinstances) (IPPath p) env) + + let by_path (p : EcPath.path) (env : env) = + match by_path_opt p env with + | None -> lookup_error (`Path p) + | Some obj -> obj + + let add (p : EcPath.path) (env : env) = + MC.import_tcinstance p (by_path p env) env + + let bind_instance (path : path option) (tci : t) (env : _) = + (path, tci) :: env - let add_instance ?(import = import0) ty cr lc env = + let rebind (name : symbol option) (tci : t) (env : env) = let env = - if import.im_immediate then - { env with env_tci = bind_instance ty cr env.env_tci } - else env in + name |> Option.fold ~none:env ~some:(fun name -> + MC.bind_tcinstance name tci env) + in + let path = + Option.map + (fun name -> EcPath.pqname (root env) name) + name + in { env with env_tci = bind_instance path tci env.env_tci } + + let bind ?(import = import0) (name : symbol option) (tci : t) (env : env) = + let env = + if import.im_immediate then rebind name tci env else env in { env with - env_tci = bind_instance ty cr env.env_tci; - env_item = mkitem import (Th_instance (ty, cr, lc)) :: env.env_item; } + env_item = mkitem import (Th_instance (name, tci)) :: env.env_item } + + let lookup qname (env : env) = + MC.lookup_tcinstance qname env + + let lookup_opt (name : qsymbol) (env : env) = + try_lf (fun () -> lookup name env) - let get_instances env = env.env_tci + let lookup_path (name : qsymbol) (env : env) = + fst (lookup name env) + + let get_instances (env : env) = + env.env_tci + + let get_all (env : env) : (path option * t) list = + env.env_tci end (* -------------------------------------------------------------------- *) @@ -2479,7 +2532,7 @@ module Ty = struct let add (p : EcPath.path) (env : env) = let obj = by_path p env in - MC.import_tydecl p obj env + MC.import_tydecl p obj env let lookup qname (env : env) = MC.lookup_tydecl qname env @@ -2495,11 +2548,11 @@ module Ty = struct | Some { tyd_type = `Concrete _ } -> true | _ -> false - let unfold (name : EcPath.path) (args : EcTypes.ty list) (env : env) = + let unfold (name : EcPath.path) (args : etyarg list) (env : env) = match by_path_opt name env with | Some ({ tyd_type = `Concrete body } as tyd) -> Tvar.subst - (Tvar.init (List.map fst tyd.tyd_params) args) + (Tvar.init (List.combine (List.fst tyd.tyd_params) args)) body | _ -> raise (LookupFailure (`Path name)) @@ -2508,13 +2561,11 @@ module Ty = struct | Tconstr (p, tys) when defined p env -> hnorm (unfold p tys env) env | _ -> ty - let rec ty_hnorm (ty : ty) (env : env) = match ty.ty_node with | Tconstr (p, tys) when defined p env -> ty_hnorm (unfold p tys env) env | _ -> ty - let rec decompose_fun (ty : ty) (env : env) : dom * ty = match (hnorm ty env).ty_node with | Tfun (ty1, ty2) -> @@ -2552,30 +2603,14 @@ module Ty = struct | Tconstr (p, tys) -> Some (p, oget (by_path_opt p env), tys) | _ -> None - let rebind name ty env = - let env = MC.bind_tydecl name ty env in - - match ty.tyd_type with - | `Abstract tc -> - let myty = - let myp = EcPath.pqname (root env) name in - let typ = List.map (fst_map EcIdent.fresh) ty.tyd_params in - (typ, EcTypes.tconstr myp (List.map (tvar |- fst) typ)) in - let instr = - Sp.fold - (fun p inst -> TypeClass.bind_instance myty (`General p) inst) - tc env.env_tci - in - { env with env_tci = instr } - - | _ -> env + let rebind (name : symbol) (tyd : t) (env : env) = + MC.bind_tydecl name tyd env let bind ?(import = import0) name ty env = let env = if import.im_immediate then rebind name ty env else env in { env with env_item = mkitem import (Th_type (name, ty)) :: env.env_item } - let iter ?name f (env : env) = gen_iter (fun mc -> mc.mc_tydecls) MC.lookup_tydecls ?name f env @@ -2646,7 +2681,6 @@ module Op = struct let core_reduce ?(mode = `IfTransparent) ?(nargs = 0) env p = let op = oget (by_path_opt p env) in - match op.op_kind with | OB_oper (Some (OP_Plain f)) | OB_pred (Some (PR_Plain f)) -> begin @@ -2674,8 +2708,60 @@ module Op = struct else false let reduce ?mode ?nargs env p tys = - let op, f = core_reduce ?mode ?nargs env p in - Tvar.f_subst ~freshen:true (List.map fst op.op_tparams) tys f + let op, form = core_reduce ?mode ?nargs env p in + Tvar.f_subst ~freshen:true + (List.combine (List.fst op.op_tparams) tys) + form + + let tc_core_reduce (env : env) (p : path) (tys : etyarg list) = + let op = by_path p env in + + if not (is_tc_op op) then + raise NotReducible; + + (* Last type application if the TC parameter. We extract the type-class * + * information from the witness. *) + let _, (_, tcw) = List.betail tys in + + match as_seq1 tcw with + | TCIConcrete { path = tcipath; etyargs = tciargs; } -> begin + let tci = TcInstance.by_path tcipath env in + + match tci.tci_instance with + | `General (_, Some symbols) -> + (EcDecl.operator_as_tc op, (tciargs, (tci.tci_params, symbols))) + + | _ -> raise NotReducible + end + + | _ -> + raise NotReducible + + let tc_reducible (env : env) (p : path) (tys : etyarg list) = + try + ignore (tc_core_reduce env p tys); + true + with NotReducible -> false + + let tc_reduce (env : env) (p : path) (tys : etyarg list) = + let ((_, opname), (tciargs, (tciparams, symbols))) = + tc_core_reduce env p tys in + + let subst = + List.fold_left + (fun subst (a, ety) -> + let ety = EcSubst.subst_etyarg subst ety in + EcSubst.add_tyvar subst a ety) + EcSubst.empty + (List.combine (List.fst tciparams) tciargs) + in + + let optg, opargs = EcMaps.Mstr.find opname symbols in + let opargs = List.map (EcSubst.subst_etyarg subst) opargs in + let optg_decl = by_path optg env in + let tysubst = Tvar.init (List.combine (List.fst optg_decl.op_tparams) opargs) in + + f_op_tc optg opargs (Tvar.subst tysubst optg_decl.op_ty) let is_projection env p = try EcDecl.is_proj (by_path p env) @@ -2685,6 +2771,10 @@ module Op = struct try EcDecl.is_rcrd (by_path p env) with LookupFailure _ -> false + let is_tc_op env p = + try EcDecl.is_tc_op (by_path p env) + with LookupFailure _ -> false + let is_dtype_ctor ?nargs env p = try match (by_path p env).op_kind with @@ -2768,7 +2858,7 @@ module Ax = struct let instanciate p tys env = match by_path_opt p env with | Some ({ ax_spec = f } as ax) -> - Tvar.f_subst ~freshen:true (List.map fst ax.ax_tparams) tys f + Tvar.f_subst ~freshen:true (List.combine (List.map fst ax.ax_tparams) tys) f | _ -> raise (LookupFailure (`Path p)) let iter ?name f (env : env) = @@ -2778,22 +2868,6 @@ module Ax = struct gen_all (fun mc -> mc.mc_axioms) MC.lookup_axioms ?check ?name env end -(* -------------------------------------------------------------------- *) -module Algebra = struct - let bind_ring ty cr env = - assert (Mid.is_empty ty.ty_fv); - { env with env_tci = - TypeClass.bind_instance ([], ty) (`Ring cr) env.env_tci } - - let bind_field ty cr env = - assert (Mid.is_empty ty.ty_fv); - { env with env_tci = - TypeClass.bind_instance ([], ty) (`Field cr) env.env_tci } - - let add_ring ty cr lc env = TypeClass.add_instance ([], ty) (`Ring cr) lc env - let add_field ty cr lc env = TypeClass.add_instance ([], ty) (`Field cr) lc env -end - (* -------------------------------------------------------------------- *) module Theory = struct type t = ctheory @@ -2859,26 +2933,12 @@ module Theory = struct let xpath x = EcPath.pqname path x in match item.ti_item with - | Th_instance (ty, k, _) -> - TypeClass.bind_instance ty k inst + | Th_instance (name, tci) -> + TcInstance.bind_instance (Option.map xpath name) tci inst | Th_theory (x, cth) when cth.cth_mode = `Concrete -> bind_instance_th (xpath x) inst cth.cth_items - | Th_type (x, tyd) -> begin - match tyd.tyd_type with - | `Abstract tc -> - let myty = - let typ = List.map (fst_map EcIdent.fresh) tyd.tyd_params in - (typ, EcTypes.tconstr (xpath x) (List.map (tvar |- fst) typ)) - in - Sp.fold - (fun p inst -> TypeClass.bind_instance myty (`General p) inst) - tc inst - - | _ -> inst - end - | _ -> inst (* ------------------------------------------------------------------ *) @@ -2901,11 +2961,10 @@ module Theory = struct (* ------------------------------------------------------------------ *) let bind_tc_th = - let for1 path base = function - | Th_typeclass (x, tc) -> - tc.tc_prt |> omap (fun prt -> - let src = EcPath.pqname path x in - TC.Graph.add ~src ~dst:prt base) + let for1 _path base = function + | Th_typeclass (_, tc) -> + Some (tc :: base) + | _ -> None in bind_base_th for1 @@ -2982,15 +3041,12 @@ module Theory = struct | _, `Concrete -> let thname = EcPath.pqname (root env) cth.name in let env_tci = bind_instance_th thname env.env_tci items in - let env_tc = bind_tc_th thname env.env_tc items in let env_rwbase = bind_br_th thname env.env_rwbase items in let env_atbase = bind_at_th thname env.env_atbase items in let env_ntbase = bind_nt_th thname env.env_ntbase items in let env_redbase = bind_rd_th thname env.env_redbase items in let env = - { env with - env_tci ; env_tc ; env_rwbase; - env_atbase; env_ntbase; env_redbase; } + { env with env_tci; env_rwbase; env_atbase; env_ntbase; env_redbase; } in add_restr_th thname env items @@ -3198,7 +3254,6 @@ module Theory = struct | `Concrete -> { env with env_tci = bind_instance_th thpath env.env_tci cth.cth_items; - env_tc = bind_tc_th thpath env.env_tc cth.cth_items; env_rwbase = bind_br_th thpath env.env_rwbase cth.cth_items; env_atbase = bind_at_th thpath env.env_atbase cth.cth_items; env_ntbase = bind_nt_th thpath env.env_ntbase cth.cth_items; diff --git a/src/ecEnv.mli b/src/ecEnv.mli index 0e8712174a..a6c06eb484 100644 --- a/src/ecEnv.mli +++ b/src/ecEnv.mli @@ -166,7 +166,7 @@ module Ax : sig val iter : ?name:qsymbol -> (path -> t -> unit) -> env -> unit val all : ?check:(path -> t -> bool) -> ?name:qsymbol -> env -> (path * t) list - val instanciate : path -> EcTypes.ty list -> env -> form + val instanciate : path -> etyarg list -> env -> form end (* -------------------------------------------------------------------- *) @@ -311,11 +311,15 @@ module Op : sig val bind : ?import:import -> symbol -> operator -> env -> env val reducible : ?mode:redmode -> ?nargs:int -> env -> path -> bool - val reduce : ?mode:redmode -> ?nargs:int -> env -> path -> ty list -> form + val reduce : ?mode:redmode -> ?nargs:int -> env -> path -> etyarg list -> form + + val tc_reducible : env -> path -> etyarg list -> bool + val tc_reduce : env -> path -> etyarg list -> form val is_projection : env -> path -> bool val is_record_ctor : env -> path -> bool val is_dtype_ctor : ?nargs:int -> env -> path -> bool + val is_tc_op : env -> path -> bool val is_fix_def : env -> path -> bool val is_abbrev : env -> path -> bool val is_prind : env -> path -> bool @@ -345,16 +349,15 @@ module Ty : sig val bind : ?import:import -> symbol -> t -> env -> env val defined : path -> env -> bool - val unfold : path -> EcTypes.ty list -> env -> EcTypes.ty - val hnorm : EcTypes.ty -> env -> EcTypes.ty - val decompose_fun : EcTypes.ty -> env -> EcTypes.dom * EcTypes.ty + val unfold : path -> etyarg list -> env -> ty + val hnorm : ty -> env -> ty + val decompose_fun : ty -> env -> EcTypes.dom * ty val get_top_decl : - EcTypes.ty -> env -> (path * EcDecl.tydecl * EcTypes.ty list) option - + EcTypes.ty -> env -> (path * EcDecl.tydecl * etyarg list) option val scheme_of_ty : - [`Ind | `Case] -> EcTypes.ty -> env -> (path * EcTypes.ty list) option + [`Ind | `Case] -> EcTypes.ty -> env -> (path * etyarg list) option val signature : env -> ty -> ty list * ty @@ -365,18 +368,25 @@ end val ty_hnorm : ty -> env -> ty (* -------------------------------------------------------------------- *) -module Algebra : sig - val add_ring : ty -> EcDecl.ring -> is_local -> env -> env - val add_field : ty -> EcDecl.field -> is_local -> env -> env +module TypeClass : sig + type t = tc_decl + + val add : path -> env -> env + val bind : ?import:import -> symbol -> t -> env -> env + + val by_path : path -> env -> t + val by_path_opt : path -> env -> t option + val lookup : qsymbol -> env -> path * t + val lookup_opt : qsymbol -> env -> (path * t) option + val lookup_path : qsymbol -> env -> path end (* -------------------------------------------------------------------- *) -module TypeClass : sig - type t = typeclass +module TcInstance : sig + type t = tcinstance - val add : path -> env -> env - val bind : ?import:import -> symbol -> t -> env -> env - val graph : env -> EcTypeClass.graph + val add : path -> env -> env + val bind : ?import:import -> symbol option -> t -> env -> env val by_path : path -> env -> t val by_path_opt : path -> env -> t option @@ -384,8 +394,7 @@ module TypeClass : sig val lookup_opt : qsymbol -> env -> (path * t) option val lookup_path : qsymbol -> env -> path - val add_instance : ?import:import -> (ty_params * ty) -> tcinstance -> is_local -> env -> env - val get_instances : env -> ((ty_params * ty) * tcinstance) list + val get_all : env -> (path option * t) list end (* -------------------------------------------------------------------- *) diff --git a/src/ecFol.ml b/src/ecFol.ml index e5d6eb2c2c..fe5960da45 100644 --- a/src/ecFol.ml +++ b/src/ecFol.ml @@ -179,8 +179,7 @@ let f_mu_x f1 f2 = let proj_distr_ty env ty = match (EcEnv.Ty.hnorm ty env).ty_node with - | Tconstr(_,lty) when List.length lty = 1 -> - List.hd lty + | Tconstr(_, [lty, []]) -> lty | _ -> assert false let f_mu env f1 f2 = @@ -842,7 +841,7 @@ type sform = | SFimp of form * form | SFiff of form * form | SFeq of form * form - | SFop of (EcPath.path * ty list) * (form list) + | SFop of (EcPath.path * etyarg list) * (form list) | SFhoareF of sHoareF | SFhoareS of sHoareS diff --git a/src/ecFol.mli b/src/ecFol.mli index 403224fe8d..9b21c566e7 100644 --- a/src/ecFol.mli +++ b/src/ecFol.mli @@ -212,7 +212,7 @@ type sform = | SFimp of form * form | SFiff of form * form | SFeq of form * form - | SFop of (path * ty list) * (form list) + | SFop of (path * etyarg list) * (form list) | SFhoareF of sHoareF | SFhoareS of sHoareS diff --git a/src/ecHiGoal.ml b/src/ecHiGoal.ml index 52f091d732..bc14d47d21 100644 --- a/src/ecHiGoal.ml +++ b/src/ecHiGoal.ml @@ -114,15 +114,16 @@ let process_simplify_info ri (tc : tcenv1) = in { - EcReduction.beta = ri.pbeta; - EcReduction.delta_p = delta_p; - EcReduction.delta_h = delta_h; - EcReduction.zeta = ri.pzeta; - EcReduction.iota = ri.piota; - EcReduction.eta = ri.peta; - EcReduction.logic = if ri.plogic then Some `Full else None; - EcReduction.modpath = ri.pmodpath; - EcReduction.user = ri.puser; + EcReduction.beta = ri.pbeta; + EcReduction.delta_p = delta_p; + EcReduction.delta_h = delta_h; + EcReduction.delta_tc = ri.pdeltatc; + EcReduction.zeta = ri.pzeta; + EcReduction.iota = ri.piota; + EcReduction.eta = ri.peta; + EcReduction.logic = if ri.plogic then Some `Full else None; + EcReduction.modpath = ri.pmodpath; + EcReduction.user = ri.puser; } (*-------------------------------------------------------------------- *) @@ -649,8 +650,10 @@ let process_delta ~und_delta ?target (s, o, p) tc = in - let ri = { EcReduction.full_red with - delta_p = (fun p -> if Some p = dp then `Force else `IfTransparent)} in + let ri = + let delta_p p = + if Some p = dp then `Force else `IfTransparent + in { EcReduction.full_red with delta_p } in let na = List.length args in match s with @@ -688,8 +691,12 @@ let process_delta ~und_delta ?target (s, o, p) tc = match sform_of_form fp with | SFop ((_, tvi), []) -> begin (* FIXME: TC HOOK *) - let body = Tvar.f_subst ~freshen:true (List.map fst tparams) tvi body in - let body = f_app body args topfp.f_ty in + let body = + Tvar.f_subst + ~freshen:true + (List.combine (List.map fst tparams) tvi) + body in + let body = f_app body args topfp.f_ty in try EcReduction.h_red EcReduction.beta_red hyps body with EcEnv.NotReducible -> body end @@ -711,8 +718,13 @@ let process_delta ~und_delta ?target (s, o, p) tc = | `RtoL -> let fp = (* FIXME: TC HOOK *) - let body = Tvar.f_subst ~freshen:true (List.map fst tparams) tvi body in - let fp = f_app body args p.f_ty in + let body = + Tvar.f_subst + ~freshen:true + (List.combine (List.map fst tparams) tvi) + body + in + let fp = f_app body args p.f_ty in try EcReduction.h_red EcReduction.beta_red hyps fp with EcEnv.NotReducible -> fp in @@ -1426,7 +1438,10 @@ let rec process_mintros_1 ?(cf = true) ttenv pis gs = end in - let tc = t_ors [t_elimT_ind `Case; t_elim; t_elim_prind `Case] in + let tc = t_ors [ + t_elimT_ind ~reduce:`Full `Case; + t_elim ~reduce:`Full; + t_elim_prind ~reduce:`Full `Case] in let tc = fun g -> try tc g @@ -2034,7 +2049,11 @@ let process_split (tc : tcenv1) = let process_elim (pe, qs) tc = let doelim tc = match qs with - | None -> t_or (t_elimT_ind `Ind) t_elim tc + | None -> + t_or + (t_elimT_ind ~reduce:`Full `Ind) + (t_elim ~reduce:`Full) + tc | Some qs -> let qs = { fp_mode = `Implicit; @@ -2080,7 +2099,10 @@ let process_case ?(doeq = false) gp tc = with E.LEMFailure -> try FApi.t_last - (t_ors [t_elimT_ind `Case; t_elim; t_elim_prind `Case]) + (t_ors [ + t_elimT_ind ~reduce:`Full `Case; + t_elim ~reduce:`Full; + t_elim_prind ~reduce:`Full `Case]) (process_move ~doeq gp.pr_view gp.pr_rev tc) with EcCoreGoal.InvalidGoalShape -> diff --git a/src/ecHiInductive.ml b/src/ecHiInductive.ml index bef40e9497..73cbe0f8bf 100644 --- a/src/ecHiInductive.ml +++ b/src/ecHiInductive.ml @@ -84,7 +84,7 @@ let trans_datatype (env : EcEnv.env) (name : ptydname) (dt : pdatatype) = let env0 = let myself = { tyd_params = EcUnify.UniEnv.tparams ue; - tyd_type = `Abstract EcPath.Sp.empty; + tyd_type = `Abstract []; tyd_loca = lc; tyd_resolve = true; } in @@ -137,7 +137,7 @@ let trans_datatype (env : EcEnv.env) (name : ptydname) (dt : pdatatype) = match tdecl.tyd_type with | `Abstract _ -> - List.exists isempty (targs) + List.exists isempty (List.fst targs) (* FIXME:TC *) | `Concrete ty -> isempty_1 [tyinst () ty] @@ -315,8 +315,8 @@ let trans_matchfix EcUnify.UniEnv.restore ~src:subue ~dst:ue; let ctorty = - let tvi = Some (EcUnify.TVIunamed tvi) in - fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in + let tvi = Some (EcUnify.tvi_unamed tvi) in + fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in let pty = EcUnify.UniEnv.fresh ue in (try EcUnify.unify env ue (toarrow ctorty pty) opty diff --git a/src/ecHiNotations.ml b/src/ecHiNotations.ml index ea8959d97c..79c11df3fe 100644 --- a/src/ecHiNotations.ml +++ b/src/ecHiNotations.ml @@ -12,7 +12,7 @@ module TT = EcTyping (* -------------------------------------------------------------------- *) type nterror = | NTE_Typing of EcTyping.tyerror -| NTE_TyNotClosed +| NTE_TyNotClosed of EcUnify.uniflags | NTE_DupIdent | NTE_UnknownBinder of symbol | NTE_AbbrevIsVar @@ -62,8 +62,8 @@ let trans_notation_r (env : env) (nt : pnotation located) = let codom = TT.transty TT.tp_relax env ue nt.nt_codom in let body = TT.transexpcast benv `InOp ue codom nt.nt_body in - if not (EcUnify.UniEnv.closed ue) then - nterror gloc env NTE_TyNotClosed; + Option.iter (fun infos -> nterror gloc env (NTE_TyNotClosed infos)) + @@ EcUnify.UniEnv.xclosed ue; ignore body; () @@ -80,11 +80,11 @@ let trans_abbrev_r (env : env) (at : pabbrev located) = let codom = TT.transty TT.tp_relax env ue (fst at.ab_def) in let body = TT.transexpcast benv `InOp ue codom (snd at.ab_def) in - if not (EcUnify.UniEnv.closed ue) then - nterror gloc env NTE_TyNotClosed; + Option.iter (fun infos -> nterror gloc env (NTE_TyNotClosed infos)) + @@ EcUnify.UniEnv.xclosed ue; - let ts = Tuni.subst (EcUnify.UniEnv.close ue) in - let es = e_subst ts in + let ts = Tuni.subst (EcUnify.UniEnv.close ue) in + let es = e_subst ts in let body = es body in let codom = ty_subst ts codom in let xs = List.map (snd_map (ty_subst ts)) xs in diff --git a/src/ecHiNotations.mli b/src/ecHiNotations.mli index 54dd54543e..53aa868c15 100644 --- a/src/ecHiNotations.mli +++ b/src/ecHiNotations.mli @@ -8,7 +8,7 @@ open EcEnv (* -------------------------------------------------------------------- *) type nterror = | NTE_Typing of EcTyping.tyerror -| NTE_TyNotClosed +| NTE_TyNotClosed of EcUnify.uniflags | NTE_DupIdent | NTE_UnknownBinder of symbol | NTE_AbbrevIsVar diff --git a/src/ecHiPredicates.ml b/src/ecHiPredicates.ml index 49e725ad58..e8f6143ced 100644 --- a/src/ecHiPredicates.ml +++ b/src/ecHiPredicates.ml @@ -2,7 +2,6 @@ open EcUtils open EcSymbols open EcLocation -open EcTypes open EcCoreSubst open EcParsetree open EcDecl @@ -11,8 +10,8 @@ module TT = EcTyping (* -------------------------------------------------------------------- *) type tperror = -| TPE_Typing of EcTyping.tyerror -| TPE_TyNotClosed +| TPE_Typing of EcTyping.tyerror +| TPE_TyNotClosed of EcUnify.uniflags | TPE_DuplicatedConstr of symbol exception TransPredError of EcLocation.t * EcEnv.env * tperror @@ -20,8 +19,8 @@ exception TransPredError of EcLocation.t * EcEnv.env * tperror let tperror loc env e = raise (TransPredError (loc, env, e)) (* -------------------------------------------------------------------- *) -let close_pr_body (uni : ty EcUid.Muid.t) (body : prbody) = - let fsubst = EcFol.Fsubst.f_subst_init ~tu:uni () in +let close_pr_body (uidmap : unisubst) (body : prbody) = + let fsubst = EcFol.Fsubst.f_subst_init ~tu:uidmap () in let tsubst = ty_subst fsubst in match body with @@ -74,13 +73,13 @@ let trans_preddecl_r (env : EcEnv.env) (pr : ppredicate located) = in - if not (EcUnify.UniEnv.closed ue) then - tperror loc env TPE_TyNotClosed; + Option.iter + (fun infos -> tperror loc env (TPE_TyNotClosed infos)) + (EcUnify.UniEnv.xclosed ue); - let uidmap = EcUnify.UniEnv.assubst ue in + let uidmap = EcUnify.UniEnv.assubst ue in let tparams = EcUnify.UniEnv.tparams ue in let body = body |> omap (close_pr_body uidmap) in - let dom = Tuni.subst_dom uidmap dom in EcDecl.mk_pred ~opaque:optransparent tparams dom body pr.pp_locality diff --git a/src/ecHiPredicates.mli b/src/ecHiPredicates.mli index eb56da6628..f411802cce 100644 --- a/src/ecHiPredicates.mli +++ b/src/ecHiPredicates.mli @@ -5,8 +5,8 @@ open EcParsetree (* -------------------------------------------------------------------- *) type tperror = -| TPE_Typing of EcTyping.tyerror -| TPE_TyNotClosed +| TPE_Typing of EcTyping.tyerror +| TPE_TyNotClosed of EcUnify.uniflags | TPE_DuplicatedConstr of symbol exception TransPredError of EcLocation.t * EcEnv.env * tperror diff --git a/src/ecIdent.ml b/src/ecIdent.ml index 60ab346526..3b2e29a0a3 100644 --- a/src/ecIdent.ml +++ b/src/ecIdent.ml @@ -57,3 +57,4 @@ let tostring (id : t) = (* -------------------------------------------------------------------- *) let pp_ident fmt id = Format.fprintf fmt "%s" (name id) +let pp = pp_ident diff --git a/src/ecIdent.mli b/src/ecIdent.mli index 988430a72e..2c3d5d6046 100644 --- a/src/ecIdent.mli +++ b/src/ecIdent.mli @@ -38,3 +38,4 @@ val fv_add : ident -> int Mid.t -> int Mid.t (* -------------------------------------------------------------------- *) val pp_ident : Format.formatter -> t -> unit +val pp : Format.formatter -> t -> unit diff --git a/src/ecInductive.ml b/src/ecInductive.ml index a873688f4d..b20fa72d7a 100644 --- a/src/ecInductive.ml +++ b/src/ecInductive.ml @@ -38,15 +38,15 @@ let datatype_proj_path (p : EP.path) (x : symbol) = (* -------------------------------------------------------------------- *) let indsc_of_record (rc : record) = - let targs = List.map (tvar |- fst) rc.rc_tparams in - let recty = tconstr rc.rc_path targs in + let targs = etyargs_of_tparams rc.rc_tparams in + let recty = tconstr_tc rc.rc_path targs in let recx = fresh_id_of_ty recty in let recfm = FL.f_local recx recty in let predty = tfun recty tbool in let predx = EcIdent.create "P" in let pred = FL.f_local predx predty in let ctor = record_ctor_path rc.rc_path in - let ctor = FL.f_op ctor targs (toarrow (List.map snd rc.rc_fields) recty) in + let ctor = FL.f_op_tc ctor targs (toarrow (List.map snd rc.rc_fields) recty) in let prem = let ids = List.map (fun (_, fty) -> (fresh_id_of_ty fty, fty)) rc.rc_fields in let vars = List.map (fun (x, xty) -> FL.f_local x xty) ids in @@ -104,7 +104,9 @@ let indsc_of_datatype ?normty (mode : indmode) (dt : datatype) = end | Tconstr (p', ts) -> - if List.exists (occurs p) ts then raise NonPositive; + (* FIXME:TC *) + if List.exists (EcTypes.etyarg_sub_exists (occurs p)) ts then + raise NonPositive; if not (EcPath.p_equal p p') then None else Some (FL.f_app pred [fac] tbool) @@ -115,11 +117,11 @@ let indsc_of_datatype ?normty (mode : indmode) (dt : datatype) = |> omap (FL.f_forall [x, GTty ty1]) and schemec mode (targs, p) pred (ctor, tys) = - let indty = tconstr p (List.map tvar targs) in + let indty = tconstr_tc p targs in let xs = List.map (fun xty -> (fresh_id_of_ty xty, xty)) tys in let cargs = List.map (fun (x, xty) -> FL.f_local x xty) xs in let ctor = EcPath.pqoname (EcPath.prefix tpath) ctor in - let ctor = FL.f_op ctor (List.map tvar targs) (toarrow tys indty) in + let ctor = FL.f_op_tc ctor targs (toarrow tys indty) in let form = FL.f_app pred [FL.f_app ctor cargs indty] tbool in let form = match mode with @@ -139,7 +141,7 @@ let indsc_of_datatype ?normty (mode : indmode) (dt : datatype) = form and scheme mode (targs, p) ctors = - let indty = tconstr p (List.map tvar targs) in + let indty = tconstr_tc p targs in let indx = fresh_id_of_ty indty in let indfm = FL.f_local indx indty in let predty = tfun indty tbool in @@ -157,7 +159,7 @@ let indsc_of_datatype ?normty (mode : indmode) (dt : datatype) = | Tconstr (p', _) when EcPath.p_equal p p' -> true | _ -> EcTypes.ty_sub_exists (occurs p) t - in scheme mode (List.map fst dt.dt_tparams, tpath) dt.dt_ctors + in scheme mode (etyargs_of_tparams dt.dt_tparams, tpath) dt.dt_ctors (* -------------------------------------------------------------------- *) let datatype_projectors (tpath, tparams, { tydt_ctors = ctors }) = diff --git a/src/ecLowGoal.ml b/src/ecLowGoal.ml index 003bf07a9a..f959e5b9f1 100644 --- a/src/ecLowGoal.ml +++ b/src/ecLowGoal.ml @@ -383,9 +383,9 @@ let rec t_lazy_match ?(reduce = `Full) (tx : form -> FApi.backward) with TTC.NoMatch -> let strategy = match reduce with - | `None -> raise InvalidGoalShape - | `Full -> EcReduction.full_red - | `NoDelta -> EcReduction.nodelta in + | `None -> raise InvalidGoalShape + | `Full -> EcReduction.full_red + | `NoDelta -> EcReduction.nodelta in FApi.t_seq (t_hred_with_info strategy) (t_lazy_match ~reduce tx) tc (* -------------------------------------------------------------------- *) @@ -712,9 +712,14 @@ let t_apply_hyp (x : EcIdent.t) ?args ?sk tc = let t_hyp (x : EcIdent.t) tc = t_apply_hyp x ~args:[] ~sk:0 tc +(* -------------------------------------------------------------------- *) +let t_apply_s_tc (p : path) (etys : etyarg list) ?args ?sk tc = + tt_apply_s p etys ?args ?sk (FApi.tcenv_of_tcenv1 tc) + (* -------------------------------------------------------------------- *) let t_apply_s (p : path) (tys : ty list) ?args ?sk tc = - tt_apply_s p tys ?args ?sk (FApi.tcenv_of_tcenv1 tc) + let etys = List.map (fun ty -> (ty, [])) tys in + tt_apply_s p etys ?args ?sk (FApi.tcenv_of_tcenv1 tc) (* -------------------------------------------------------------------- *) let t_apply_hd (hd : handle) ?args ?sk tc = @@ -971,7 +976,7 @@ let t_true (tc : tcenv1) = let t_reflex_s (f : form) (tc : tcenv1) = t_apply_s LG.p_eq_refl [f.f_ty] ~args:[f] tc -let t_reflex ?(mode=`Conv) ?reduce (tc : tcenv1) = +let t_reflex ?(mode = `Conv) ?reduce (tc : tcenv1) = let t_reflex_r (fp : form) (tc : tcenv1) = match sform_of_form fp with | SFeq (f1, f2) -> @@ -1133,9 +1138,9 @@ let t_elim_r ?(reduce = (`Full : lazyred)) txs tc = | None -> begin let strategy = match reduce with - | `None -> raise InvalidGoalShape - | `Full -> EcReduction.full_red - | `NoDelta -> EcReduction.nodelta in + | `None -> raise InvalidGoalShape + | `Full -> EcReduction.full_red + | `NoDelta -> EcReduction.nodelta in match h_red_opt strategy (FApi.tc1_hyps tc) f1 with | None -> raise InvalidGoalShape @@ -1470,9 +1475,9 @@ let t_elim_prind_r ?reduce ?accept (_mode : [`Case | `Ind]) tc = end; (oget (EcEnv.Op.scheme_of_prind env `Case p), tv, args) - | _ -> raise InvalidGoalShape + | _ -> raise InvalidGoalShape in - in t_apply_s p tv ~args:(args @ [f2]) ~sk tc + t_apply_s_tc p tv ~args:(args @ [f2]) ~sk tc | _ -> raise TTC.NoMatch @@ -1552,7 +1557,7 @@ let t_split_prind ?reduce (tc : tcenv1) = | None -> raise InvalidGoalShape | Some (x, sk) -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc in t_lazy_match ?reduce t_split_r tc @@ -1572,10 +1577,10 @@ let t_or_intro_prind ?reduce (side : side) (tc : tcenv1) = match EcInductive.prind_is_iso_ors pri with | Some ((x, sk), _) when side = `Left -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc | Some (_, (x, sk)) when side = `Right -> let p = EcInductive.prind_introsc_path p x in - t_apply_s p tv ~args ~sk tc + t_apply_s_tc p tv ~args ~sk tc | _ -> raise InvalidGoalShape in t_lazy_match ?reduce t_split_r tc @@ -2175,8 +2180,7 @@ let t_progress ?options ?ti (tt : FApi.backward) (tc : tcenv1) = else elims in - let reduce = - if options.pgo_delta.pgod_case then `Full else `NoDelta in + let reduce = if options.pgo_delta.pgod_case then `Full else `NoDelta in FApi.t_switch ~on:`All (t_elim_r ~reduce elims) ~ifok:aux0 ~iffail tc end @@ -2197,7 +2201,6 @@ let t_progress ?options ?ti (tt : FApi.backward) (tc : tcenv1) = in entry tc (* -------------------------------------------------------------------- *) - let pp_tc tc = let pr = proofenv_of_proof (proof_of_tcenv tc) in let cl = List.map (FApi.get_pregoal_by_id^~ pr) (FApi.tc_opened tc) in diff --git a/src/ecLowGoal.mli b/src/ecLowGoal.mli index 093f144240..7b6cf8c621 100644 --- a/src/ecLowGoal.mli +++ b/src/ecLowGoal.mli @@ -18,7 +18,6 @@ exception InvalidProofTerm (* invalid proof term *) type side = [`Left|`Right] type lazyred = [`Full | `NoDelta | `None] - (* -------------------------------------------------------------------- *) val (@!) : FApi.backward -> FApi.backward -> FApi.backward val (@+) : FApi.backward -> FApi.backward list -> FApi.backward @@ -113,6 +112,8 @@ val t_apply : ?cutsolver:cutsolver -> proofterm -> FApi.backward * skip before applying [p]. *) val t_apply_s : path -> ty list -> ?args:(form list) -> ?sk:int -> FApi.backward +val t_apply_s_tc : path -> etyarg list -> ?args:(form list) -> ?sk:int -> FApi.backward + (* Apply a proof term of the form [h f1...fp _ ... _] constructed from * the local hypothesis and formulas given to the function. The [int] * argument gives the number of premises to skip before applying @@ -189,7 +190,7 @@ val t_elim_iso_or : ?reduce:lazyred -> tcenv1 -> int list * tcenv (* Elimination using an custom elimination principle. *) val t_elimT_form : proofterm -> ?sk:int -> form -> FApi.backward -val t_elimT_form_global : path -> ?typ:(ty list) -> ?sk:int -> form -> FApi.backward +val t_elimT_form_global : path -> ?typ:(etyarg list) -> ?sk:int -> form -> FApi.backward (* Eliminiation using an elimation principle of an induction type *) val t_elimT_ind : ?reduce:lazyred -> [ `Case | `Ind ] -> FApi.backward diff --git a/src/ecMatching.ml b/src/ecMatching.ml index fd7f256330..6a3043cb22 100644 --- a/src/ecMatching.ml +++ b/src/ecMatching.ml @@ -639,6 +639,8 @@ let f_match_core opts hyps (ue, ev) f1 f2 = | Fop (op1, tys1), Fop (op2, tys2) -> begin if not (EcPath.p_equal op1 op2) then failure (); + let tys1 = List.fst tys1 in (* FIXME:TC *) + let tys2 = List.fst tys2 in (* FIXME:TC *) try List.iter2 (EcUnify.unify env ue) tys1 tys2 with EcUnify.UnificationFailure _ -> failure () end @@ -732,6 +734,12 @@ let f_match_core opts hyps (ue, ev) f1 f2 = | _, (Fop (op2, tys2), args2) when EcEnv.Op.reducible env op2 -> doit_reduce env (doit env ilc f1) f2.f_ty op2 tys2 args2 + | (Fop (op1, tys1), args1), _ when EcEnv.Op.tc_reducible env op1 tys1 -> + doit_tc_reduce env ((doit env ilc)^~ f2) f1.f_ty op1 tys1 args1 + + | _, (Fop (op2, tys2), args2) when EcEnv.Op.tc_reducible env op2 tys2 -> + doit_tc_reduce env (doit env ilc f1) f2.f_ty op2 tys2 args2 + | _, _ -> failure () in @@ -757,6 +765,12 @@ let f_match_core opts hyps (ue, ev) f1 f2 = with NotReducible -> raise MatchFailure in cb (odfl reduced (EcReduction.h_red_opt EcReduction.beta_red hyps reduced)) + and doit_tc_reduce env cb ty op tys args = + let reduced = + try f_app (EcEnv.Op.tc_reduce env op tys) args ty + with NotReducible -> raise MatchFailure in + cb (odfl reduced (EcReduction.h_red_opt EcReduction.beta_red hyps reduced)) + and doit_lreduce _env cb ty x args = let reduced = try f_app (LDecl.unfold x hyps) args ty @@ -841,7 +855,7 @@ let f_match opts hyps (ue, ev) f1 f2 = raise MatchFailure; let clue = try EcUnify.UniEnv.close ue - with EcUnify.UninstanciateUni -> raise MatchFailure + with EcUnify.UninstanciateUni _ -> raise MatchFailure in (ue, clue, ev) diff --git a/src/ecMatching.mli b/src/ecMatching.mli index 9961f1c24e..d1f822f3d7 100644 --- a/src/ecMatching.mli +++ b/src/ecMatching.mli @@ -1,6 +1,5 @@ (* -------------------------------------------------------------------- *) open EcMaps -open EcUid open EcIdent open EcTypes open EcModules @@ -196,7 +195,7 @@ val f_match : -> unienv * mevmap -> form -> form - -> unienv * (ty Muid.t) * mevmap + -> unienv * unisubst * mevmap (* -------------------------------------------------------------------- *) type ptnpos = private [`Select of int | `Sub of ptnpos] Mint.t diff --git a/src/ecPV.ml b/src/ecPV.ml index 2ad9fed2dc..25c09b9e20 100644 --- a/src/ecPV.ml +++ b/src/ecPV.ml @@ -116,7 +116,7 @@ module Mpv = struct let rec esubst env (s : esubst) e = match e.e_node with | Evar pv -> (try find env pv s with Not_found -> e) - | _ -> EcTypes.e_map (fun ty -> ty) (esubst env s) e + | _ -> EcTypes.e_map (esubst env s) e let rec isubst env (s : esubst) (i : instr) = let esubst = esubst env s in @@ -182,30 +182,30 @@ module PVM = struct | FequivF _ -> check_binding EcFol.mleft s; check_binding EcFol.mright s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | FequivS es -> check_binding (fst es.es_ml) s; check_binding (fst es.es_mr) s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | FhoareF _ | FbdHoareF _ -> check_binding EcFol.mhr s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | FhoareS hs -> check_binding (fst hs.hs_m) s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | FbdHoareS hs -> check_binding (fst hs.bhs_m) s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | Fpr pr -> check_binding pr.pr_mem s; - EcFol.f_map (fun ty -> ty) aux f + EcFol.f_map aux f | Fquant(q,b,f1) -> let f1 = if has_mod b then subst (Mod.add_mod_binding b env) s f1 else aux f1 in f_quant q b f1 - | _ -> EcFol.f_map (fun ty -> ty) aux f) + | _ -> EcFol.f_map aux f) let subst1 env pv m f = let s = add env pv m f empty in @@ -852,7 +852,7 @@ module Mpv2 = struct when EcIdent.id_equal ml m1 && EcIdent.id_equal mr m2 -> add_glob env (EcPath.mident mp1) (EcPath.mident mp2) eqs | Fop(op1,tys1), Fop(op2,tys2) when EcPath.p_equal op1 op2 && - List.all2 (EcReduction.EqTest.for_type env) tys1 tys2 -> eqs + List.all2 (EcReduction.EqTest.for_etyarg env) tys1 tys2 -> eqs | Fapp(f1,a1), Fapp(f2,a2) -> List.fold_left2 (add_eq local) eqs (f1::a1) (f2::a2) | Ftuple es1, Ftuple es2 -> @@ -951,7 +951,7 @@ module Mpv2 = struct I postpone this for latter *) | Eop(op1,tys1), Eop(op2,tys2) when EcPath.p_equal op1 op2 && - List.all2 (EcReduction.EqTest.for_type env) tys1 tys2 -> eqs + List.all2 (EcReduction.EqTest.for_etyarg env) tys1 tys2 -> eqs | Eapp(f1,a1), Eapp(f2,a2) -> List.fold_left2 (add_eqs_loc env local) eqs (f1::a1) (f2::a2) | Elet(lp1,a1,b1), Elet(lp2,a2,b2) -> diff --git a/src/ecParser.mly b/src/ecParser.mly index c7d89e783c..958294eb30 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -71,17 +71,18 @@ let mk_simplify l = if l = [] then - { pbeta = true; pzeta = true; - piota = true; peta = true; - plogic = true; pdelta = None; - pmodpath = true; puser = true; } + { pbeta = true; pzeta = true; + piota = true; peta = true; + plogic = true; pdelta = None; + pdeltatc = true; pmodpath = true; + puser = true; } else let doarg acc = function | `Delta l -> if l = [] || acc.pdelta = None then { acc with pdelta = None } else { acc with pdelta = Some (oget acc.pdelta @ l) } - + | `DeltaTC -> { acc with pdeltatc = true } | `Zeta -> { acc with pzeta = true } | `Iota -> { acc with piota = true } | `Beta -> { acc with pbeta = true } @@ -91,10 +92,11 @@ | `User -> { acc with puser = true } in List.fold_left doarg - { pbeta = false; pzeta = false; - piota = false; peta = false; - plogic = false; pdelta = Some []; - pmodpath = false; puser = false; } l + { pbeta = false; pzeta = false; + piota = false; peta = false; + plogic = false; pdelta = Some []; + pdeltatc = false; pmodpath = false; + puser = false; } l let simplify_red = [`Zeta; `Iota; `Beta; `Eta; `Logic; `ModPath; `User] @@ -1558,6 +1560,7 @@ signature_item: pfd_uses = { pmre_name = x; pmre_orcls = orcls; } } } (* -------------------------------------------------------------------- *) +(* EcTypes declarations / definitions *) %inline locality: | (* empty *) { `Global } | LOCAL { `Local } @@ -1572,12 +1575,13 @@ signature_item: %inline is_local: | lc=loc(locality) { locality_as_local lc } -(* -------------------------------------------------------------------- *) -(* EcTypes declarations / definitions *) +tcparam: +| tys=ioption(type_args) x=lqident + { (x, odfl [] tys) } typaram: | x=tident { (x, []) } -| x=tident LTCOLON tc=plist1(lqident, AMP) { (x, tc) } +| x=tident LTCOLON tc=plist1(tcparam, AMP) { (x, tc) } typarams: | empty { [] } @@ -1605,7 +1609,7 @@ typedecl: | locality=locality TYPE td=rlist1(tyd_name, COMMA) { List.map (fun x -> mk_tydecl ~locality x (PTYD_Abstract [])) td } -| locality=locality TYPE td=tyd_name LTCOLON tcs=rlist1(qident, COMMA) +| locality=locality TYPE td=tyd_name LTCOLON tcs=rlist1(tcparam, AMP) { [mk_tydecl ~locality td (PTYD_Abstract tcs)] } | locality=locality TYPE td=tyd_name EQ te=loc(type_exp) @@ -1620,18 +1624,16 @@ typedecl: (* -------------------------------------------------------------------- *) (* Type classes *) typeclass: -| loca=is_local TYPE CLASS x=lident inth=tc_inth? EQ LBRACE body=tc_body RBRACE { - { ptc_name = x; - ptc_inth = inth; - ptc_ops = fst body; - ptc_axs = snd body; - ptc_loca = loca; - } +| loca=is_local TYPE CLASS tya=tyvars_decl? x=lident inth=prefix(LTCOLON, tcparam)? + EQ LBRACE body=tc_body RBRACE { + { ptc_name = x; + ptc_params = tya; + ptc_inth = inth; + ptc_ops = fst body; + ptc_axs = snd body; + ptc_loca = loca; } } -tc_inth: -| LTCOLON x=lqident { x } - tc_body: | ops=tc_op* axs=tc_ax* { (ops, axs) } @@ -1644,29 +1646,22 @@ tc_ax: (* -------------------------------------------------------------------- *) (* Type classes (instances) *) tycinstance: -| loca=is_local INSTANCE x=qident - WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* +| loca=is_local INSTANCE tc=tcparam args=tyci_args? + name=prefix(AS, lident)? WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* { - { pti_name = x; + let args = args |> omap (fun (c, p) -> `Ring (c, p)) in + { pti_tc = tc; + pti_name = name; pti_type = (odfl [] typ, ty); pti_ops = ops; pti_axs = axs; - pti_args = None; - pti_loca = loca; - } + pti_args = args; + pti_loca = loca; } } -| loca=is_local INSTANCE x=qident c=uoption(UINT) p=uoption(UINT) - WITH typ=tyvars_decl? ty=loc(type_exp) ops=tyci_op* axs=tyci_ax* - { - { pti_name = x; - pti_type = (odfl [] typ, ty); - pti_ops = ops; - pti_axs = axs; - pti_args = Some (`Ring (c, p)); - pti_loca = loca; - } - } +tyci_args: +| c=uoption(UINT) p=uoption(UINT) + { (c, p) } tyci_op: | OP x=oident EQ tg=qoident @@ -2406,6 +2401,7 @@ genpattern: simplify_arg: | DELTA l=qoident* { `Delta l } +| CLASS { `DeltaTC } | ZETA { `Zeta } | IOTA { `Iota } | BETA { `Beta } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 13f8f8e604..df95ff8366 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -79,18 +79,21 @@ type locality = [`Declare | `Local | `Global] (* -------------------------------------------------------------------- *) type pmodule_type = pqsymbol -type ptyparams = (psymbol * pqsymbol list) list +(* -------------------------------------------------------------------- *) +type ptcparam = pqsymbol * pty list +type ptyparam = psymbol * ptcparam list +type ptyparams = ptyparam list type ptydname = (ptyparams * psymbol) located type ptydecl = { - pty_name : psymbol; - pty_tyvars : ptyparams; - pty_body : ptydbody; + pty_name : psymbol; + pty_tyvars : ptyparams; + pty_body : ptydbody; pty_locality : locality; } and ptydbody = - | PTYD_Abstract of pqsymbol list + | PTYD_Abstract of ptcparam list | PTYD_Alias of pty | PTYD_Record of precord | PTYD_Datatype of pdatatype @@ -104,7 +107,6 @@ type f_or_mod_ident = | FM_FunOrVar of pgamepath | FM_Mod of pmsymbol located - type pmod_restr_mem_el = | PMPlus of f_or_mod_ident | PMMinus of f_or_mod_ident @@ -114,7 +116,7 @@ type pmod_restr_mem_el = type pmod_restr_mem = pmod_restr_mem_el list (* -------------------------------------------------------------------- *) -type pmemory = psymbol +type pmemory = psymbol type phoarecmp = EcFol.hoarecmp @@ -345,9 +347,6 @@ let rec pf_ident ?(raw = false) f = | _ -> None (* -------------------------------------------------------------------- *) -type ptyvardecls = - (psymbol * pqsymbol list) list - type pop_def = | PO_abstr of pty | PO_concr of pty * pformula @@ -369,7 +368,7 @@ type poperator = { po_name : psymbol; po_aliases: psymbol list; po_tags : psymbol list; - po_tyvars : ptyvardecls option; + po_tyvars : ptyparams option; po_args : ptybindings * ptybindings option; po_def : pop_def; po_ax : osymbol_r; @@ -397,7 +396,7 @@ and ppind = ptybindings * (ppind_ctor list) type ppredicate = { pp_name : psymbol; - pp_tyvars : (psymbol * pqsymbol list) list option; + pp_tyvars : ptyparams option; pp_def : ppred_def; pp_locality : locality; } @@ -405,7 +404,7 @@ type ppredicate = { (* -------------------------------------------------------------------- *) type pnotation = { nt_name : psymbol; - nt_tv : ptyvardecls option; + nt_tv : ptyparams option; nt_bd : (psymbol * pty) list; nt_args : (psymbol * (psymbol list * pty option)) list; nt_codom : pty; @@ -419,7 +418,7 @@ type abrvopts = (bool * abrvopt) list type pabbrev = { ab_name : psymbol; - ab_tv : ptyvardecls option; + ab_tv : ptyparams option; ab_args : ptybindings; ab_def : pty * pexpr; ab_opts : abrvopts; @@ -460,6 +459,7 @@ type pmpred_args = (osymbol * pformula) list type preduction = { pbeta : bool; (* β-reduction *) pdelta : pqsymbol list option; (* definition unfolding *) + pdeltatc : bool; pzeta : bool; (* let-reduction *) piota : bool; (* case/if-reduction *) peta : bool; (* η-reduction *) @@ -1027,7 +1027,7 @@ type mempred_binding = PT_MemPred of psymbol list type paxiom = { pa_name : psymbol; pa_pvars : mempred_binding option; - pa_tyvars : (psymbol * pqsymbol list) list option; + pa_tyvars : ptyparams option; pa_vars : pgtybindings option; pa_formula : pformula; pa_kind : paxiom_kind; @@ -1042,16 +1042,18 @@ type prealize = { (* -------------------------------------------------------------------- *) type ptypeclass = { - ptc_name : psymbol; - ptc_inth : pqsymbol option; - ptc_ops : (psymbol * pty) list; - ptc_axs : (psymbol * pformula) list; - ptc_loca : is_local; + ptc_name : psymbol; + ptc_params : ptyparams option; + ptc_inth : ptcparam option; + ptc_ops : (psymbol * pty) list; + ptc_axs : (psymbol * pformula) list; + ptc_loca : is_local; } type ptycinstance = { - pti_name : pqsymbol; - pti_type : (psymbol * pqsymbol list) list * pty; + pti_tc : ptcparam; + pti_name : psymbol option; + pti_type : ptyparams * pty; pti_ops : (psymbol * (pty list * pqsymbol)) list; pti_axs : (psymbol * ptactic_core) list; pti_args : [`Ring of (zint option * zint option)] option; diff --git a/src/ecPath.ml b/src/ecPath.ml index 091840b6bd..97599e8b45 100644 --- a/src/ecPath.ml +++ b/src/ecPath.ml @@ -104,6 +104,9 @@ let rec tostring p = | Psymbol x -> x | Pqname (p,x) -> Printf.sprintf "%s.%s" (tostring p) x +let pp_path fmt p = + Format.fprintf fmt "%s" (tostring p) + let tolist = let rec aux l p = match p.p_node with @@ -394,10 +397,16 @@ let rec m_tostring (m : mpath) = in Printf.sprintf "%s%s%s" top args sub +let pp_mpath fmt p = + Format.fprintf fmt "%s" (m_tostring p) + let x_tostring x = Printf.sprintf "%s./%s" (m_tostring x.x_top) x.x_sub +let pp_xpath fmt x = + Format.fprintf fmt "%s" (x_tostring x) + (* -------------------------------------------------------------------- *) type smsubst = { sms_crt : path Mp.t; diff --git a/src/ecPath.mli b/src/ecPath.mli index ef2d2e8c0f..a34361bc7b 100644 --- a/src/ecPath.mli +++ b/src/ecPath.mli @@ -13,6 +13,8 @@ and path_node = | Psymbol of symbol | Pqname of path * symbol +val pp_path : Format.formatter -> path -> unit + (* -------------------------------------------------------------------- *) val psymbol : symbol -> path val pqname : path -> symbol -> path @@ -62,6 +64,8 @@ and mpath_top = [ | `Local of ident | `Concrete of path * path option ] +val pp_mpath : Format.formatter -> mpath -> unit + (* -------------------------------------------------------------------- *) val mpath : mpath_top -> mpath list -> mpath val mpath_abs : ident -> mpath list -> mpath @@ -96,6 +100,8 @@ type xpath = private { x_tag : int; } +val pp_xpath : Format.formatter -> xpath -> unit + val xpath : mpath -> symbol -> xpath val xastrip : xpath -> xpath diff --git a/src/ecPrinting.ml b/src/ecPrinting.ml index 9b5232c5b7..e7e5c2e965 100644 --- a/src/ecPrinting.ml +++ b/src/ecPrinting.ml @@ -158,25 +158,36 @@ module PPEnv = struct shorten (List.rev nm) ([], x) let ty_symb (ppe : t) p = - let exists sm = - try EcPath.p_equal (EcEnv.Ty.lookup_path sm ppe.ppe_env) p - with EcEnv.LookupFailure _ -> false + let exists sm = + let p1 = Option.map fst (EcEnv.Ty.lookup_opt sm ppe.ppe_env) in + let p2 = Option.map fst (EcEnv.TypeClass.lookup_opt sm ppe.ppe_env) in + + List.exists + (EcPath.p_equal p) + (Option.to_list p1 @ Option.to_list p2) in p_shorten exists p - let tc_symb (ppe : t) p = + let tc_symb (ppe : t) p = let exists sm = try EcPath.p_equal (EcEnv.TypeClass.lookup_path sm ppe.ppe_env) p with EcEnv.LookupFailure _ -> false in p_shorten exists p + let tci_symb (ppe : t) p = + let exists sm = + try EcPath.p_equal (EcEnv.TcInstance.lookup_path sm ppe.ppe_env) p + with EcEnv.LookupFailure _ -> false + in + p_shorten exists p + let rw_symb (ppe : t) p = - let exists sm = - try EcPath.p_equal (EcEnv.BaseRw.lookup_path sm ppe.ppe_env) p - with EcEnv.LookupFailure _ -> false - in - p_shorten exists p + let exists sm = + try EcPath.p_equal (EcEnv.BaseRw.lookup_path sm ppe.ppe_env) p + with EcEnv.LookupFailure _ -> false + in + p_shorten exists p let ax_symb (ppe : t) p = let exists sm = @@ -185,7 +196,7 @@ module PPEnv = struct in p_shorten exists p - let op_symb (ppe : t) p info = + let op_symb (ppe : t) (p : P.path) (info : ([`Expr | `Form] * etyarg list * dom) option) = let specs = [1, EcPath.pqoname (EcPath.prefix EcCoreLib.CI_Bool.p_eq) "<>"] in let check_for_local sm = @@ -199,13 +210,13 @@ module PPEnv = struct check_for_local sm; EcEnv.Op.lookup_path sm ppe.ppe_env - | Some (mode, typ, dom) -> + | Some (mode, ety, dom) -> let filter = match mode with | `Expr -> fun _ op -> not (EcDecl.is_pred op) | `Form -> fun _ _ -> true in - let tvi = Some (EcUnify.TVIunamed typ) in + let tvi = Some (EcUnify.tvi_unamed ety) in fun sm -> check_for_local sm; @@ -326,12 +337,12 @@ module PPEnv = struct let tyvar (ppe : t) x = match Mid.find_opt x ppe.ppe_locals with - | None -> EcIdent.tostring x + | None -> EcIdent.name x | Some x -> x exception FoundUnivarSym of symbol - let tyunivar (ppe : t) i = + let univar (ppe : t) (i : EcUid.uid) = if not (Mint.mem i (fst !(ppe.ppe_univar))) then begin let alpha = "abcdefghijklmnopqrstuvwxyz" in @@ -425,6 +436,14 @@ let pp_paren pp fmt x = let pp_maybe_paren c pp = pp_maybe c pp_paren pp +(* -------------------------------------------------------------------- *) +let pp_bracket pp fmt x = + pp_enclose ~pre:"[" ~post:"]" pp fmt x + +(* -------------------------------------------------------------------- *) +let pp_maybe_bracket c pp = + pp_maybe c pp_bracket pp + (* -------------------------------------------------------------------- *) let pp_string fmt x = Format.fprintf fmt "%s" x @@ -457,8 +476,12 @@ let pp_tyvar ppe fmt x = Format.fprintf fmt "%s" (PPEnv.tyvar ppe x) (* -------------------------------------------------------------------- *) -let pp_tyunivar ppe fmt x = - Format.fprintf fmt "%s" (PPEnv.tyunivar ppe x) +let pp_tyunivar (ppe : PPEnv.t) (fmt : Format.formatter) (a : tyuni) = + Format.fprintf fmt "%s" (PPEnv.univar ppe (a :> EcUid.uid)) + +(* -------------------------------------------------------------------- *) +let pp_tcunivar (ppe : PPEnv.t) (fmt : Format.formatter) (a : tcuni) = + Format.fprintf fmt "%s" (PPEnv.univar ppe (a :> EcUid.uid)) (* -------------------------------------------------------------------- *) let pp_tyname ppe fmt p = @@ -469,6 +492,10 @@ let pp_tcname ppe fmt p = Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.tc_symb ppe p) (* -------------------------------------------------------------------- *) +let pp_tciname ppe fmt p = + Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.tci_symb ppe p) + + (* -------------------------------------------------------------------- *) let pp_rwname ppe fmt p = Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.rw_symb ppe p) @@ -725,7 +752,7 @@ let rec pp_type_r ppe outer fmt ty = (pp_paren (pp_list ",@ " subpp)) xs (pp_tyname ppe) name in - maybe_paren_nosc outer t_prio_name pp fmt (name, tyargs) + maybe_paren_nosc outer t_prio_name pp fmt (name, List.fst tyargs) end | Tfun (t1, t2) -> @@ -922,6 +949,61 @@ let pp_app (ppe : PPEnv.t) (pp_first, pp_sub) outer fmt (e, args) = in maybe_paren outer ([], e_app_prio) pp fmt () +(* -------------------------------------------------------------------- *) +let pp_opname fmt (nm, op) = + let op = + if EcCoreLib.is_mixfix_op op then + Printf.sprintf "\"%s\"" op + else if is_binop op then begin + if op.[0] = '*' || op.[String.length op - 1] = '*' + then Format.sprintf "( %s )" op + else Format.sprintf "(%s)" op + end else op + + in EcSymbols.pp_qsymbol fmt (nm, op) + +(* -------------------------------------------------------------------- *) +let rec pp_etyarg (ppe : PPEnv.t) (fmt : Format.formatter) ((ty, tcws) : etyarg) = + Format.fprintf fmt "%a[%a]" (pp_type ppe) ty (pp_tcws ppe) tcws + +(* -------------------------------------------------------------------- *) +and pp_etyargs (ppe : PPEnv.t) (fmt : Format.formatter) (etys : etyarg list) = + Format.fprintf fmt "%a" (pp_list ",@ " (pp_etyarg ppe)) etys + +(* -------------------------------------------------------------------- *) +and pp_tcw (ppe : PPEnv.t) (fmt : Format.formatter) (tcw : tcwitness) = + match tcw with + | TCIUni uid -> + Format.fprintf fmt "%a" (pp_tcunivar ppe) uid + + | TCIConcrete { path; etyargs } -> + Format.fprintf fmt "%a[%a]" + (pp_tciname ppe) path (pp_etyargs ppe) etyargs + + | TCIAbstract { support = `Var x; offset } -> + Format.fprintf fmt "%a.`%d" (pp_tyvar ppe) x (offset + 1) + + | TCIAbstract { support = `Abs path; offset } -> + Format.fprintf fmt "%a.`%d" (pp_tyname ppe) path (offset + 1) + +(* -------------------------------------------------------------------- *) +and pp_tcws (ppe : PPEnv.t) (fmt : Format.formatter) (tcws : tcwitness list) = + Format.fprintf fmt "%a" (pp_list ",@ " (pp_tcw ppe)) tcws + +(* -------------------------------------------------------------------- *) +let pp_opname_with_tvi + (ppe : PPEnv.t) + (fmt : Format.formatter) + ((nm, op, tvi) : symbol list * symbol * etyarg list option) += + match tvi with + | None -> + pp_opname fmt (nm, op) + + | Some tvi -> + Format.fprintf fmt "%a<:%a>" + pp_opname (nm, op) (pp_etyargs ppe) tvi + (* -------------------------------------------------------------------- *) let pp_opapp (ppe : PPEnv.t) @@ -937,7 +1019,7 @@ let pp_opapp (fmt : Format.formatter) ((pred : [`Expr | `Form]), (op : EcPath.path), - (tvi : EcTypes.ty list), + (tvi : EcTypes.etyarg list), (es : 'a list)) = let (nm, opname) = @@ -1001,12 +1083,13 @@ let pp_opapp fun () -> match es with | [] -> - pp_opname fmt (nm, opname) + pp_opname_with_tvi ppe fmt (nm, opname, Some tvi) | _ -> - let pp_subs = ((fun _ _ -> pp_opname), pp_sub) in - let pp fmt () = pp_app ppe pp_subs outer fmt (([], opname), es) in - maybe_paren outer (inm, max_op_prec) pp fmt () + let pp_subs = ((fun ppe _ -> pp_opname_with_tvi ppe), pp_sub) in + let pp fmt () = + pp_app ppe pp_subs outer fmt (([], opname, Some tvi), es) + in maybe_paren outer (inm, max_op_prec) pp fmt () and try_pp_as_uniop () = match es with @@ -1343,7 +1426,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form) else l_l f2 onm e_bin_prio_rop4 | Fapp ({f_node = Fop (op, tys)}, [f1; f2]) -> (let (inm, opname) = - PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in + PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in (* FIXME: TC *) if inm <> [] && inm <> onm then None else match priority_of_binop opname with @@ -1534,7 +1617,8 @@ and try_pp_chained_orderings (ppe : PPEnv.t) outer fmt f = match collect [] None f with | None | Some (_, ([] | [_])) -> false | Some (f, fs) -> - pp_chained_orderings ppe f_ty pp_form_r outer fmt (f, fs); + pp_chained_orderings + ppe f_ty pp_form_r outer fmt (f, fs); true and try_pp_lossless (ppe : PPEnv.t) outer fmt f = @@ -1575,11 +1659,11 @@ and try_pp_notations (ppe : PPEnv.t) outer fmt f = let ev = MEV.of_idents (List.map fst nt.ont_args) `Form in let ue = EcUnify.UniEnv.create None in let ov = EcUnify.UniEnv.opentvi ue tv None in - let ti = Tvar.subst ov in + let ti = Tvar.subst ov.subst in let hy = EcEnv.LDecl.init ppe.PPEnv.ppe_env [] in let mr = odfl mhr (EcEnv.Memory.get_active ppe.PPEnv.ppe_env) in let bd = form_of_expr mr nt.ont_body in - let bd = Fsubst.f_subst_tvar ~freshen:true ov bd in + let bd = Fsubst.f_subst_tvar ~freshen:true ov.subst bd in try let (ue, ev) = @@ -1825,7 +1909,7 @@ and pp_form_core_r (ppe : PPEnv.t) outer fmt f = (string_of_hcmp hs.bhs_cmp) (pp_form_r ppef (fst outer, (max_op_prec,`NonAssoc))) hs.bhs_bd - | Fpr pr-> + | Fpr pr -> let me = EcEnv.Fun.prF_memenv EcFol.mhr pr.pr_fun ppe.PPEnv.ppe_env in let ppep = PPEnv.create_and_push_mem ppe ~active:true me in @@ -1842,16 +1926,19 @@ and pp_form_core_r (ppe : PPEnv.t) outer fmt f = (pp_form ppep) pr.pr_event and pp_form_r (ppe : PPEnv.t) outer fmt f = - let printers = - [try_pp_notations; - try_pp_form_eqveq; - try_pp_chained_orderings; - try_pp_lossless] - in + let doit fmt = + let printers = + [try_pp_notations; + try_pp_form_eqveq; + try_pp_chained_orderings; + try_pp_lossless] + in - match List.ofind (fun pp -> pp ppe outer fmt f) printers with - | Some _ -> () - | None -> pp_form_core_r ppe outer fmt f + match List.ofind (fun pp -> pp ppe outer fmt f) printers with + | Some _ -> () + | None -> pp_form_core_r ppe outer fmt f + + in Format.fprintf fmt "(%t : %a)" doit (pp_type ppe) f.f_ty and pp_form ppe fmt f = pp_form_r ppe ([], (min_op_prec, `NonAssoc)) fmt f @@ -2075,14 +2162,30 @@ let pp_typedecl (ppe : PPEnv.t) fmt (x, tyd) = in Format.fprintf fmt "@[%a%t%t.@]" pp_locality tyd.tyd_loca pp_prelude pp_body +(* -------------------------------------------------------------------- *) +let pp_typeclass (ppe : PPEnv.t) fmt tc = + match tc.tc_args with + | [] -> + pp_tyname ppe fmt tc.tc_name + + | [ty] -> + Format.fprintf fmt "%a %a" + (pp_type ppe) (fst ty) + (pp_tyname ppe) tc.tc_name + + | tys -> + Format.fprintf fmt "(%a) %a" + (pp_list ",@ " (pp_type ppe)) (List.fst tys) + (pp_tyname ppe) tc.tc_name + (* -------------------------------------------------------------------- *) let pp_tyvar_ctt (ppe : PPEnv.t) fmt (tvar, ctt) = - match EcPath.Sp.elements ctt with + match ctt with | [] -> pp_tyvar ppe fmt tvar | ctt -> Format.fprintf fmt "%a <: %a" (pp_tyvar ppe) tvar - (pp_list " &@ " (pp_tcname ppe)) ctt + (pp_list " &@ " (fun fmt tc -> pp_typeclass ppe fmt tc)) ctt (* -------------------------------------------------------------------- *) let pp_tyvarannot (ppe : PPEnv.t) fmt ids = @@ -2297,8 +2400,9 @@ let pp_opdecl_op (ppe : PPEnv.t) fmt (basename, ts, ty, op) = (pp_type ppe) fix.opf_resty (pp_list "@\n" pp_branch) cfix - | Some (OP_TC) -> - Format.fprintf fmt "= < type-class-operator >" + | Some (OP_TC (path, name)) -> + Format.fprintf fmt ": %a = < type-class operator `%s' of `%a'>" + (pp_type ppe) ty name (pp_tyname ppe) path in match ts with @@ -2853,8 +2957,8 @@ let pp_equivS (ppe : PPEnv.t) ?prpo fmt es = let insync = EcMemory.mt_equal (snd es.es_ml) (snd es.es_mr) - && EcReduction.EqTest.for_stmt - ppe.PPEnv.ppe_env ~norm:false es.es_sl es.es_sr in +(* && EcReduction.EqTest.for_stmt + ppe.PPEnv.ppe_env ~norm:false es.es_sl es.es_sr in *) in let ppnode = if insync then begin @@ -2889,6 +2993,46 @@ let pp_rwbase ppe fmt (p, rws) = Format.fprintf fmt "%a = %a@\n%!" (pp_rwname ppe) p (pp_list ", " (pp_axname ppe)) (Sp.elements rws) +(* -------------------------------------------------------------------- *) +let pp_tparam ppe fmt (id, tcs) = + Format.fprintf fmt "%a <: %a" + pp_symbol (EcIdent.name id) + (pp_list " &@ " (pp_typeclass ppe)) tcs + +let pp_tparams ppe fmt tparams = + Format.fprintf fmt "%a" + (pp_maybe (List.length tparams != 0) (pp_enclose ~pre:"[" ~post:"] ") (pp_list ",@ " (pp_tparam ppe))) tparams + +let pp_prt ppe = + pp_option (pp_enclose ~pre:" <: " ~post:"" (pp_typeclass ppe)) + +let pp_op ppe fmt (t, ty) = + Format.fprintf fmt " @[op %s :@ %a.@]" + (EcIdent.name t) + (pp_type ppe) ty + +let pp_ops ppe fmt ops = + pp_maybe (List.length ops != 0) (pp_enclose ~pre:"" ~post:"@,@,") (pp_list "@,@," (pp_op ppe)) fmt ops + +let pp_ax ppe fmt (s, f) = + Format.fprintf fmt " @[axiom %s :@ %a.@]" + s (pp_form ppe) f + +let pp_axs ppe fmt axs = + pp_maybe (List.length axs != 0) (pp_enclose ~pre:"" ~post:"@,@,") (pp_list "@,@," (pp_ax ppe)) fmt axs + +let pp_ops_axs ppe fmt (ops, axs) = + Format.fprintf fmt "%a%a" + (pp_maybe (List.length ops + List.length axs != 0) (pp_enclose ~pre:"@,@," ~post:"") (pp_ops ppe)) ops + (pp_axs ppe) axs + +let pp_tc_decl ppe fmt (p, tcdecl) = + Format.fprintf fmt "@[type class %a%a%a = {%a}.@]" + (pp_tparams ppe) tcdecl.tc_tparams + (pp_tyname ppe) p + (pp_prt ppe) tcdecl.tc_prt + (pp_ops_axs ppe) (tcdecl.tc_ops, tcdecl.tc_axs) + (* -------------------------------------------------------------------- *) let pp_solvedb ppe fmt db = List.iter (fun (lvl, ps) -> @@ -3172,7 +3316,7 @@ let rec pp_instr_r (ppe : PPEnv.t) fmt i = let pp_branch fmt ((vars, s), (cname, _)) = let ptn = EcTypes.toarrow (List.snd vars) e.e_ty in - let ptn = f_op (EcPath.pqoname (EcPath.prefix p) cname) typ ptn in + let ptn = f_op_tc (EcPath.pqoname (EcPath.prefix p) cname) typ ptn in let ptn = f_app ptn (List.map (fun (x, ty) -> f_local x ty) vars) e.e_ty in Format.fprintf fmt "| %a => @[%a@]@ " @@ -3323,10 +3467,10 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = | EcTheory.Th_typeclass _ -> Format.fprintf fmt "typeclass ." - | EcTheory.Th_instance ((typ, ty), tc, lc) -> begin - let ppe = PPEnv.add_locals ppe (List.map fst typ) in (* FIXME *) + | EcTheory.Th_instance (_, tci) -> begin + let ppe = PPEnv.add_locals ppe (List.fst tci.tci_params) in - match tc with + match tci.tci_instance with | (`Ring _ | `Field _) as tc -> begin let (name, ops) = let rec ops_of_ring cr = @@ -3362,10 +3506,10 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = in Format.fprintf fmt "%ainstance %s with [%a] %a@\n@[ %a@]" - pp_locality lc + pp_locality tci.tci_local name - (pp_paren (pp_list ",@ " (pp_tyvar ppe))) (List.map fst typ) - (pp_type ppe) ty + (pp_paren (pp_list ",@ " (pp_tyvar ppe))) (List.fst tci.tci_params) + (pp_type ppe) tci.tci_type (pp_list "@\n" (fun fmt (name, op) -> Format.fprintf fmt "op %s = %s" @@ -3373,9 +3517,11 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) = ops end - | `General p -> + | `General (tc, _) -> Format.fprintf fmt "%ainstance %a with %a." - pp_locality lc (pp_type ppe) ty pp_path p + pp_locality tci.tci_local + (pp_type ppe) tci.tci_type + (pp_typeclass ppe) tc end | EcTheory.Th_baserw (name, _lc) -> @@ -3533,6 +3679,12 @@ module ObjectInfo = struct | `Rewrite name -> pr_rw fmt env name | `Solve name -> pr_at fmt env name + (* ------------------------------------------------------------------ *) + let pr_tc_r = + { od_name = "type classes"; + od_lookup = EcEnv.TypeClass.lookup; + od_printer = pp_tc_decl; } + (* ------------------------------------------------------------------ *) let pr_any fmt env qs = let printers = [pr_gen_r ~prcat:true pr_ty_r ; @@ -3542,7 +3694,8 @@ module ObjectInfo = struct pr_gen_r ~prcat:true pr_mod_r; pr_gen_r ~prcat:true pr_mty_r; pr_gen_r ~prcat:true pr_rw_r ; - pr_gen_r ~prcat:true pr_at_r ; ] in + pr_gen_r ~prcat:true pr_at_r ; + pr_gen_r ~prcat:true pr_tc_r ; ] in let ok = ref (List.length printers) in diff --git a/src/ecProcSem.ml b/src/ecProcSem.ml index 808ea8674d..97f0b8a657 100644 --- a/src/ecProcSem.ml +++ b/src/ecProcSem.ml @@ -416,7 +416,7 @@ and translate_e (env : senv) (e : expr) = raise SemNotSupported | _ -> - e_map (fun x -> x) (translate_e env) e + e_map (translate_e env) e (* -------------------------------------------------------------------- *) and translate_lv (env : senv) (lv : lvalue) : lpattern = diff --git a/src/ecProofTerm.ml b/src/ecProofTerm.ml index b17fe4649b..5d732b8e63 100644 --- a/src/ecProofTerm.ml +++ b/src/ecProofTerm.ml @@ -120,8 +120,8 @@ let concretize_e_form_gen (CPTEnv subst) ids f = f_forall ids f (* -------------------------------------------------------------------- *) -let concretize_e_form cptenv f = - concretize_e_form_gen cptenv [] f +let concretize_e_form (CPTEnv subst) f = + Fsubst.f_subst subst f (* -------------------------------------------------------------------- *) let rec concretize_e_arg ((CPTEnv subst) as cptenv) arg = @@ -137,7 +137,7 @@ and concretize_e_head ((CPTEnv subst) as cptenv) head = | PTCut (f, s) -> PTCut (Fsubst.f_subst subst f, s) | PTHandle h -> PTHandle h | PTLocal x -> PTLocal x - | PTGlobal (p, tys) -> PTGlobal (p, List.map (ty_subst subst) tys) + | PTGlobal (p, tys) -> PTGlobal (p, List.map (etyarg_subst subst) tys) | PTTerm pt -> PTTerm (concretize_e_pt cptenv pt) and concretize_e_pt ((CPTEnv subst) as cptenv) pt = @@ -191,23 +191,31 @@ let pt_of_hyp_r ptenv x = ptev_ax = ax; } (* -------------------------------------------------------------------- *) -let pt_of_global pf hyps p tys = +let pt_of_global_tc pf hyps p etyargs = let ptenv = ptenv_of_penv hyps pf in - let ax = EcEnv.Ax.instanciate p tys (LDecl.toenv hyps) in + let ax = EcEnv.Ax.instanciate p etyargs (LDecl.toenv hyps) in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys p; + ptev_pt = ptglobal ~tys:etyargs p; ptev_ax = ax; } (* -------------------------------------------------------------------- *) -let pt_of_global_r ptenv p tys = +let pt_of_global pf hyps p tys = + pt_of_global_tc pf hyps p (List.map (fun ty -> (ty, [])) tys) + +(* -------------------------------------------------------------------- *) +let pt_of_global_tc_r ptenv p etyargs = let env = LDecl.toenv ptenv.pte_hy in - let ax = EcEnv.Ax.instanciate p tys env in + let ax = EcEnv.Ax.instanciate p etyargs env in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys p; + ptev_pt = ptglobal ~tys:etyargs p; ptev_ax = ax; } +(* -------------------------------------------------------------------- *) +let pt_of_global_r ptenv p tys = + pt_of_global_tc_r ptenv p (List.map (fun ty -> (ty, [])) tys) + (* -------------------------------------------------------------------- *) let pt_of_handle_r ptenv hd = let g = FApi.get_pregoal_by_id hd ptenv.pte_pe in @@ -222,13 +230,11 @@ let pt_of_uglobal_r ptenv p = let ax = oget (EcEnv.Ax.by_path_opt p env) in let typ, ax = (ax.EcDecl.ax_tparams, ax.EcDecl.ax_spec) in - (* FIXME: TC HOOK *) let fs = EcUnify.UniEnv.opentvi ptenv.pte_ue typ None in - let ax = Fsubst.f_subst_tvar ~freshen:true fs ax in - let typ = List.map (fun (a, _) -> EcIdent.Mid.find a fs) typ in + let ax = Fsubst.f_subst_tvar ~freshen:true fs.subst ax in { ptev_env = ptenv; - ptev_pt = ptglobal ~tys:typ p; + ptev_pt = ptglobal ~tys:fs.args p; ptev_ax = ax; } (* -------------------------------------------------------------------- *) @@ -264,7 +270,7 @@ let pattern_form ?name hyps ~ptn subject = (fun aux f -> if EcReduction.is_alpha_eq hyps f ptn then fx - else f_map (fun ty -> ty) aux f) + else f_map aux f) subject in (x, body) @@ -512,12 +518,10 @@ let process_named_pterm pe (tvi, fp) = PT.pf_check_tvi pe.pte_pe typ tvi; - (* FIXME: TC HOOK *) let fs = EcUnify.UniEnv.opentvi pe.pte_ue typ tvi in - let ax = Fsubst.f_subst_tvar ~freshen:false fs ax in - let typ = List.map (fun (a, _) -> EcIdent.Mid.find a fs) typ in + let ax = Fsubst.f_subst_tvar ~freshen:false fs.subst ax in - (p, (typ, ax)) + (p, (fs.args, ax)) (* ------------------------------------------------------------------ *) let process_pterm_cut ~prcut pe pt = @@ -908,7 +912,7 @@ let tc1_process_full_closed_pterm (tc : tcenv1) (ff : ppterm) = (* -------------------------------------------------------------------- *) type prept = [ | `Hy of EcIdent.t - | `G of EcPath.path * ty list + | `G of EcPath.path * etyarg list | `UG of EcPath.path | `HD of handle | `App of prept * prept_arg list @@ -928,8 +932,8 @@ let pt_of_prept tc (pt : prept) = let rec build_pt = function | `Hy id -> pt_of_hyp_r ptenv id - | `G (p, tys) -> pt_of_global_r ptenv p tys - | `UG p -> pt_of_global_r ptenv p [] + | `G (p, tys) -> pt_of_global_tc_r ptenv p tys + | `UG p -> pt_of_global_tc_r ptenv p [] | `HD hd -> pt_of_handle_r ptenv hd | `App (pt, args) -> List.fold_left app_pt_ev (build_pt pt) args diff --git a/src/ecProofTerm.mli b/src/ecProofTerm.mli index 55ec0f6c84..55b2f5ff31 100644 --- a/src/ecProofTerm.mli +++ b/src/ecProofTerm.mli @@ -150,12 +150,13 @@ val ptenv : proofenv -> LDecl.hyps -> (EcUnify.unienv * mevmap) -> pt_env val copy : pt_env -> pt_env (* Proof-terms construction from components *) -val pt_of_hyp : proofenv -> LDecl.hyps -> EcIdent.t -> pt_ev -val pt_of_global_r : pt_env -> EcPath.path -> ty list -> pt_ev -val pt_of_global : proofenv -> LDecl.hyps -> EcPath.path -> ty list -> pt_ev -val pt_of_uglobal_r : pt_env -> EcPath.path -> pt_ev -val pt_of_uglobal : proofenv -> LDecl.hyps -> EcPath.path -> pt_ev - +val pt_of_hyp : proofenv -> LDecl.hyps -> EcIdent.t -> pt_ev +val pt_of_global_tc_r : pt_env -> EcPath.path -> etyarg list -> pt_ev +val pt_of_global_tc : proofenv -> LDecl.hyps -> EcPath.path -> etyarg list -> pt_ev +val pt_of_global_r : pt_env -> EcPath.path -> ty list -> pt_ev +val pt_of_global : proofenv -> LDecl.hyps -> EcPath.path -> ty list -> pt_ev +val pt_of_uglobal_r : pt_env -> EcPath.path -> pt_ev +val pt_of_uglobal : proofenv -> LDecl.hyps -> EcPath.path -> pt_ev (* -------------------------------------------------------------------- *) val ffpattern_of_genpattern : LDecl.hyps -> genpattern -> ppterm option @@ -163,7 +164,7 @@ val ffpattern_of_genpattern : LDecl.hyps -> genpattern -> ppterm option (* -------------------------------------------------------------------- *) type prept = [ | `Hy of EcIdent.t - | `G of EcPath.path * ty list + | `G of EcPath.path * etyarg list | `UG of EcPath.path | `HD of handle | `App of prept * prept_arg list @@ -184,7 +185,7 @@ module Prept : sig val (@) : prept -> prept_arg list -> prept val hyp : EcIdent.t -> prept - val glob : EcPath.path -> ty list -> prept + val glob : EcPath.path -> etyarg list -> prept val uglob : EcPath.path -> prept val hdl : handle -> prept diff --git a/src/ecProofTyping.ml b/src/ecProofTyping.ml index eb1cc232a3..01fd18cc49 100644 --- a/src/ecProofTyping.ml +++ b/src/ecProofTyping.ml @@ -25,9 +25,9 @@ let process_form_opt ?mv hyps pf oty = let ts = Tuni.subst (EcUnify.UniEnv.close ue) in EcFol.Fsubst.f_subst ts ff - with EcUnify.UninstanciateUni -> + with EcUnify.UninstanciateUni infos -> EcTyping.tyerror pf.EcLocation.pl_loc - (LDecl.toenv hyps) EcTyping.FreeTypeVariables + (LDecl.toenv hyps) (FreeUniVariables infos) let process_form ?mv hyps pf ty = process_form_opt ?mv hyps pf (Some ty) @@ -188,7 +188,7 @@ let tc1_process_codepos1 tc (side, cpos) = (* ------------------------------------------------------------------ *) (* FIXME: factor out to typing module *) -(* FIXME: TC HOOK - check parameter constraints *) +(* FIXME:TC HOOK - check parameter constraints *) (* ------------------------------------------------------------------ *) let pf_check_tvi (pe : proofenv) typ tvi = match tvi with diff --git a/src/ecReduction.ml b/src/ecReduction.ml index 355d420dc4..baf639d5c8 100644 --- a/src/ecReduction.ml +++ b/src/ecReduction.ml @@ -15,47 +15,15 @@ exception IncompatibleType of env * (ty * ty) exception IncompatibleForm of env * (form * form) exception IncompatibleExpr of env * (expr * expr) -(* -------------------------------------------------------------------- *) -type 'a eqtest = env -> 'a -> 'a -> bool +type 'a eqtest = env -> 'a -> 'a -> bool type 'a eqntest = env -> ?norm:bool -> 'a -> 'a -> bool type 'a eqantest = env -> ?alpha:(EcIdent.t * ty) Mid.t -> ?norm:bool -> 'a -> 'a -> bool +(* -------------------------------------------------------------------- *) module EqTest_base = struct - let rec for_type env t1 t2 = - ty_equal t1 t2 || for_type_r env t1 t2 - - and for_type_r env t1 t2 = - match t1.ty_node, t2.ty_node with - | Tunivar uid1, Tunivar uid2 -> EcUid.uid_equal uid1 uid2 - - | Tvar i1, Tvar i2 -> i1 = i2 - - | Ttuple lt1, Ttuple lt2 -> - List.length lt1 = List.length lt2 - && List.all2 (for_type env) lt1 lt2 - - | Tfun (t1, t2), Tfun (t1', t2') -> - for_type env t1 t1' && for_type env t2 t2' - - | Tglob m1, Tglob m2 -> EcIdent.id_equal m1 m2 - - | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> - if - List.length lt1 = List.length lt2 - && List.all2 (for_type env) lt1 lt2 - then true - else - if Ty.defined p1 env - then for_type env (Ty.unfold p1 lt1 env) (Ty.unfold p2 lt2 env) - else false - - | Tconstr(p1,lt1), _ when Ty.defined p1 env -> - for_type env (Ty.unfold p1 lt1 env) t2 - - | _, Tconstr(p2,lt2) when Ty.defined p2 env -> - for_type env t1 (Ty.unfold p2 lt2 env) - - | _, _ -> false + (* ------------------------------------------------------------------ *) + let for_type = EcCoreEqTest.for_type + let for_etyarg = EcCoreEqTest.for_etyarg (* ------------------------------------------------------------------ *) let is_unit env ty = for_type env tunit ty @@ -134,7 +102,7 @@ module EqTest_base = struct for_pv env ~norm p1 p2 | Eop(o1,ty1), Eop(o2,ty2) -> - p_equal o1 o2 && List.all2 (for_type env) ty1 ty2 + p_equal o1 o2 && List.all2 (for_etyarg env) ty1 ty2 | Equant(q1,b1,e1), Equant(q2,b2,e2) when eqt_equal q1 q2 -> let alpha = check_bindings env alpha b1 b2 in @@ -403,6 +371,9 @@ let ensure b = if b then () else raise NotConv let check_ty env subst ty1 ty2 = ensure (EqTest_base.for_type env ty1 (ty_subst subst ty2)) +let check_etyarg env subst etyarg1 etyarg2 = + ensure (EqTest_base.for_etyarg env etyarg1 (etyarg_subst subst etyarg2)) + let add_local (env, subst) (x1, ty1) (x2, ty2) = check_ty env subst ty1 ty2; env, @@ -528,7 +499,7 @@ let is_alpha_eq hyps f1 f2 = check_mod subst m1 m2 | Fop(p1, ty1), Fop(p2, ty2) when EcPath.p_equal p1 p2 -> - List.iter2 (check_ty env subst) ty1 ty2 + List.iter2 (check_etyarg env subst) ty1 ty2 | Fapp(f1',args1), Fapp(f2',args2) when List.length args1 = List.length args2 -> @@ -620,6 +591,7 @@ type reduction_info = { beta : bool; delta_p : (path -> deltap); (* reduce operators *) delta_h : (ident -> bool); (* reduce local definitions *) + delta_tc : bool; zeta : bool; iota : bool; eta : bool; @@ -636,6 +608,7 @@ let full_red = { beta = true; delta_p = (fun _ -> `IfTransparent); delta_h = EcUtils.predT; + delta_tc = true; zeta = true; iota = true; eta = true; @@ -645,15 +618,16 @@ let full_red = { } let no_red = { - beta = false; - delta_p = (fun _ -> `No); - delta_h = EcUtils.pred0; - zeta = false; - iota = false; - eta = false; - logic = None; - modpath = false; - user = false; + beta = false; + delta_p = (fun _ -> `No); + delta_h = EcUtils.pred0; + delta_tc = false; + zeta = false; + iota = false; + eta = false; + logic = None; + modpath = false; + user = false; } let beta_red = { no_red with beta = true; } @@ -661,8 +635,8 @@ let betaiota_red = { no_red with beta = true; iota = true; } let nodelta = { full_red with - delta_h = EcUtils.pred0; - delta_p = (fun _ -> `No); } + delta_h = EcUtils.pred0; + delta_p = (fun _ -> `No); } let delta = { no_red with delta_p = (fun _ -> `IfTransparent); } @@ -692,6 +666,15 @@ let reduce_op ri env nargs p tys = Op.reduce ~mode ~nargs env p tys with NotReducible -> raise nohead +let reduce_tc_op (ri : reduction_info) (env : EcEnv.env) (p : path) (tys : etyarg list) = + if ri.delta_tc then + try + Op.tc_reduce env p tys + with NotReducible -> raise nohead + else + raise nohead + +(* -------------------------------------------------------------------- *) let is_record env f = match EcFol.destr_app f with | { f_node = Fop (p, _) }, _ -> EcEnv.Op.is_record_ctor env p @@ -734,8 +717,8 @@ let reduce_user_gen simplify ri env hyps f = oget ~exn:needsubterm (List.Exceptionless.find_map (fun rule -> try - let ue = EcUnify.UniEnv.create None in - let tvi = EcUnify.UniEnv.opentvi ue rule.R.rl_tyd None in + let ue = EcUnify.UniEnv.create None in + let tvi = EcUnify.UniEnv.opentvi ue rule.R.rl_tyd None in let check_alpha_eq f f' = if not (is_alpha_eq hyps f f') then raise NotReducible @@ -753,7 +736,8 @@ let reduce_user_gen simplify ri env hyps f = | ({ f_node = Fop (p, tys) }, args), R.Rule (`Op (p', tys'), args') when EcPath.p_equal p p' && List.length args = List.length args' -> - let tys' = List.map (Tvar.subst tvi) tys' in + let tys' = List.map (Tvar.subst tvi.subst) tys' in + let tys = List.fst tys in (* FIXME:TC *) begin try List.iter2 (EcUnify.unify env ue) tys tys' @@ -788,7 +772,7 @@ let reduce_user_gen simplify ri env hyps f = let subst = ts in let subst = Mid.fold (fun x f s -> Fsubst.f_bind_local s x f) !pv subst in - Fsubst.f_subst subst (Fsubst.f_subst_tvar ~freshen:true tvi f) + Fsubst.f_subst subst (Fsubst.f_subst_tvar ~freshen:true tvi.subst f) in List.iter (fun cond -> @@ -867,7 +851,7 @@ let reduce_logic ri env hyps f p args = when EcPath.p_equal p1 p2 && EcEnv.Op.is_record_ctor env p1 && EcEnv.Op.is_record_ctor env p2 - && List.for_all2 (EqTest_i.for_type env) tys1 tys2 -> + && List.for_all2 (EqTest_i.for_etyarg env) tys1 tys2 -> f_ands (List.map2 f_eq args1 args2) @@ -888,14 +872,28 @@ let reduce_logic ri env hyps f p args = check_reduced hyps needsubterm f f' (* -------------------------------------------------------------------- *) -let reduce_delta ri env _hyps f = +let reduce_delta ri env f = match f.f_node with | Fop (p, tys) when ri.delta_p p <> `No -> - reduce_op ri env 0 p tys + reduce_op ri env 0 p tys | Fapp ({ f_node = Fop (p, tys) }, args) when ri.delta_p p <> `No -> - let op = reduce_op ri env (List.length args) p tys in - f_app_simpl op args f.f_ty + let op = reduce_op ri env (List.length args) p tys in + f_app_simpl op args f.f_ty + + | _ -> raise nohead + +(* -------------------------------------------------------------------- *) +let reduce_tc ri env f = + match f.f_node with + | Fop (p, etyargs) when ri.delta_tc && Op.tc_reducible env p etyargs -> + reduce_tc_op ri env p etyargs + + | Fapp ({ f_node = Fop (p, etyargs) }, args) + when ri.delta_tc && Op.tc_reducible env p etyargs + -> + let op = reduce_tc_op ri env p etyargs in + f_app_simpl op args f.f_ty | _ -> raise nohead @@ -1048,7 +1046,10 @@ let reduce_head simplify ri env hyps f = let body = EcFol.form_of_expr EcFol.mhr body in (* FIXME subst-refact can we do both subst in once *) let body = - Tvar.f_subst ~freshen:true (List.map fst op.EcDecl.op_tparams) tys body in + Tvar.f_subst ~freshen:true + (List.combine + (List.map fst op.EcDecl.op_tparams) + tys) body in f_app (Fsubst.f_subst subst body) eargs f.f_ty @@ -1065,20 +1066,24 @@ let reduce_head simplify ri env hyps f = when ri.eta && can_eta x (fn, args) -> f_app fn (List.take (List.length args - 1) args) f.f_ty - | Fop _ -> begin + | Fop _ -> + oget ~exn:nohead @@ + List.find_map_opt + (fun cb -> try Some (cb f) with NotRed _ -> None) + [ reduce_user_gen simplify ri env hyps + ; reduce_delta ri env + ; reduce_tc ri env ] + + | Fapp ({ f_node = Fop (p, _); }, args) -> begin try - reduce_user_gen simplify ri env hyps f + reduce_logic ri env hyps f p args with NotRed _ -> - reduce_delta ri env hyps f - end - - | Fapp({ f_node = Fop(p,_); }, args) -> begin - try reduce_logic ri env hyps f p args - with NotRed kind1 -> - try reduce_user_gen simplify ri env hyps f - with NotRed kind2 -> - if kind1 = NoHead && kind2 = NoHead then reduce_delta ri env hyps f - else raise needsubterm + oget ~exn:needsubterm @@ + List.find_map_opt + (fun cb -> try Some (cb f) with NotRed NoHead -> None) + [ reduce_user_gen simplify ri env hyps + ; reduce_delta ri env + ; reduce_tc ri env ] end | Ftuple _ -> begin @@ -1179,9 +1184,18 @@ and reduce_head_top_force ri env onhead f = match reduce_head_sub ri env f with | f -> if onhead then reduce_head_top ri env ~onhead f else f - | exception (NotRed _) -> - try reduce_delta ri.ri env ri.hyps f - with NotRed _ -> RedTbl.set_norm ri.redtbl f; raise nohead + | exception (NotRed _) -> begin + match + List.find_map_opt + (fun cb -> try Some (cb ri.ri env f) with NotRed _ -> None) + [reduce_delta; reduce_tc] + with + | Some f -> + f + | None -> + RedTbl.set_norm ri.redtbl f; + raise nohead + end end and reduce_head_sub ri env f = @@ -1242,36 +1256,36 @@ let rec simplify ri env f = match f.f_node with | FhoareF hf when ri.ri.modpath -> let hf_f = EcEnv.NormMp.norm_xfun env hf.hf_f in - f_map (fun ty -> ty) (simplify ri env) (f_hoareF_r { hf with hf_f }) + f_map (simplify ri env) (f_hoareF_r { hf with hf_f }) | FeHoareF hf when ri.ri.modpath -> let ehf_f = EcEnv.NormMp.norm_xfun env hf.ehf_f in - f_map (fun ty -> ty) (simplify ri env) (f_eHoareF_r { hf with ehf_f }) + f_map (simplify ri env) (f_eHoareF_r { hf with ehf_f }) | FbdHoareF hf when ri.ri.modpath -> let bhf_f = EcEnv.NormMp.norm_xfun env hf.bhf_f in - f_map (fun ty -> ty) (simplify ri env) (f_bdHoareF_r { hf with bhf_f }) + f_map (simplify ri env) (f_bdHoareF_r { hf with bhf_f }) | FequivF ef when ri.ri.modpath -> let ef_fl = EcEnv.NormMp.norm_xfun env ef.ef_fl in let ef_fr = EcEnv.NormMp.norm_xfun env ef.ef_fr in - f_map (fun ty -> ty) (simplify ri env) (f_equivF_r { ef with ef_fl; ef_fr; }) + f_map (simplify ri env) (f_equivF_r { ef with ef_fl; ef_fr; }) | FeagerF eg when ri.ri.modpath -> let eg_fl = EcEnv.NormMp.norm_xfun env eg.eg_fl in let eg_fr = EcEnv.NormMp.norm_xfun env eg.eg_fr in - f_map (fun ty -> ty) (simplify ri env) (f_eagerF_r { eg with eg_fl ; eg_fr; }) + f_map (simplify ri env) (f_eagerF_r { eg with eg_fl ; eg_fr; }) | Fpr pr when ri.ri.modpath -> let pr_fun = EcEnv.NormMp.norm_xfun env pr.pr_fun in - f_map (fun ty -> ty) (simplify ri env) (f_pr_r { pr with pr_fun }) + f_map (simplify ri env) (f_pr_r { pr with pr_fun }) | Fquant (q, bd, f) -> let env = Mod.add_mod_binding bd env in f_quant q bd (simplify ri env f) | _ -> - f_map (fun ty -> ty) (simplify ri env) f + f_map (simplify ri env) f let simplify ri hyps f = let ri, env = init_redinfo ri hyps in @@ -1365,6 +1379,9 @@ let zpop ri side f hd = let rec conv ri env f1 f2 stk = if f_equal f1 f2 then conv_next ri env f1 stk else match f1.f_node, f2.f_node with + | Flocal x, Flocal y when EcIdent.id_equal x y -> + true + | Fquant (q1, bd1, f1'), Fquant(q2,bd2,f2') -> if q1 <> q2 then force_head_sub ri env f1 f2 stk else @@ -1418,7 +1435,8 @@ let rec conv ri env f1 f2 stk = end | Fop(p1, ty1), Fop(p2,ty2) - when EcPath.p_equal p1 p2 && List.all2 (EqTest_i.for_type env) ty1 ty2 -> + when EcPath.p_equal p1 p2 + && List.all2 (EqTest_i.for_etyarg env) ty1 ty2 -> conv_next ri env f1 stk | Fapp(f1', args1), Fapp(f2', args2) @@ -1688,8 +1706,10 @@ module User = struct let rule = let rec rule (f : form) : EcTheory.rule_pattern = match EcFol.destr_app f with - | { f_node = Fop (p, tys) }, args -> - R.Rule (`Op (p, tys), List.map rule args) + | { f_node = Fop (p, etyargs) }, args + when List.for_all (fun (_, ws) -> List.is_empty ws) etyargs + -> (* FIXME: TC *) + R.Rule (`Op (p, List.fst etyargs), List.map rule args) | { f_node = Ftuple args }, [] -> R.Rule (`Tuple, List.map rule args) | { f_node = Fproj (target, i) }, [] -> diff --git a/src/ecReduction.mli b/src/ecReduction.mli index 7d5a47dffb..eac29237f8 100644 --- a/src/ecReduction.mli +++ b/src/ecReduction.mli @@ -19,16 +19,17 @@ type 'a eqantest = env -> ?alpha:(EcIdent.t * ty) Mid.t -> ?norm:bool -> 'a -> ' module EqTest : sig val for_type_exn : env -> ty -> ty -> unit - val for_type : ty eqtest - val for_pv : prog_var eqntest - val for_lv : lvalue eqntest - val for_xp : xpath eqntest - val for_mp : mpath eqntest - val for_instr : instr eqantest - val for_stmt : stmt eqantest - val for_expr : expr eqantest - val for_msig : module_sig eqntest - val for_mexpr : env -> ?norm:bool -> ?body:bool -> module_expr -> module_expr -> bool + val for_type : ty eqtest + val for_etyarg : etyarg eqtest + val for_pv : prog_var eqntest + val for_lv : lvalue eqntest + val for_xp : xpath eqntest + val for_mp : mpath eqntest + val for_instr : instr eqantest + val for_stmt : stmt eqantest + val for_expr : expr eqantest + val for_msig : module_sig eqntest + val for_mexpr : env -> ?norm:bool -> ?body:bool -> module_expr -> module_expr -> bool val is_unit : env -> ty -> bool val is_bool : env -> ty -> bool @@ -64,6 +65,7 @@ type reduction_info = { beta : bool; delta_p : (path -> deltap); (* reduce operators *) delta_h : (ident -> bool); (* reduce local definitions *) + delta_tc : bool; (* reduce tc-operators *) zeta : bool; (* reduce let *) iota : bool; (* reduce case *) eta : bool; (* reduce eta-expansion *) diff --git a/src/ecScope.ml b/src/ecScope.ml index b17295a4d5..e42bf616d6 100644 --- a/src/ecScope.ml +++ b/src/ecScope.ml @@ -874,8 +874,11 @@ module Ax = struct let concl = TT.trans_prop env ue pconcl in - if not (EcUnify.UniEnv.closed ue) then - hierror "the formula contains free type variables"; + Option.iter (fun infos -> + hierror + "the formula contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos + ) (EcUnify.UniEnv.xclosed ue); let uidmap = EcUnify.UniEnv.close ue in let fs = Tuni.subst uidmap in @@ -1088,11 +1091,39 @@ module Op = struct let item = EcTheory.mkitem import (EcTheory.Th_operator (x, op)) in { scope with sc_env = EcSection.add_item item scope.sc_env; } + (* -------------------------------------------------------------------- *) + let axiomatized_op ?(nargs = 0) ?(nosmt = false) path (tparams, axbd) lc = + let axpm, axbd = + let subst, axpm = EcSubst.fresh_tparams EcSubst.empty tparams in + (axpm, EcSubst.subst_form subst axbd) + in + + let args, axbd = + match axbd.f_node with + | Fquant (Llambda, bds, axbd) -> + let bds, flam = List.split_at nargs bds in + (bds, f_lambda flam axbd) + | _ -> [], axbd + in + + let opargs = List.map (fun (x, ty) -> f_local x (gty_as_ty ty)) args in + let opty = toarrow (List.map f_ty opargs) axbd.EcAst.f_ty in + let op = f_op_tc path (etyargs_of_tparams axpm) opty in + let op = f_app op opargs axbd.f_ty in + let axspec = f_forall args (f_eq op axbd) in + + { ax_tparams = axpm; + ax_spec = axspec; + ax_kind = `Axiom (Ssym.empty, false); + ax_loca = lc; + ax_visibility = if nosmt then `NoSmt else `Visible; } + let add (scope : scope) (op : poperator located) = assert (scope.sc_pr_uc = None); let op = op.pl_desc and loc = op.pl_loc in let eenv = env scope in let ue = TT.transtyvars eenv (loc, op.po_tyvars) in + let lc = op.po_locality in let args = fst op.po_args @ odfl [] (snd op.po_args) in let (ty, body, refts) = @@ -1127,8 +1158,11 @@ module Op = struct (opty, `Abstract, [(rname, xs, reft, codom)]) in - if not (EcUnify.UniEnv.closed ue) then - hierror ~loc "this operator type contains free type variables"; + Option.iter (fun infos -> + hierror ~loc + "this operator type contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos + ) (EcUnify.UniEnv.xclosed ue); let uidmap = EcUnify.UniEnv.close ue in let ts = Tuni.subst uidmap in @@ -1171,7 +1205,7 @@ module Op = struct try EcUnify.unify eenv tue ty tfun; - let msg = "this operator type is (unifiable) to a function type" in + let msg = "this operator type is (unifiable to) a function type" in hierror ~loc "%s" msg with EcUnify.UnificationFailure _ -> () end; @@ -1185,7 +1219,7 @@ module Op = struct let path = EcPath.pqname (path scope) (unloc op.po_name) in let axop = let nargs = List.sum (List.map (List.length |- fst) args) in - EcDecl.axiomatized_op ~nargs path (tyop.op_tparams, bd) lc in + axiomatized_op ~nargs path (tyop.op_tparams, bd) lc in let tyop = { tyop with op_opaque = { reduction = true; smt = false; }} in let scope = bind scope (unloc op.po_name, tyop) in Ax.bind scope (unloc ax, axop) @@ -1212,11 +1246,10 @@ module Op = struct ax in - let ax, axpm = - let bdpm = List.map fst tparams in - let axpm = List.map EcIdent.fresh bdpm in - (Tvar.f_subst ~freshen:true bdpm (List.map EcTypes.tvar axpm) ax, - List.combine axpm (List.map snd tparams)) in + let axpm, ax = + let subst, tparams = EcSubst.fresh_tparams EcSubst.empty tparams in + (tparams, EcSubst.subst_form subst ax) in + let ax = { ax_tparams = axpm; ax_spec = ax; @@ -1233,11 +1266,11 @@ module Op = struct hierror ~loc "multiple names are only allowed for non-refined abstract operators"; let addnew scope name = - let nparams = List.map (fst_map EcIdent.fresh) tparams in - let subst = Tvar.init - (List.map fst tparams) - (List.map (tvar |- fst) nparams) in - let rop = EcDecl.mk_op ~opaque:optransparent nparams (Tvar.subst subst ty) None lc in + let subst, nparams = + EcSubst.fresh_tparams EcSubst.empty tparams in + let rop = + EcDecl.mk_op ~opaque:optransparent + nparams (EcSubst.subst_ty subst ty) None lc in bind scope (unloc name, rop) in List.fold_left addnew scope op.po_aliases @@ -1252,10 +1285,18 @@ module Op = struct if not (EcAlgTactic.is_module_loaded (env scope)) then hierror "for tag %s, load Distr first" tag; - let oppath = EcPath.pqname (path scope) (unloc op.po_name) in - let nparams = List.map (EcIdent.fresh |- fst) tyop.op_tparams in - let subst = Tvar.init (List.fst tyop.op_tparams) (List.map tvar nparams) in - let ty = Tvar.subst subst tyop.op_ty in + let subst, nparams = + EcSubst.fresh_tparams EcSubst.empty tyop.op_tparams in + let oppath = EcPath.pqname (path scope) (unloc op.po_name) in + let optyargs = + let mktcw (a : EcIdent.t) (i : int) = + TCIAbstract { support = `Var a; offset = i; } + in + List.map + (fun (a, tcs) -> (tvar a, List.mapi (fun i _ -> mktcw a i) tcs)) + nparams + in + let ty = EcSubst.subst_ty subst tyop.op_ty in let aty, rty = EcTypes.tyfun_flat ty in let dty = @@ -1265,13 +1306,13 @@ module Op = struct in let bds = List.combine (List.map EcTypes.fresh_id_of_ty aty) aty in - let ax = EcFol.f_op oppath (List.map tvar nparams) ty in + let ax = EcFol.f_op_tc oppath optyargs ty in let ax = EcFol.f_app ax (List.map (curry f_local) bds) rty in let ax = EcFol.f_app (EcFol.f_op pred [dty] (tfun rty tbool)) [ax] tbool in let ax = EcFol.f_forall (List.map (snd_map gtty) bds) ax in let ax = - { ax_tparams = List.map (fun ty -> (ty, Sp.empty)) nparams; + { ax_tparams = nparams; ax_spec = ax; ax_kind = `Axiom (Ssym.empty, false); ax_loca = lc; @@ -1578,14 +1619,12 @@ module Ty = struct let tyd_params, tyd_type = match body with | PTYD_Abstract tcs -> - let tcs = - List.map - (fun tc -> fst (EcEnv.TypeClass.lookup (unloc tc) env)) - tcs in let ue = TT.transtyvars env (loc, Some args) in - EcUnify.UniEnv.tparams ue, `Abstract (Sp.of_list tcs) + let tcs = List.map (TT.transtc env ue) tcs in + let tp = EcUnify.UniEnv.tparams ue in + tp, `Abstract tcs - | PTYD_Alias bd -> + | PTYD_Alias bd -> let ue = TT.transtyvars env (loc, Some args) in let body = transty tp_tydecl env ue bd in EcUnify.UniEnv.tparams ue, `Concrete body @@ -1613,7 +1652,7 @@ module Ty = struct { scope with sc_env = EcSection.add_item item scope.sc_env } (* ------------------------------------------------------------------ *) - let add_class (scope : scope) { pl_desc = tcd } = + let add_class (scope : scope) { pl_desc = tcd; pl_loc = loc } = assert (scope.sc_pr_uc = None); let lc = tcd.ptc_loca in let name = unloc tcd.ptc_name in @@ -1622,21 +1661,16 @@ module Ty = struct check_name_available scope tcd.ptc_name; let tclass = - let uptc = - tcd.ptc_inth |> omap - (fun { pl_loc = uploc; pl_desc = uptc } -> - match EcEnv.TypeClass.lookup_opt uptc scenv with - | None -> hierror ~loc:uploc "unknown type-class: `%s'" - (string_of_qsymbol uptc) - | Some (tcp, _) -> tcp) - in + (* Check typeclasses arguments *) + let ue = TT.transtyvars scenv (loc, tcd.ptc_params) in + + let uptc = tcd.ptc_inth |> omap (TT.transtc scenv ue) in let asty = - let body = ofold (fun p tc -> Sp.add p tc) Sp.empty uptc in - { tyd_params = []; - tyd_type = `Abstract body; - tyd_loca = (lc :> locality); - tyd_resolve = true; } in + { tyd_params = []; + tyd_type = `Abstract (otolist uptc); + tyd_resolve = true; + tyd_loca = (lc :> locality); } in let scenv = EcEnv.Ty.bind name asty scenv in (* Check for duplicated field names *) @@ -1650,7 +1684,7 @@ module Ty = struct (* Check operators types *) let operators = let check1 (x, ty) = - let ue = EcUnify.UniEnv.create (Some []) in + let ue = EcUnify.UniEnv.copy ue in let ty = transty tp_tydecl scenv ue ty in let uidmap = EcUnify.UniEnv.close ue in let ty = ty_subst (Tuni.subst uidmap) ty in @@ -1662,7 +1696,7 @@ module Ty = struct let axioms = let scenv = EcEnv.Var.bind_locals operators scenv in let check1 (x, ax) = - let ue = EcUnify.UniEnv.create (Some []) in + let ue = EcUnify.UniEnv.copy ue in let ax = trans_prop scenv ue ax in let uidmap = EcUnify.UniEnv.close ue in let fs = Tuni.subst uidmap in @@ -1672,21 +1706,22 @@ module Ty = struct tcd.ptc_axs |> List.map check1 in (* Construct actual type-class *) - { tc_prt = uptc; tc_ops = operators; tc_axs = axioms; tc_loca = lc} + { tc_prt = uptc; tc_tparams = EcUnify.UniEnv.tparams ue; + tc_ops = operators; tc_axs = axioms; tc_loca = lc; } in bindclass scope (name, tclass) (* ------------------------------------------------------------------ *) let check_tci_operators env tcty ops reqs = - let ue = EcUnify.UniEnv.create (Some (fst tcty)) in - let rmap = Mstr.of_list reqs in + let ue = EcUnify.UniEnv.create (Some (fst tcty)) in let ops = let tt1 m (x, (tvi, op)) = - if not (Mstr.mem (unloc x) rmap) then + if not (Mstr.mem (unloc x) reqs) then hierror ~loc:x.pl_loc "invalid operator name: `%s'" (unloc x); let tvi = List.map (TT.transty tp_tydecl env ue) tvi in + let tvi = List.map (fun ty -> (Some ty, None)) tvi in let selected = EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper) (Some (EcUnify.TVIunamed tvi)) env (unloc op) ue [] @@ -1694,19 +1729,16 @@ module Ty = struct let op = match selected with | [] -> hierror ~loc:op.pl_loc "unknown operator" - | op1::op2::_ -> + | op1 :: op2 :: _ -> hierror ~loc:op.pl_loc "ambiguous operator (%s / %s)" (EcPath.tostring (fst (proj4_1 op1))) (EcPath.tostring (fst (proj4_1 op2))) - | [((p, _), _, _, _)] -> - let op = EcEnv.Op.by_path p env in - let opty = - Tvar.subst - (Tvar.init (List.map fst op.op_tparams) tvi) - op.op_ty - in - (p, opty) + | [((p, opparams), opty, subue, _)] -> + let subst = Tuni.subst (EcUnify.UniEnv.assubst subue) in + let opty = ty_subst subst opty in + let opparams = List.map (etyarg_subst subst) opparams in + ((p, opparams), opty) in Mstr.change @@ -1718,20 +1750,25 @@ module Ty = struct in List.fold_left tt1 Mstr.empty ops in - List.iter - (fun (x, (req, _)) -> + Mstr.iter + (fun x (req, _) -> if req && not (Mstr.mem x ops) then hierror "no definition for operator `%s'" x) reqs; - List.fold_left - (fun m (x, (_, ty)) -> + Mstr.fold + (fun x (_, ty) m -> match Mstr.find_opt x ops with | None -> m - | Some (loc, (p, opty)) -> - if not (EcReduction.EqTest.for_type env ty opty) then - hierror ~loc "invalid type for operator `%s'" x; - Mstr.add x p m) - Mstr.empty reqs + | Some (loc, ((p, opparams), opty)) -> + if not (EcReduction.EqTest.for_type env ty opty) then begin + let ppe = EcPrinting.PPEnv.ofenv env in + hierror ~loc +"invalid type for operator `%s':@\n\ +\ - expected: %a@\n\ +\ - got : %a" + x (EcPrinting.pp_type ppe) ty (EcPrinting.pp_type ppe) opty + end; Mstr.add x (p, opparams) m) + reqs Mstr.empty (* ------------------------------------------------------------------ *) let check_tci_axioms scope mode axs reqs lc = @@ -1786,26 +1823,23 @@ module Ty = struct interactive (* ------------------------------------------------------------------ *) - (* FIXME section: those path does not exists ... - futhermode Ring.ZModule is an abstract theory *) - let p_zmod = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "ZModule"], "zmodule") - let p_ring = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "ComRing"], "ring" ) - let p_idomain = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "IDomain"], "idomain") - let p_field = EcPath.fromqsymbol ([EcCoreLib.i_top; "Ring"; "Field" ], "field" ) + let get_ring_field_op (name : string) (symbols : (path * etyarg list) Mstr.t) = + Option.map + (fun (p, tys) -> assert (List.is_empty tys); p) + (Mstr.find_opt name symbols) - (* ------------------------------------------------------------------ *) let ring_of_symmap env ty kind symbols = { r_type = ty; - r_zero = oget (Mstr.find_opt "rzero" symbols); - r_one = oget (Mstr.find_opt "rone" symbols); - r_add = oget (Mstr.find_opt "add" symbols); - r_opp = (Mstr.find_opt "opp" symbols); - r_mul = oget (Mstr.find_opt "mul" symbols); - r_exp = (Mstr.find_opt "expr" symbols); - r_sub = (Mstr.find_opt "sub" symbols); + r_zero = oget (get_ring_field_op "rzero" symbols); + r_one = oget (get_ring_field_op "rone" symbols); + r_add = oget (get_ring_field_op "add" symbols); + r_opp = (get_ring_field_op "opp" symbols); + r_mul = oget (get_ring_field_op "mul" symbols); + r_exp = (get_ring_field_op "expr" symbols); + r_sub = (get_ring_field_op "sub" symbols); r_kind = kind; r_embed = - (match Mstr.find_opt "ofint" symbols with + (match get_ring_field_op "ofint" symbols with | None when EcReduction.EqTest.for_type env ty tint -> `Direct | None -> `Default | Some p -> `Embed p); } @@ -1821,36 +1855,36 @@ module Ty = struct let uidmap = EcUnify.UniEnv.close ue in (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) in + if not (List.is_empty (fst ty)) then hierror "ring instances cannot be polymorphic"; let symbols = EcAlgTactic.ring_symbols env kind (snd ty) in + let symbols = Mstr.of_list symbols in let symbols = check_tci_operators env ty tci.pti_ops symbols in let cr = ring_of_symmap env (snd ty) kind symbols in let axioms = EcAlgTactic.ring_axioms env cr in let lc = (tci.pti_loca :> locality) in let inter = check_tci_axioms scope mode tci.pti_axs axioms lc in - let add env p = - let item = EcTheory.Th_instance (ty,`General p, tci.pti_loca) in - let item = EcTheory.mkitem import item in - EcSection.add_item item env in - let scope = - { scope with sc_env = - List.fold_left add - (let item = - EcTheory.Th_instance (([], snd ty), `Ring cr, tci.pti_loca) in - let item = EcTheory.mkitem import item in - EcSection.add_item item scope.sc_env) - [p_zmod; p_ring; p_idomain] } + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `Ring cr + ; tci_local = (tci.pti_loca :> locality) } in + + let scope = + let item = EcTheory.Th_instance (None, instance) in + let item = EcTheory.mkitem import item in + { scope with sc_env = EcSection.add_item item scope.sc_env } in - in Ax.add_defer scope inter + Ax.add_defer scope inter (* ------------------------------------------------------------------ *) let field_of_symmap env ty symbols = { f_ring = ring_of_symmap env ty `Integer symbols; - f_inv = oget (Mstr.find_opt "inv" symbols); - f_div = Mstr.find_opt "div" symbols; } + f_inv = oget (get_ring_field_op "inv" symbols); + f_div = get_ring_field_op "div" symbols; } let addfield ~import (scope : scope) mode { pl_desc = tci; pl_loc = loc; } = let env = env scope in @@ -1864,75 +1898,121 @@ module Ty = struct let uidmap = EcUnify.UniEnv.close ue in (EcUnify.UniEnv.tparams ue, ty_subst (Tuni.subst uidmap) ty) in + if not (List.is_empty (fst ty)) then hierror "field instances cannot be polymorphic"; + let symbols = EcAlgTactic.field_symbols env (snd ty) in + let symbols = Mstr.of_list symbols in let symbols = check_tci_operators env ty tci.pti_ops symbols in let cr = field_of_symmap env (snd ty) symbols in let axioms = EcAlgTactic.field_axioms env cr in let lc = (tci.pti_loca :> locality) in let inter = check_tci_axioms scope mode tci.pti_axs axioms lc; in - let add env p = - let item = EcTheory.Th_instance(ty,`General p, tci.pti_loca) in + + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `Field cr + ; tci_local = (tci.pti_loca :> locality) } in + + let scope = + let item = EcTheory.Th_instance (None, instance) in let item = EcTheory.mkitem import item in - EcSection.add_item item env in - let scope = - { scope with - sc_env = - List.fold_left add - (let item = - EcTheory.Th_instance (([], snd ty), `Field cr, tci.pti_loca) in - let item = EcTheory.mkitem import item in - EcSection.add_item item scope.sc_env) - [p_zmod; p_ring; p_idomain; p_field] } - - in Ax.add_defer scope inter + { scope with sc_env = EcSection.add_item item scope.sc_env } in + + Ax.add_defer scope inter (* ------------------------------------------------------------------ *) - let symbols_of_tc (_env : EcEnv.env) ty (tcp, tc) = - let subst = EcSubst.add_tydef EcSubst.empty tcp ([], ty) in - List.map (fun (x, opty) -> - (EcIdent.name x, (true, EcSubst.subst_ty subst opty))) - tc.tc_ops + let symbols_of_tc (_env : EcEnv.env) ((tparams, ty) : ty_params * ty) (tcp, tc) = + let subst, tparams = EcSubst.fresh_tparams EcSubst.empty tparams in + let ty = EcSubst.subst_ty subst ty in + let subst = EcSubst.add_tydef subst tcp.tc_name (List.fst tparams, ty) in + let subst = + List.fold_left + (fun subst (a, ty) -> EcSubst.add_tyvar subst a ty) + subst (List.combine (List.fst tc.tc_tparams) tcp.tc_args) in -(* - (* ------------------------------------------------------------------ *) - let add_generic_tc (scope : scope) _mode { pl_desc = tci; pl_loc = loc; } = - let ty = - let ue = TT.transtyvars scope.sc_env (loc, Some (fst tci.pti_type)) in - let ty = transty tp_tydecl scope.sc_env ue (snd tci.pti_type) in - assert (EcUnify.UniEnv.closed ue); - (EcUnify.UniEnv.tparams ue, Tuni.offun (EcUnify.UniEnv.close ue) ty) - in + List.map (fun (x, opty) -> + (EcIdent.name x, (true, EcSubst.subst_ty subst opty))) + tc.tc_ops - let (tcp, tc) = - match EcEnv.TypeClass.lookup_opt (unloc tci.pti_name) (env scope) with + (* ------------------------------------------------------------------ *) + let add_generic_instance + ~import (scope : scope) mode { pl_desc = tci; pl_loc = loc; } + = + let name = + match tci.pti_name with | None -> - hierror ~loc:tci.pti_name.pl_loc - "unknown type-class: %s" (string_of_qsymbol (unloc tci.pti_name)) - | Some tc -> tc + hierror ~loc "typeclass instances must be given a name" + | Some name -> name in + + let (typarams, _) as ty = + let ue = TT.transtyvars (env scope) (loc, Some (fst tci.pti_type)) in + let ty = transty tp_tydecl (env scope) ue (snd tci.pti_type) in + assert (EcUnify.UniEnv.closed ue); + ( + EcUnify.UniEnv.tparams ue, + ty_subst (Tuni.subst (EcUnify.UniEnv.close ue)) ty + ) in - let symbols = symbols_of_tc scope.sc_env (snd ty) (tcp, tc) in - let _symbols = check_tci_operators scope.sc_env ty tci.pti_ops symbols in + let tcp = + let ue = EcUnify.UniEnv.create (Some typarams) in + TT.transtc (env scope) ue tci.pti_tc in - { scope with - sc_env = EcEnv.TypeClass.add_instance ty (`General tcp) scope.sc_env } + let tc = EcEnv.TypeClass.by_path tcp.tc_name (env scope) in + + let tcsyms = symbols_of_tc (env scope) ty (tcp, tc) in + let tcsyms = Mstr.of_list tcsyms in + let symbols = check_tci_operators (env scope) ty tci.pti_ops tcsyms in + + let subst = EcSubst.empty in + let subst = EcSubst.add_tydef subst tcp.tc_name ([], snd ty) in + let subst = + List.fold_left + (fun subst (a, ty) -> EcSubst.add_tyvar subst a ty) + subst (List.combine (List.fst tc.tc_tparams) tcp.tc_args) in + + let subst = + List.fold_left + (fun subst (opname, ty) -> + let oppath, optys = Mstr.find (EcIdent.name opname) symbols in + let op = + EcFol.f_op_tc + oppath + (List.map (EcSubst.subst_etyarg subst) optys) + (EcSubst.subst_ty subst ty) + in EcSubst.add_flocal subst opname op) + subst tc.tc_ops in + + let axioms = + List.map + (fun (name, ax) -> + let ax = EcSubst.subst_form subst ax in + (name, ax)) + tc.tc_axs in + let lc = (tci.pti_loca :> locality) in + let inter = check_tci_axioms scope mode tci.pti_axs axioms lc in + + let instance = EcTheory. + { tci_params = fst ty + ; tci_type = snd ty + ; tci_instance = `General (tcp, Some symbols) + ; tci_local = lc } in + + let scope = + let item = EcTheory.Th_instance (Some (unloc name), instance) in + let item = EcTheory.mkitem import item in + { scope with sc_env = EcSection.add_item item scope.sc_env } in -(* - let ue = EcUnify.UniEnv.create (Some []) in - let ty = fst (EcUnify.UniEnv.openty ue (fst ty) None (snd ty)) in - try EcUnify.hastc scope.sc_env ue ty (Sp.singleton (fst tc)); tc - with EcUnify.UnificationFailure _ -> - hierror "type must be an instance of `%s'" (EcPath.tostring (fst tc)) -*) -*) + Ax.add_defer scope inter (* ------------------------------------------------------------------ *) let add_instance ?(import = EcTheory.import0) (scope : scope) mode ({ pl_desc = tci } as toptci) = - match unloc tci.pti_name with + match unloc (fst tci.pti_tc) with | ([], "bring") -> begin if EcUtils.is_some tci.pti_args then hierror "unsupported-option"; @@ -1960,7 +2040,7 @@ module Ty = struct | _ -> if EcUtils.is_some tci.pti_args then hierror "unsupported-option"; - failwith "unsupported" (* FIXME *) + add_generic_instance ~import scope mode toptci end (* -------------------------------------------------------------------- *) @@ -2319,15 +2399,15 @@ module Search = struct let ps = ref Mid.empty in let ue = EcUnify.UniEnv.create None in let tip = EcUnify.UniEnv.opentvi ue decl.op_tparams None in - let tip = f_subst_init ~tv:tip () in - let es = e_subst tip in + let tip = f_subst_init ~tv:tip.subst () in + let es = e_subst tip in let xs = List.map (snd_map (ty_subst tip)) nt.ont_args in let bd = EcFol.form_of_expr EcFol.mhr (es nt.ont_body) in let fp = EcFol.f_lambda (List.map (snd_map EcFol.gtty) xs) bd in match fp.f_node with | Fop (pf, _) -> (pf :: paths, pts) - | _ -> (paths, (ps, ue, fp) ::pts) + | _ -> (paths, (ps, ue, fp) :: pts) end | _ -> (p :: paths, pts) in diff --git a/src/ecScope.mli b/src/ecScope.mli index f04f9595aa..724b06091b 100644 --- a/src/ecScope.mli +++ b/src/ecScope.mli @@ -117,7 +117,6 @@ end (* -------------------------------------------------------------------- *) module Ty : sig val add : scope -> ptydecl located -> scope - val add_class : scope -> ptypeclass located -> scope val add_instance : ?import:EcTheory.import -> scope -> Ax.proofmode -> ptycinstance located -> scope end diff --git a/src/ecSection.ml b/src/ecSection.ml index 127040f6b1..94a41e1d1e 100644 --- a/src/ecSection.ml +++ b/src/ecSection.ml @@ -1,6 +1,7 @@ (* -------------------------------------------------------------------- *) open EcUtils open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -21,7 +22,7 @@ type cbarg = [ | `Module of mpath | `ModuleType of path | `Typeclass of path - | `Instance of tcinstance + | `TcInstance of [`General of path | `Ring | `Field] ] type cb = cbarg -> unit @@ -46,16 +47,18 @@ let pp_cbarg env fmt (who : cbarg) = | _ -> ppe in Format.fprintf fmt "module %a" (EcPrinting.pp_topmod ppe) mp | `ModuleType p -> - let mty = EcEnv.ModTy.modtype p env in - Format.fprintf fmt "module type %a" (EcPrinting.pp_modtype1 ppe) mty + Format.fprintf fmt "module type %a" + (EcPrinting.pp_modtype1 ppe) + (EcEnv.ModTy.modtype p env) | `Typeclass p -> - Format.fprintf fmt "typeclass %a" (EcPrinting.pp_tcname ppe) p - | `Instance tci -> - match tci with - | `Ring _ -> Format.fprintf fmt "ring instance" - | `Field _ -> Format.fprintf fmt "field instance" - | `General _ -> Format.fprintf fmt "instance" - + Format.fprintf fmt "typeclass %a" (EcPrinting.pp_tyname ppe) p + | `TcInstance (`General p) -> + Format.fprintf fmt "typeclass instance %s" (EcPath.tostring p) (* FIXME:TC *) + | `TcInstance `Ring -> + Format.fprintf fmt "ring instance" + | `TcInstance `Field -> + Format.fprintf fmt "field instance" + let pp_locality fmt = function | `Local -> Format.fprintf fmt "local" | `Global -> () @@ -105,9 +108,28 @@ let rec on_ty (cb : cb) (ty : ty) = | Tvar _ -> () | Tglob _ -> () | Ttuple tys -> List.iter (on_ty cb) tys - | Tconstr (p, tys) -> cb (`Type p); List.iter (on_ty cb) tys + | Tconstr (p, tys) -> cb (`Type p); List.iter (on_etyarg cb) tys | Tfun (ty1, ty2) -> List.iter (on_ty cb) [ty1; ty2] +and on_etyarg cb ((ty, tcw) : etyarg) = + on_ty cb ty; + List.iter (on_tcwitness cb) tcw + +and on_tcwitness cb (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + () + + | TCIConcrete { path; etyargs } -> + List.iter (on_etyarg cb) etyargs; + cb (`TcInstance (`General path)) + + | TCIAbstract { support = `Abs path } -> + cb (`Type path) + + | TCIAbstract { support = `Var _ } -> + () + let on_pv (cb : cb) (pv : prog_var)= match pv with | PVglob xp -> on_xp cb xp @@ -136,7 +158,7 @@ let rec on_expr (cb : cb) (e : expr) = | Evar pv -> on_pv cb pv | Elet (lp, e1, e2) -> on_lp cb lp; List.iter cbrec [e1; e2] | Etuple es -> List.iter cbrec es - | Eop (p, tys) -> cb (`Op p); List.iter (on_ty cb) tys + | Eop (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys | Eapp (e, es) -> List.iter cbrec (e :: es) | Eif (c, e1, e2) -> List.iter cbrec [c; e1; e2] | Ematch (e, es, ty) -> on_ty cb ty; List.iter cbrec (e :: es) @@ -201,7 +223,7 @@ let rec on_form (cb : cb) (f : EcFol.form) = | EcAst.Fif (f1, f2, f3) -> List.iter cbrec [f1; f2; f3] | EcAst.Fmatch (b, fs, ty) -> on_ty cb ty; List.iter cbrec (b :: fs) | EcAst.Flet (lp, f1, f2) -> on_lp cb lp; List.iter cbrec [f1; f2] - | EcAst.Fop (p, tys) -> cb (`Op p); List.iter (on_ty cb) tys + | EcAst.Fop (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys | EcAst.Fapp (f, fs) -> List.iter cbrec (f :: fs) | EcAst.Ftuple fs -> List.iter cbrec fs | EcAst.Fproj (f, _) -> cbrec f @@ -352,11 +374,15 @@ and on_oi (cb : cb) (oi : OI.t) = List.iter (on_xp cb) (OI.allowed oi) (* -------------------------------------------------------------------- *) -let on_typeclasses cb s = - Sp.iter (fun p -> cb (`Typeclass p)) s +let on_typeclass cb tc = + cb (`Typeclass tc.tc_name); + List.iter (on_etyarg cb) tc.tc_args + +let on_typeclasses cb tcs = + List.iter (on_typeclass cb) tcs let on_typarams cb typarams = - List.iter (fun (_,s) -> on_typeclasses cb s) typarams + List.iter (fun (_, tc) -> on_typeclasses cb tc) typarams (* -------------------------------------------------------------------- *) let on_tydecl (cb : cb) (tyd : tydecl) = @@ -371,8 +397,8 @@ let on_tydecl (cb : cb) (tyd : tydecl) = List.iter (List.iter (on_ty cb) |- snd) dt.tydt_ctors; List.iter (on_form cb) [dt.tydt_schelim; dt.tydt_schcase] -let on_typeclass cb tc = - oiter (fun p -> cb (`Typeclass p)) tc.tc_prt; +let on_tcdecl cb tc = + oiter (on_typeclass cb) tc.tc_prt; List.iter (fun (_,ty) -> on_ty cb ty) tc.tc_ops; List.iter (fun (_,f) -> on_form cb f) tc.tc_axs @@ -402,8 +428,8 @@ let on_opdecl (cb : cb) (opdecl : operator) = | OB_oper Some b -> match b with | OP_Constr _ | OP_Record _ | OP_Proj _ -> assert false - | OP_TC -> assert false - | OP_Plain f -> on_form cb f + | OP_TC _ -> assert false + | OP_Plain f -> on_form cb f | OP_Fix f -> let rec on_mpath_branches br = match br with @@ -448,15 +474,19 @@ let on_field cb f = let on_p p = cb (`Op p) in on_p f.f_inv; oiter on_p f.f_div -let on_instance cb ty tci = - on_typarams cb (fst ty); - on_ty cb (snd ty); - match tci with - | `Ring r -> on_ring cb r - | `Field f -> on_field cb f - | `General p -> - (* FIXME section: ring/field use type class that do not exists *) - cb (`Typeclass p) +let on_instance cb tci = + on_typarams cb tci.tci_params; + on_ty cb tci.tci_type; + (* FIXME section: ring/field use type class that do not exists *) + match tci.tci_instance with + | `Ring r -> on_ring cb r + | `Field f -> on_field cb f + + | `General (tci, syms) -> + on_typeclass cb tci; + Option.iter + (Mstr.iter (fun _ (p, tys) -> cb (`Op p); List.iter (on_etyarg cb) tys)) + syms (* -------------------------------------------------------------------- *) type sc_name = @@ -504,7 +534,11 @@ let pp_thname scenv = (* -------------------------------------------------------------------- *) let locality (env : EcEnv.env) (who : cbarg) = match who with - | `Type p -> (EcEnv. Ty.by_path p env).tyd_loca + | `Type p -> begin + match EcEnv.TypeClass.by_path_opt p env with + | Some tc -> (tc.tc_loca :> locality) + | _ -> (EcEnv.Ty.by_path p env).tyd_loca + end | `Op p -> (EcEnv. Op.by_path p env).op_loca | `Ax p -> (EcEnv. Ax.by_path p env).ax_loca | `Typeclass p -> ((EcEnv.TypeClass.by_path p env).tc_loca :> locality) @@ -515,7 +549,8 @@ let locality (env : EcEnv.env) (who : cbarg) = | _ -> `Global end | `ModuleType p -> ((EcEnv.ModTy.by_path p env).tms_loca :> locality) - | `Instance _ -> assert false + | `TcInstance (`General p) -> (EcEnv.TcInstance.by_path p env).tci_local + | `TcInstance (`Ring | `Field) -> `Global (* -------------------------------------------------------------------- *) type to_clear = @@ -525,7 +560,7 @@ type to_clear = type to_gen = { tg_env : scenv; - tg_params : (EcIdent.t * Sp.t) list; + tg_params : (EcIdent.t * typeclass list) list; (* FIXME: TC *) tg_binds : bind list; tg_subst : EcSubst.subst; tg_clear : to_clear; } @@ -577,11 +612,12 @@ let add_declared_ty to_gen path tydecl = | `Abstract s -> s | _ -> assert false in - let name = "'" ^ basename path in - let id = EcIdent.create name in + let name = Format.sprintf "'%s" (basename path) in + let id = EcIdent.create name in + { to_gen with - tg_params = to_gen.tg_params @ [id, s]; - tg_subst = EcSubst.add_tydef to_gen.tg_subst path ([], tvar id); + tg_params = to_gen.tg_params @ [id, s]; + tg_subst = EcSubst.add_tydef to_gen.tg_subst path ([], tvar id); } let add_declared_op to_gen path opdecl = @@ -603,14 +639,22 @@ let add_declared_op to_gen path opdecl = | _ -> assert false } let tvar_fv ty = Mid.map (fun () -> 1) (EcTypes.Tvar.fv ty) + + let etyargs_tvar_fv etyargs = + Mid.map (fun () -> 1) (EcTypes.etyargs_tvar_fv etyargs) + let fv_and_tvar_e e = let rec aux fv e = let fv = EcIdent.fv_union fv (tvar_fv e.e_ty) in match e.e_node with - | Eop(_, tys) -> List.fold_left (fun fv ty -> EcIdent.fv_union fv (tvar_fv ty)) fv tys + | Eop(_, etyargs) -> + EcIdent.fv_union fv (etyargs_tvar_fv etyargs) | Equant(_,d,e) -> - let fv = List.fold_left (fun fv (_,ty) -> EcIdent.fv_union fv (tvar_fv ty)) fv d in - aux fv e + let fv = + List.fold_left + (fun fv (_,ty) -> EcIdent.fv_union fv (tvar_fv ty)) + fv d + in aux fv e | _ -> e_fold aux fv e in aux e.e_fv e @@ -629,7 +673,8 @@ and fv_and_tvar_f f = let rec aux f = fv := EcIdent.fv_union !fv (tvar_fv f.f_ty); match f.f_node with - | Fop(_, tys) -> fv := List.fold_left (fun fv ty -> EcIdent.fv_union fv (tvar_fv ty)) !fv tys + | Fop(_, tys) -> + fv := EcIdent.fv_union !fv (etyargs_tvar_fv tys) | Fquant(_, d, f) -> fv := List.fold_left (fun fv (_,gty) -> EcIdent.fv_union fv (gty_fv_and_tvar gty)) !fv d; aux f @@ -659,7 +704,7 @@ let op_body_fv body ty = let fv = ty_fv_and_tvar ty in match body with | OP_Plain f -> EcIdent.fv_union fv (fv_and_tvar_f f) - | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_TC -> fv + | OP_Constr _ | OP_Record _ | OP_Proj _ | OP_TC _ -> fv | OP_Fix opfix -> let fv = List.fold_left (fun fv (_, ty) -> EcIdent.fv_union fv (ty_fv_and_tvar ty)) @@ -844,7 +889,7 @@ let generalize_opdecl to_gen prefix (name, operator) = let body = match body with | OP_Constr _ | OP_Record _ | OP_Proj _ -> assert false - | OP_TC -> assert false (* ??? *) + | OP_TC _ -> assert false (* ??? *) | OP_Plain f -> OP_Plain (f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) extra_a) f) | OP_Fix opfix -> @@ -995,11 +1040,11 @@ let generalize_export to_gen (p,lc) = if lc = `Local || to_clear to_gen (`Th p) then to_gen, None else to_gen, Some (Th_export (p,lc)) -let generalize_instance to_gen (ty,tci, lc) = - if lc = `Local then to_gen, None - (* FIXME: be sure that we have no dep to declare or local, +let generalize_instance to_gen (x, tci) = + if tci.tci_local = `Local then to_gen, None + (* FIXME:TC be sure that we have no dep to declare or local, or fix this code *) - else to_gen, Some (Th_instance (ty,tci,lc)) + else to_gen, Some (Th_instance (x, tci)) let generalize_baserw to_gen prefix (s,lc) = if lc = `Local then @@ -1042,7 +1087,7 @@ let rec set_local_item item = | Th_typeclass (s,tc) -> Th_typeclass (s, { tc with tc_loca = set_local tc.tc_loca }) | Th_theory (s, th) -> Th_theory (s, set_local_th th) | Th_export (p,lc) -> Th_export (p, set_local lc) - | Th_instance (ty,ti,lc) -> Th_instance (ty,ti, set_local lc) + | Th_instance (x,tci) -> Th_instance (x, { tci with tci_local = set_local tci.tci_local }) | Th_baserw (s,lc) -> Th_baserw (s, set_local lc) | Th_addrw (p,ps,lc) -> Th_addrw (p, ps, set_local lc) | Th_reduction r -> Th_reduction r @@ -1070,22 +1115,6 @@ let is_abstract_ty = function | `Abstract _ -> true | _ -> false -(* -let rec check_glob_mp_ty s scenv mp = - let mtop = `Module (mastrip mp) in - if is_declared scenv mtop then - hierror "global %s can't depend on declared module" s; - if is_local scenv mtop then - hierror "global %s can't depend on local module" s; - List.iter (check_glob_mp_ty s scenv) mp.m_args - -let rec check_glob_mp scenv mp = - let mtop = `Module (mastrip mp) in - if is_local scenv mtop then - hierror "global definition can't depend on local module"; - List.iter (check_glob_mp scenv) mp.m_args - *) - let check s scenv who b = if not b then hierror "%a %s" (pp_lc_cbarg scenv.sc_env) who s @@ -1099,24 +1128,26 @@ let check_polymorph scenv who typarams = let check_abstract = check "should be abstract" type can_depend = { - d_ty : locality list; - d_op : locality list; - d_ax : locality list; - d_sc : locality list; - d_mod : locality list; - d_modty : locality list; - d_tc : locality list; - } + d_ty : locality list; + d_op : locality list; + d_ax : locality list; + d_sc : locality list; + d_mod : locality list; + d_modty : locality list; + d_tc : locality list; + d_tci : locality list; +} -let cd_glob = - { d_ty = [`Global]; - d_op = [`Global]; - d_ax = [`Global]; - d_sc = [`Global]; - d_mod = [`Global]; - d_modty = [`Global]; - d_tc = [`Global]; - } +let cd_glob = { + d_ty = [`Global]; + d_op = [`Global]; + d_ax = [`Global]; + d_sc = [`Global]; + d_mod = [`Global]; + d_modty = [`Global]; + d_tc = [`Global]; + d_tci = [`Global]; +} let can_depend (cd : can_depend) = function | `Type _ -> cd.d_ty @@ -1126,8 +1157,7 @@ let can_depend (cd : can_depend) = function | `Module _ -> cd.d_mod | `ModuleType _ -> cd.d_modty | `Typeclass _ -> cd.d_tc - | `Instance _ -> assert false - + | `TcInstance _ -> cd.d_tci let cb scenv from cd who = let env = scenv.sc_env in @@ -1158,29 +1188,10 @@ let check_tyd scenv prefix name tyd = d_mod = [`Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in on_tydecl (cb scenv from cd) tyd -(* -let cb_glob scenv (who:cbarg) = - match who with - | `Type p -> - if is_local scenv who then - hierror "global definition can't depend of local type %s" - (EcPath.tostring p) - | `Module mp -> - check_glob_mp scenv mp - | `Op p -> - if is_local scenv who then - hierror "global definition can't depend of local op %s" - (EcPath.tostring p) - | `ModuleType p -> - if is_local scenv who then - hierror "global definition can't depend of local module type %s" - (EcPath.tostring p) - | `Ax _ | `Typeclass _ -> assert false -*) - let is_abstract_op op = match op.op_kind with | OB_oper None | OB_pred None -> true @@ -1204,6 +1215,7 @@ let check_op scenv prefix name op = d_mod = [`Declare; `Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in on_opdecl (cb scenv from cd) op @@ -1216,6 +1228,7 @@ let check_op scenv prefix name op = d_mod = [`Global]; d_modty = []; d_tc = [`Global]; + d_tci = [`Global]; } in on_opdecl (cb scenv from cd) op @@ -1235,6 +1248,7 @@ let check_ax (scenv : scenv) (prefix : path) (name : symbol) (ax : axiom) = d_mod = [`Declare; `Global]; d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in let doit = on_axiom (cb scenv from cd) in let error b s1 s = @@ -1287,28 +1301,39 @@ let check_module scenv prefix tme = d_mod = [`Global]; (* FIXME section: add local *) d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in on_module (cb scenv from cd) me | `Declare -> (* Should be SC_decl_mod ... *) assert false -let check_typeclass scenv prefix name tc = +let check_tcdecl scenv prefix name tc = let path = pqname prefix name in let from = ((tc.tc_loca :> locality), `Typeclass path) in if tc.tc_loca = `Local then check_section scenv from else - on_typeclass (cb scenv from cd_glob) tc - -let check_instance scenv ty tci lc = - let from = (lc :> locality), `Instance tci in - if lc = `Local then check_section scenv from + on_tcdecl (cb scenv from cd_glob) tc + +let check_instance scenv prefix x tci = + let from = + match x, tci.tci_instance with + | Some x, `General _ -> `General (pqname prefix x) + | None , `Ring _ -> `Ring + | None , `Field _ -> `Field + | _ , _ -> assert false in + + let from = (tci.tci_local, `TcInstance from) in + + if tci.tci_local = `Local then check_section scenv from else if scenv.sc_insec then - match tci with - | `Ring _ | `Field _ -> on_instance (cb scenv from cd_glob) ty tci + match tci.tci_instance with + | `Ring _ | `Field _ -> + on_instance (cb scenv from cd_glob) tci + | `General _ -> - let cd = { cd_glob with d_ty = [`Declare; `Global]; } in - on_instance (cb scenv from cd) ty tci + let cd = { cd_glob with d_ty = [`Declare; `Global]; } in + on_instance (cb scenv from cd) tci (* -----------------------------------------------------------*) let enter_theory (name:symbol) (lc:is_local) (mode:thmode) scenv : scenv = @@ -1344,7 +1369,7 @@ let add_item_ (item : theory_item) (scenv:scenv) = | Th_module me -> EcEnv.Mod.bind me.tme_expr.me_name me env | Th_typeclass(s,tc) -> EcEnv.TypeClass.bind s tc env | Th_export (p, lc) -> EcEnv.Theory.export p lc env - | Th_instance (tys,i,lc) -> EcEnv.TypeClass.add_instance tys i lc env + | Th_instance (x, tc) -> EcEnv.TcInstance.bind x tc env | Th_baserw (s,lc) -> EcEnv.BaseRw.add s lc env | Th_addrw (p,ps,lc) -> EcEnv.BaseRw.addto p ps lc env | Th_auto (level, base, ps, lc) -> EcEnv.Auto.add ~level ?base ps lc env @@ -1370,8 +1395,8 @@ let rec generalize_th_item (to_gen : to_gen) (prefix : path) (th_item : theory_i | Th_module me -> generalize_module to_gen prefix me | Th_theory th -> (generalize_ctheory to_gen prefix th, None) | Th_export (p,lc) -> generalize_export to_gen (p,lc) - | Th_instance (ty,i,lc) -> generalize_instance to_gen (ty,i,lc) - | Th_typeclass _ -> assert false + | Th_instance (x,tci)-> generalize_instance to_gen (x,tci) + | Th_typeclass _ -> assert false (* FIXME:TC *) | Th_baserw (s,lc) -> generalize_baserw to_gen prefix (s,lc) | Th_addrw (p,ps,lc) -> generalize_addrw to_gen (p, ps, lc) | Th_reduction rl -> generalize_reduction to_gen rl @@ -1484,9 +1509,9 @@ let check_item scenv item = | Th_axiom (s, ax) -> check_ax scenv prefix s ax | Th_modtype (s, ms) -> check_modtype scenv prefix s ms | Th_module me -> check_module scenv prefix me - | Th_typeclass (s,tc) -> check_typeclass scenv prefix s tc + | Th_typeclass (s,tc) -> check_tcdecl scenv prefix s tc | Th_export (_, lc) -> assert (lc = `Global || scenv.sc_insec); - | Th_instance (ty,tci,lc) -> check_instance scenv ty tci lc + | Th_instance(x, tci) -> check_instance scenv prefix x tci | Th_baserw (_,lc) -> if (lc = `Local && not scenv.sc_insec) then hierror "local base rewrite can only be declared inside section"; @@ -1530,6 +1555,7 @@ let add_decl_mod id mt scenv = d_mod = [`Declare; `Global]; d_modty = [`Global]; d_tc = [`Global]; + d_tci = [`Global]; } in let from = `Declare, `Module (mpath_abs id []) in on_mty_mr (cb scenv from cd) mt; diff --git a/src/ecSmt.ml b/src/ecSmt.ml index 6d81016c62..b685a6710c 100644 --- a/src/ecSmt.ml +++ b/src/ecSmt.ml @@ -376,7 +376,7 @@ let rec trans_ty ((genv, lenv) as env) ty = | Tconstr (p, tys) -> let id = trans_pty genv p in - WTy.ty_app id (trans_tys env tys) + WTy.ty_app id (trans_tys env (List.fst tys)) (* FIXME:TC *) | Tfun (t1, t2) -> WTy.ty_func (trans_ty env t1) (trans_ty env t2) @@ -712,6 +712,7 @@ and trans_app ((genv, lenv) as env : tenv * lenv) (f : form) args = | Fop (p, ts) -> let wop = trans_op genv p in + let ts = List.fst ts in (* FIXME:TC *) let tys = List.map (trans_ty (genv,lenv)) ts in apply_wop genv wop tys args @@ -764,7 +765,7 @@ and trans_branch (genv, lenv) (p, _dty, tvs) (f, (cname, argsty)) = in let lenv, ws = trans_lvars genv lenv xs in - let wcty = trans_ty (genv, lenv) (tconstr p tvs) in + let wcty = trans_ty (genv, lenv) (tconstr_tc p tvs) in let ws = List.map WTerm.pat_var ws in let ws = WTerm.pat_app csymb ws wcty in let wf = trans_app (genv, lenv) f [] in diff --git a/src/ecSubst.ml b/src/ecSubst.ml index 977f7e3657..c3bebf2464 100644 --- a/src/ecSubst.ml +++ b/src/ecSubst.ml @@ -1,5 +1,6 @@ (* -------------------------------------------------------------------- *) open EcUtils +open EcMaps open EcAst open EcTypes open EcDecl @@ -26,7 +27,7 @@ exception InconsistentSubst type subst = { sb_module : EcPath.mpath Mid.t; sb_path : EcPath.path Mp.t; - sb_tyvar : ty Mid.t; + sb_tyvar : etyarg Mid.t; sb_elocal : expr Mid.t; sb_flocal : EcCoreFol.form Mid.t; sb_fmem : EcIdent.t Mid.t; @@ -137,17 +138,17 @@ let has_def (s : subst) (p : EcPath.path) = Mp.mem p s.sb_def (* -------------------------------------------------------------------- *) -let add_tyvar (s : subst) (x : EcIdent.t) (ty : ty) = +let add_tyvar (s : subst) (x : EcIdent.t) (ety : etyarg) = (* FIXME: check name clash *) let merger = function - | None -> Some ty + | None -> Some ety | Some _ -> raise (SubstNameClash (`Ident x)) in { s with sb_tyvar = Mid.change merger x s.sb_tyvar } (* -------------------------------------------------------------------- *) -let add_tyvars (s : subst) (xs : EcIdent.t list) (tys : ty list) = - List.fold_left2 add_tyvar s xs tys +let add_tyvars (s : subst) (xs : (EcIdent.t * etyarg) list) = + List.fold_left (fun s (x, ety) -> add_tyvar s x ety) s xs (* -------------------------------------------------------------------- *) let rec subst_ty (s : subst) (ty : ty) = @@ -156,23 +157,25 @@ let rec subst_ty (s : subst) (ty : ty) = tglob (EcPath.mget_ident (subst_mpath s (EcPath.mident mp))) | Tunivar _ -> - ty (* FIXME *) + ty | Tvar a -> - Mid.find_def ty a s.sb_tyvar + Mid.find_opt a s.sb_tyvar + |> Option.map fst + |> Option.value ~default:ty | Ttuple tys -> ttuple (subst_tys s tys) - | Tconstr (p, tys) -> begin - let tys = subst_tys s tys in + | Tconstr (p, etys) -> begin + let etys = subst_etyargs s etys in match Mp.find_opt p s.sb_tydef with | None -> - tconstr (subst_path s p) tys + tconstr_tc (subst_path s p) etys | Some (args, body) -> - let s = List.fold_left2 add_tyvar empty args tys in + let s = List.fold_left2 add_tyvar empty args etys in subst_ty s body end @@ -183,6 +186,43 @@ let rec subst_ty (s : subst) (ty : ty) = and subst_tys (s : subst) (tys : ty list) = List.map (subst_ty s) tys +(* -------------------------------------------------------------------- *) +and subst_etyarg (s : subst) ((ty, tcws) : etyarg) : etyarg = + (subst_ty s ty, subst_tcws s tcws) + +(* -------------------------------------------------------------------- *) +and subst_etyargs (s : subst) (tyargs : etyarg list) : etyarg list = + List.map (subst_etyarg s) tyargs + +(* -------------------------------------------------------------------- *) +and subst_tcw (s : subst) (tcw : tcwitness) = + match tcw with + | TCIUni _ -> + tcw + + | TCIConcrete { etyargs; path } -> + let path = subst_path s path in + let etyargs = subst_etyargs s etyargs in + TCIConcrete { etyargs; path } + + | TCIAbstract { support = `Var a; offset } -> + Mid.find_opt a s.sb_tyvar + |> Option.map snd + |> Option.map (fun tcs -> List.nth tcs offset) + |> Option.value ~default:tcw + + | TCIAbstract ({ support = `Abs p } as tcw) -> + match Mp.find_opt p s.sb_tydef with + | None -> + TCIAbstract { tcw with support = `Abs (subst_path s p) } + + | Some _ -> + assert false (* FIXME:TC *) + +(* -------------------------------------------------------------------- *) +and subst_tcws (s : subst) (tcws : tcwitness list) : tcwitness list = + List.map (subst_tcw s) tcws + (* -------------------------------------------------------------------- *) let add_module (s : subst) (x : EcIdent.t) (m : EcPath.mpath) = let merger = function @@ -267,9 +307,9 @@ let add_path (s : subst) ~src ~dst = assert (Mp.find_opt src s.sb_path = None); { s with sb_path = Mp.add src dst s.sb_path } -let add_tydef (s : subst) p (ids, ty) = +let add_tydef (s : subst) p (typ, ty) = assert (Mp.find_opt p s.sb_tydef = None); - { s with sb_tydef = Mp.add p (ids, ty) s.sb_tydef } + { s with sb_tydef = Mp.add p (typ, ty) s.sb_tydef } let add_opdef (s : subst) p (ids, f) = assert (Mp.find_opt p s.sb_def = None); @@ -317,51 +357,80 @@ let subst_expr_lpattern (s : subst) (lp : lpattern) = (* -------------------------------------------------------------------- *) let rec subst_expr (s : subst) (e : expr) = + let mk (node : expr_node) = + let ty = subst_ty s e.e_ty in + mk_expr node ty in + match e.e_node with + | Eint _ -> + mk e.e_node + | Elocal id -> begin match Mid.find id s.sb_elocal with | aout -> aout - | exception Not_found -> e_local id (subst_ty s e.e_ty) + | exception Not_found -> mk (Elocal id) end | Evar pv -> - e_var (subst_progvar s pv) (subst_ty s e.e_ty) - - | Eapp ({ e_node = Eop (p, tys) }, args) when has_opdef s p -> - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - let body = oget (get_opdef s p) in - let args = List.map (subst_expr s) args in - subst_eop ty tys args body - - | Eop (p, tys) when has_opdef s p -> - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - let body = oget (get_opdef s p) in - subst_eop ty tys [] body - - | Eop (p, tys) -> - let p = subst_path s p in - let tys = subst_tys s tys in - let ty = subst_ty s e.e_ty in - e_op p tys ty + mk (Evar (subst_progvar s pv)) + + | Eapp ({ e_node = Eop (p, tyargs) }, args) when has_opdef s p -> + let tyargs = subst_etyargs s tyargs in + let ty = subst_ty s e.e_ty in + let body = oget (get_opdef s p) in + let args = List.map (subst_expr s) args in + subst_eop ty tyargs args body + + | Eapp (hd, args) -> + let hd = subst_expr s hd in + let args = List.map (subst_expr s) args in + mk (Eapp (hd, args)) + + | Eop (p, tyargs) when has_opdef s p -> + let tys = subst_etyargs s tyargs in + let ty = subst_ty s e.e_ty in + let body = oget (get_opdef s p) in + subst_eop ty tys [] body + + | Eop (p, tyargs) -> + let p = subst_path s p in + let tyargs = subst_etyargs s tyargs in + mk (Eop (p, tyargs)) + + | Eif (c, e1, e2) -> + let c = subst_expr s c in + let e1 = subst_expr s e1 in + let e2 = subst_expr s e2 in + mk (Eif (c, e1, e2)) + + | Ematch (c, bs, ty) -> + let c = subst_expr s c in + let bs = List.map (subst_expr s) bs in + let ty = subst_ty s ty in + mk (Ematch (c, bs, ty)) + + | Eproj (sube, (i : int)) -> + let sube = subst_expr s sube in + mk (Eproj (sube, i)) + + | Etuple es -> + let es = List.map (subst_expr s) es in + mk (Etuple es) | Elet (lp, e1, e2) -> - let e1 = subst_expr s e1 in - let s, lp = subst_expr_lpattern s lp in - let e2 = subst_expr s e2 in - e_let lp e1 e2 + let e1 = subst_expr s e1 in + let s, lp = subst_expr_lpattern s lp in + let e2 = subst_expr s e2 in + mk (Elet (lp, e1, e2)) - | Equant (q, b, e1) -> - let s, b = fresh_elocals s b in - let e1 = subst_expr s e1 in - e_quantif q b e1 - - | _ -> e_map (subst_ty s) (subst_expr s) e + | Equant (q, b, bd) -> + let s, b = fresh_elocals s b in + let bd = subst_expr s bd in + mk (Equant (q, b, bd)) (* -------------------------------------------------------------------- *) and subst_eop ety tys args (tyids, e) = - let s = add_tyvars empty tyids tys in + let s = add_tyvars empty (List.combine tyids tys) in let (s, args, e) = match e.e_node with @@ -475,166 +544,187 @@ let subst_form_lpattern (s : subst) (lp : lpattern) = (* -------------------------------------------------------------------- *) let rec subst_form (s : subst) (f : form) = + let mk (node : f_node) = + let ty = subst_ty s f.f_ty in + mk_form node ty in + match f.f_node with - | Fquant (q, b, f1) -> - let s, b = fresh_glocals s b in - let e1 = subst_form s f1 in - f_quant q b e1 + | Fint _ -> + mk (f.f_node) + + | Fquant (q, b, bd) -> + let s, b = fresh_glocals s b in + let bd = subst_form s bd in + mk (Fquant (q, b, bd)) | Fmatch (f, bs, ty) -> - let f = subst_form s f in - let bs = List.map (subst_form s) bs in - let ty = subst_ty s ty in - f_match f bs ty + let f = subst_form s f in + let bs = List.map (subst_form s) bs in + let ty = subst_ty s ty in + mk (Fmatch (f, bs, ty)) | Flet (lp, f, body) -> - let f = subst_form s f in - let s, lp = subst_form_lpattern s lp in - let body = subst_form s body in - f_let lp f body + let f = subst_form s f in + let s, lp = subst_form_lpattern s lp in + let body = subst_form s body in + mk (Flet (lp, f, body)) | Flocal x -> begin - match Mid.find x s.sb_flocal with - | aout -> aout - | exception Not_found -> f_local x (subst_ty s f.f_ty) - end + match Mid.find x s.sb_flocal with + | aout -> aout + | exception Not_found -> mk (Flocal x) + end | Fpvar (pv, m) -> - let pv = subst_progvar s pv in - let ty = subst_ty s f.f_ty in - let m = subst_mem s m in - f_pvar pv ty m + let pv = subst_progvar s pv in + let m = subst_mem s m in + mk (Fpvar (pv, m)) | Fglob (mp, m) -> - let mp = EcPath.mget_ident (subst_mpath s (EcPath.mident mp)) in - let m = subst_mem s m in - f_glob mp m - - | Fapp ({ f_node = Fop (p, tys) }, args) when has_def s p -> - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - let body = oget (get_def s p) in - let args = List.map (subst_form s) args in - subst_fop ty tys args body - - | Fop (p, tys) when has_def s p -> - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - let body = oget (get_def s p) in - subst_fop ty tys [] body - - | Fop (p, tys) -> - let p = subst_path s p in - let tys = subst_tys s tys in - let ty = subst_ty s f.f_ty in - f_op p tys ty + let mp = EcPath.mget_ident (subst_mpath s (EcPath.mident mp)) in + let m = subst_mem s m in + mk (Fglob (mp, m)) + + | Fapp ({ f_node = Fop (p, tyargs) }, args) when has_def s p -> + let tys = subst_etyargs s tyargs in + let ty = subst_ty s f.f_ty in + let body = oget (get_def s p) in + let args = List.map (subst_form s) args in + subst_fop ty tys args body + + | Fapp (hd, args) -> + let hd = subst_form s hd in + let args = List.map (subst_form s) args in + mk (Fapp (hd, args)) + + | Fop (p, tyargs) when has_def s p -> + let tyargs = subst_etyargs s tyargs in + let ty = subst_ty s f.f_ty in + let body = oget (get_def s p) in + subst_fop ty tyargs [] body + + | Fop (p, tyargs) -> + let p = subst_path s p in + let tyargs = subst_etyargs s tyargs in + mk (Fop (p, tyargs)) + + | Fif (c, f1, f2) -> + let c = subst_form s c in + let f1 = subst_form s f1 in + let f2 = subst_form s f2 in + mk (Fif (c, f1, f2)) + + | Ftuple fs -> + let fs = List.map (subst_form s) fs in + mk (Ftuple fs) + + | Fproj (subf, (i : int)) -> + let subf = subst_form s subf in + mk (Fproj (subf, i)) | FhoareF { hf_pr; hf_f; hf_po } -> - let hf_pr, hf_po = - let s = add_memory s mhr mhr in - let hf_pr = subst_form s hf_pr in - let hf_po = subst_form s hf_po in - (hf_pr, hf_po) in - let hf_f = subst_xpath s hf_f in - f_hoareF hf_pr hf_f hf_po + let hf_pr, hf_po = + let s = add_memory s mhr mhr in + let hf_pr = subst_form s hf_pr in + let hf_po = subst_form s hf_po in + (hf_pr, hf_po) in + let hf_f = subst_xpath s hf_f in + f_hoareF hf_pr hf_f hf_po | FhoareS { hs_m; hs_pr; hs_s; hs_po } -> - let hs_m, (hs_pr, hs_po) = - let s, hs_m = subst_memtype s hs_m in - let hs_pr = subst_form s hs_pr in - let hs_po = subst_form s hs_po in - hs_m, (hs_pr, hs_po) in - let hs_s = subst_stmt s hs_s in - f_hoareS hs_m hs_pr hs_s hs_po + let hs_m, (hs_pr, hs_po) = + let s, hs_m = subst_memtype s hs_m in + let hs_pr = subst_form s hs_pr in + let hs_po = subst_form s hs_po in + hs_m, (hs_pr, hs_po) in + let hs_s = subst_stmt s hs_s in + f_hoareS hs_m hs_pr hs_s hs_po | FbdHoareF { bhf_pr; bhf_f; bhf_po; bhf_cmp; bhf_bd } -> - let bhf_pr, bhf_po = - let s = add_memory s mhr mhr in - let bhf_pr = subst_form s bhf_pr in - let bhf_po = subst_form s bhf_po in - (bhf_pr, bhf_po) in - let bhf_f = subst_xpath s bhf_f in - let bhf_bd = subst_form s bhf_bd in - f_bdHoareF bhf_pr bhf_f bhf_po bhf_cmp bhf_bd + let bhf_pr, bhf_po = + let s = add_memory s mhr mhr in + let bhf_pr = subst_form s bhf_pr in + let bhf_po = subst_form s bhf_po in + (bhf_pr, bhf_po) in + let bhf_f = subst_xpath s bhf_f in + let bhf_bd = subst_form s bhf_bd in + f_bdHoareF bhf_pr bhf_f bhf_po bhf_cmp bhf_bd | FbdHoareS { bhs_m; bhs_pr; bhs_s; bhs_po; bhs_cmp; bhs_bd } -> - let bhs_m, (bhs_pr, bhs_po, bhs_bd) = - let s, bhs_m = subst_memtype s bhs_m in - let bhs_pr = subst_form s bhs_pr in - let bhs_po = subst_form s bhs_po in - let bhs_bd = subst_form s bhs_bd in - bhs_m, (bhs_pr, bhs_po, bhs_bd) in - let bhs_s = subst_stmt s bhs_s in - f_bdHoareS bhs_m bhs_pr bhs_s bhs_po bhs_cmp bhs_bd + let bhs_m, (bhs_pr, bhs_po, bhs_bd) = + let s, bhs_m = subst_memtype s bhs_m in + let bhs_pr = subst_form s bhs_pr in + let bhs_po = subst_form s bhs_po in + let bhs_bd = subst_form s bhs_bd in + bhs_m, (bhs_pr, bhs_po, bhs_bd) in + let bhs_s = subst_stmt s bhs_s in + f_bdHoareS bhs_m bhs_pr bhs_s bhs_po bhs_cmp bhs_bd | FeHoareF { ehf_pr; ehf_f; ehf_po } -> - let ehf_pr, ehf_po = - let s = add_memory s mhr mhr in - let ehf_pr = subst_form s ehf_pr in - let ehf_po = subst_form s ehf_po in - (ehf_pr, ehf_po) in - let ehf_f = subst_xpath s ehf_f in - f_eHoareF ehf_pr ehf_f ehf_po + let ehf_pr, ehf_po = + let s = add_memory s mhr mhr in + let ehf_pr = subst_form s ehf_pr in + let ehf_po = subst_form s ehf_po in + (ehf_pr, ehf_po) in + let ehf_f = subst_xpath s ehf_f in + f_eHoareF ehf_pr ehf_f ehf_po | FeHoareS { ehs_m; ehs_pr; ehs_s; ehs_po } -> - let ehs_m, (ehs_pr, ehs_po) = - let s, ehs_m = subst_memtype s ehs_m in - let ehs_pr = subst_form s ehs_pr in - let ehs_po = subst_form s ehs_po in - ehs_m, (ehs_pr, ehs_po) in - let ehs_s = subst_stmt s ehs_s in - f_eHoareS ehs_m ehs_pr ehs_s ehs_po + let ehs_m, (ehs_pr, ehs_po) = + let s, ehs_m = subst_memtype s ehs_m in + let ehs_pr = subst_form s ehs_pr in + let ehs_po = subst_form s ehs_po in + ehs_m, (ehs_pr, ehs_po) in + let ehs_s = subst_stmt s ehs_s in + f_eHoareS ehs_m ehs_pr ehs_s ehs_po | FequivF { ef_pr; ef_fl; ef_fr; ef_po } -> - let ef_pr, ef_po = - let s = add_memory s mleft mleft in - let s = add_memory s mright mright in - let ef_pr = subst_form s ef_pr in - let ef_po = subst_form s ef_po in - (ef_pr, ef_po) in - let ef_fl = subst_xpath s ef_fl in - let ef_fr = subst_xpath s ef_fr in - f_equivF ef_pr ef_fl ef_fr ef_po + let ef_pr, ef_po = + let s = add_memory s mleft mleft in + let s = add_memory s mright mright in + let ef_pr = subst_form s ef_pr in + let ef_po = subst_form s ef_po in + (ef_pr, ef_po) in + let ef_fl = subst_xpath s ef_fl in + let ef_fr = subst_xpath s ef_fr in + f_equivF ef_pr ef_fl ef_fr ef_po | FequivS { es_ml; es_mr; es_pr; es_sl; es_sr; es_po } -> - let (es_ml, es_mr), (es_pr, es_po) = - let s, es_ml = subst_memtype s es_ml in - let s, es_mr = subst_memtype s es_mr in - let es_pr = subst_form s es_pr in - let es_po = subst_form s es_po in - (es_ml, es_mr), (es_pr, es_po) in - let es_sl = subst_stmt s es_sl in - let es_sr = subst_stmt s es_sr in - f_equivS es_ml es_mr es_pr es_sl es_sr es_po + let (es_ml, es_mr), (es_pr, es_po) = + let s, es_ml = subst_memtype s es_ml in + let s, es_mr = subst_memtype s es_mr in + let es_pr = subst_form s es_pr in + let es_po = subst_form s es_po in + (es_ml, es_mr), (es_pr, es_po) in + let es_sl = subst_stmt s es_sl in + let es_sr = subst_stmt s es_sr in + f_equivS es_ml es_mr es_pr es_sl es_sr es_po | FeagerF { eg_pr; eg_sl; eg_fl; eg_fr; eg_sr; eg_po } -> - let eg_pr, eg_po = - let s = add_memory s mleft mleft in - let s = add_memory s mright mright in - let eg_pr = subst_form s eg_pr in - let eg_po = subst_form s eg_po in - (eg_pr, eg_po) in - let eg_sl = subst_stmt s eg_sl in - let eg_sr = subst_stmt s eg_sr in - let eg_fl = subst_xpath s eg_fl in - let eg_fr = subst_xpath s eg_fr in - f_eagerF eg_pr eg_sl eg_fl eg_fr eg_sr eg_po + let eg_pr, eg_po = + let s = add_memory s mleft mleft in + let s = add_memory s mright mright in + let eg_pr = subst_form s eg_pr in + let eg_po = subst_form s eg_po in + (eg_pr, eg_po) in + let eg_sl = subst_stmt s eg_sl in + let eg_sr = subst_stmt s eg_sr in + let eg_fl = subst_xpath s eg_fl in + let eg_fr = subst_xpath s eg_fr in + f_eagerF eg_pr eg_sl eg_fl eg_fr eg_sr eg_po | Fpr { pr_mem; pr_fun; pr_args; pr_event } -> - let pr_mem = subst_mem s pr_mem in - let pr_fun = subst_xpath s pr_fun in - let pr_args = subst_form s pr_args in - let pr_event = - let s = add_memory s mhr mhr in - subst_form s pr_event in - f_pr pr_mem pr_fun pr_args pr_event - - | Fif _ | Fint _ | Ftuple _ | Fproj _ | Fapp _ -> - f_map (subst_ty s) (subst_form s) f + let pr_mem = subst_mem s pr_mem in + let pr_fun = subst_xpath s pr_fun in + let pr_args = subst_form s pr_args in + let pr_event = + let s = add_memory s mhr mhr in + subst_form s pr_event in + f_pr pr_mem pr_fun pr_args pr_event (* -------------------------------------------------------------------- *) and subst_fop fty tys args (tyids, f) = - let s = add_tyvars empty tyids tys in + let s = add_tyvars empty (List.combine tyids tys) in let (s, args, f) = match f.f_node with @@ -837,14 +927,19 @@ let subst_top_module (s : subst) (m : top_module_expr) = tme_loca = m.tme_loca; } (* -------------------------------------------------------------------- *) -let subst_typeclass (s : subst) (tcs : Sp.t) = - Sp.map (subst_path s) tcs +let subst_typeclass (s : subst) (tc : typeclass) = + { tc_name = subst_path s tc.tc_name; + tc_args = subst_etyargs s tc.tc_args; } (* -------------------------------------------------------------------- *) let fresh_tparam (s : subst) ((x, tcs) : ty_param) = let newx = EcIdent.fresh x in - let tcs = subst_typeclass s tcs in - let s = add_tyvar s x (tvar newx) in + let tcs = List.map (subst_typeclass s) tcs in + let tcw = + let mk (offset : int) = + TCIAbstract { support = `Var newx; offset; } + in List.mapi (fun i _ -> mk i) tcs in + let s = add_tyvar s x (tvar newx, tcw) in (s, (newx, tcs)) (* -------------------------------------------------------------------- *) @@ -861,7 +956,7 @@ let subst_genty (s : subst) (tparams, ty) = let subst_tydecl_body (s : subst) (tyd : ty_body) = match tyd with | `Abstract tc -> - `Abstract (subst_typeclass s tc) + `Abstract (List.map (subst_typeclass s) tc) | `Concrete ty -> `Concrete (subst_ty s ty) @@ -924,7 +1019,7 @@ and subst_op_body (s : subst) (bd : opbody) = opf_struct = opfix.opf_struct; opf_branches = subst_branches es opfix.opf_branches; } - | OP_TC -> OP_TC + | OP_TC (p, n) -> OP_TC (subst_path s p, n) and subst_branches (s : subst) = function | OPB_Leaf (locals, e) -> @@ -1019,19 +1114,37 @@ let subst_field (s : subst) cr = f_inv = subst_path s cr.f_inv; f_div = omap (subst_path s) cr.f_div; } -(* -------------------------------------------------------------------- *) -let subst_instance (s : subst) tci = - match tci with - | `Ring cr -> `Ring (subst_ring s cr) - | `Field cr -> `Field (subst_field s cr) - | `General p -> `General (subst_path s p) - (* -------------------------------------------------------------------- *) let subst_tc (s : subst) tc = - let tc_prt = omap (subst_path s) tc.tc_prt in + let s, tc_tparams = fresh_tparams s tc.tc_tparams in + let tc_prt = omap (subst_typeclass s) tc.tc_prt in let tc_ops = List.map (snd_map (subst_ty s)) tc.tc_ops in let tc_axs = List.map (snd_map (subst_form s)) tc.tc_axs in - { tc_prt; tc_ops; tc_axs; tc_loca = tc.tc_loca } + { tc_tparams; tc_prt; tc_ops; tc_axs; tc_loca = tc.tc_loca } + +(* -------------------------------------------------------------------- *) +let subst_tcibody (s : subst) (tci : tcibody) = + match tci with + | `Ring cr -> `Ring (subst_ring s cr) + | `Field cr -> `Field (subst_field s cr) + + | `General (tc, syms) -> + let tc = subst_typeclass s tc in + let syms = + Option.map + (Mstr.map (fun (p, tys) -> (subst_path s p, subst_etyargs s tys))) + syms in + `General (tc, syms) + + +(* -------------------------------------------------------------------- *) +let subst_tcinstance (s : subst) (tci : tcinstance) = + let s, tci_params = fresh_tparams s tci.tci_params in + let tci_type = subst_ty s tci.tci_type in + let tci_instance = subst_tcibody s tci.tci_instance in + let tci_local = tci.tci_local in + + { tci_params; tci_type; tci_instance; tci_local; } (* -------------------------------------------------------------------- *) (* SUBSTITUTION OVER THEORIES *) @@ -1058,8 +1171,8 @@ let rec subst_theory_item_r (s : subst) (item : theory_item_r) = | Th_export (p, lc) -> Th_export (subst_path s p, lc) - | Th_instance (ty, tci, lc) -> - Th_instance (subst_genty s ty, subst_instance s tci, lc) + | Th_instance (x, tci) -> + Th_instance (x, subst_tcinstance s tci) | Th_typeclass (x, tc) -> Th_typeclass (x, subst_tc s tc) @@ -1099,16 +1212,16 @@ and subst_theory_source (s : subst) (ths : thsource) = { ths_base = subst_path s ths.ths_base; } (* -------------------------------------------------------------------- *) -let init_tparams (params : (EcIdent.t * ty) list) : subst = - List.fold_left (fun s (x, ty) -> add_tyvar s x ty) empty params +let init_tparams (params : (EcIdent.t * etyarg) list) : subst = + add_tyvars empty params (* -------------------------------------------------------------------- *) -let open_oper op tys = +let open_oper (op : operator) (tys : etyarg list) : ty * operator_kind = let s = List.combine (List.fst op.op_tparams) tys in let s = init_tparams s in (subst_ty s op.op_ty, subst_op_kind s op.op_kind) -let open_tydecl tyd tys = +let open_tydecl (tyd : tydecl) (tys : etyarg list) : EcDecl.ty_body = let s = List.combine (List.fst tyd.tyd_params) tys in let s = init_tparams s in subst_tydecl_body s tyd.tyd_type diff --git a/src/ecSubst.mli b/src/ecSubst.mli index 8eabb02aeb..eab598b759 100644 --- a/src/ecSubst.mli +++ b/src/ecSubst.mli @@ -26,6 +26,7 @@ val is_empty : subst -> bool val add_module : subst -> EcIdent.t -> mpath -> subst val add_path : subst -> src:path -> dst:path -> subst val add_tydef : subst -> path -> (EcIdent.t list * ty) -> subst +val add_tyvar : subst -> EcIdent.t -> etyarg -> subst val add_opdef : subst -> path -> (EcIdent.t list * expr) -> subst val add_pddef : subst -> path -> (EcIdent.t list * form) -> subst val add_moddef : subst -> src:path -> dst:mpath -> subst (* Only concrete modules *) @@ -43,7 +44,7 @@ val subst_theory : subst -> theory -> theory val subst_ax : subst -> axiom -> axiom val subst_op : subst -> operator -> operator val subst_tydecl : subst -> tydecl -> tydecl -val subst_tc : subst -> typeclass -> typeclass +val subst_tc : subst -> tc_decl -> tc_decl val subst_theory : subst -> theory -> theory val subst_branches : subst -> opbranches -> opbranches @@ -64,17 +65,21 @@ val subst_mod_restr : subst -> mod_restr -> mod_restr val subst_oracle_infos : subst -> oracle_infos -> oracle_infos (* -------------------------------------------------------------------- *) -val subst_gty : subst -> gty -> gty -val subst_genty : subst -> (ty_params * ty) -> (ty_params * ty) -val subst_ty : subst -> ty -> ty -val subst_form : subst -> form -> form -val subst_expr : subst -> expr -> expr -val subst_stmt : subst -> stmt -> stmt - val subst_progvar : subst -> prog_var -> prog_var -val subst_mem : subst -> EcIdent.t -> EcIdent.t -val subst_flocal : subst -> form -> form +val subst_mem : subst -> EcIdent.t -> EcIdent.t +val subst_flocal : subst -> form -> form +val subst_gty : subst -> gty -> gty +val subst_genty : subst -> (ty_params * ty) -> (ty_params * ty) +val subst_ty : subst -> ty -> ty +val subst_etyarg : subst -> etyarg -> etyarg +val subst_tcw : subst -> tcwitness -> tcwitness +val subst_form : subst -> form -> form +val subst_expr : subst -> expr -> expr +val subst_stmt : subst -> stmt -> stmt + +(* -------------------------------------------------------------------- *) +val open_oper : operator -> etyarg list -> ty * operator_kind +val open_tydecl : tydecl -> etyarg list -> ty_body (* -------------------------------------------------------------------- *) -val open_oper : operator -> ty list -> ty * operator_kind -val open_tydecl : tydecl -> ty list -> ty_body +val fresh_tparams : subst -> ty_params -> subst * ty_params diff --git a/src/ecTheory.ml b/src/ecTheory.ml index 0d39c8d21d..e439a2cfb2 100644 --- a/src/ecTheory.ml +++ b/src/ecTheory.ml @@ -1,6 +1,7 @@ (* -------------------------------------------------------------------- *) open EcUtils open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -32,8 +33,8 @@ and theory_item_r = | Th_module of top_module_expr | Th_theory of (symbol * ctheory) | Th_export of EcPath.path * is_local - | Th_instance of (ty_params * EcTypes.ty) * tcinstance * is_local - | Th_typeclass of (symbol * typeclass) + | Th_instance of (symbol option * tcinstance) + | Th_typeclass of (symbol * tc_decl) | Th_baserw of symbol * is_local | Th_addrw of EcPath.path * EcPath.path list * is_local | Th_reduction of (EcPath.path * rule_option * rule option) list @@ -50,8 +51,20 @@ and ctheory = { cth_source : thsource option; } -and tcinstance = [ `Ring of ring | `Field of field | `General of path ] -and thmode = [ `Abstract | `Concrete ] +and tcinstance = { + tci_params : ty_params; + tci_type : ty; + tci_instance : tcibody; + tci_local : locality; +} + +and tcibody = [ + | `Ring of ring + | `Field of field + | `General of typeclass * ((path * etyarg list) Mstr.t) option +] + +and thmode = [ `Abstract | `Concrete ] and rule_pattern = | Rule of top_rule_pattern * rule_pattern list @@ -59,7 +72,7 @@ and rule_pattern = | Var of EcIdent.t and top_rule_pattern = - [`Op of (EcPath.path * EcTypes.ty list) | `Tuple | `Proj of int] + [`Op of (EcPath.path * ty list) | `Tuple | `Proj of int] and rule = { rl_tyd : EcDecl.ty_params; diff --git a/src/ecTheory.mli b/src/ecTheory.mli index 472928561f..949ce569b2 100644 --- a/src/ecTheory.mli +++ b/src/ecTheory.mli @@ -1,5 +1,6 @@ (* -------------------------------------------------------------------- *) open EcSymbols +open EcMaps open EcPath open EcAst open EcTypes @@ -28,11 +29,11 @@ and theory_item_r = | Th_module of top_module_expr | Th_theory of (symbol * ctheory) | Th_export of EcPath.path * is_local - | Th_instance of (ty_params * EcTypes.ty) * tcinstance * is_local - | Th_typeclass of (symbol * typeclass) + | Th_instance of (symbol option * tcinstance) + | Th_typeclass of (symbol * tc_decl) | Th_baserw of symbol * is_local | Th_addrw of EcPath.path * EcPath.path list * is_local - (* reduction rule does not survive to section so no locality *) + (* reduction rule does not survive section => no locality *) | Th_reduction of (EcPath.path * rule_option * rule option) list | Th_auto of (int * symbol option * path list * is_local) @@ -47,8 +48,20 @@ and ctheory = { cth_source : thsource option; } -and tcinstance = [ `Ring of ring | `Field of field | `General of EcPath.path ] -and thmode = [ `Abstract | `Concrete ] +and tcinstance = { + tci_params : ty_params; + tci_type : ty; + tci_instance : tcibody; + tci_local : locality; +} + +and tcibody = [ + | `Ring of ring + | `Field of field + | `General of typeclass * ((path * etyarg list) Mstr.t) option +] + +and thmode = [ `Abstract | `Concrete ] and rule_pattern = | Rule of top_rule_pattern * rule_pattern list @@ -56,7 +69,7 @@ and rule_pattern = | Var of EcIdent.t and top_rule_pattern = - [`Op of (EcPath.path * EcTypes.ty list) | `Tuple | `Proj of int] + [`Op of (EcPath.path * ty list) | `Tuple | `Proj of int] and rule = { rl_tyd : EcDecl.ty_params; @@ -71,7 +84,6 @@ and rule_option = { ur_delta : bool; ur_eqtrue : bool; } - val mkitem : import -> theory_item_r -> theory_item (* -------------------------------------------------------------------- *) diff --git a/src/ecTheoryReplay.ml b/src/ecTheoryReplay.ml index 227aaee341..97374e5a04 100644 --- a/src/ecTheoryReplay.ml +++ b/src/ecTheoryReplay.ml @@ -1,5 +1,6 @@ (* ------------------------------------------------------------------ *) open EcSymbols +open EcMaps open EcUtils open EcLocation open EcParsetree @@ -50,14 +51,17 @@ let keep_of_mode (mode : clmode) = (* -------------------------------------------------------------------- *) exception Incompatible of incompatible -let tparams_compatible rtyvars ntyvars = +(* FIXME:TC *) +let tparams_compatible (rtyvars : ty_params) (ntyvars : ty_params) = let rlen = List.length rtyvars and nlen = List.length ntyvars in if rlen <> nlen then - raise (Incompatible (NotSameNumberOfTyParam(rlen,nlen))) + raise (Incompatible (NotSameNumberOfTyParam (rlen, nlen))) let ty_compatible env ue (rtyvars, rty) (ntyvars, nty) = tparams_compatible rtyvars ntyvars; - let subst = CS.Tvar.init rtyvars (List.map tvar ntyvars) in + let subst = + let etyargs = etyargs_of_tparams ntyvars in + CS.Tvar.init (List.combine (List.fst rtyvars) etyargs) in let rty = CS.Tvar.subst subst rty in try EcUnify.unify env ue rty nty with EcUnify.UnificationFailure _ -> @@ -113,7 +117,7 @@ let rec tybody_compatible exn hyps ty_body1 ty_body2 = let tydecl_compatible env tyd1 tyd2 = let params = tyd1.tyd_params in tparams_compatible params tyd2.tyd_params; - let tparams = List.map (fun (id,_) -> tvar id) params in + let tparams = etyargs_of_tparams params in let ty_body1 = tyd1.tyd_type in let ty_body2 = EcSubst.open_tydecl tyd2 tparams in let exn = Incompatible (TyBody(*tyd1,tyd2*)) in @@ -130,12 +134,13 @@ let expr_compatible exn env s e1 e2 = let get_open_oper exn env p tys = let oper = EcEnv.Op.by_path p env in - let _, okind = EcSubst.open_oper oper tys in + let _, okind = EcSubst.open_oper oper tys in (* FIXME:TC *) match okind with | OB_oper (Some ob) -> ob | _ -> raise exn let rec oper_compatible exn env ob1 ob2 = + (* FIXME: duplicated code *) match ob1, ob2 with | OP_Plain f1, OP_Plain f2 -> error_body exn (EcReduction.is_conv ~ri:ri_compatible (EcEnv.LDecl.init env []) f1 f2) @@ -153,7 +158,8 @@ let rec oper_compatible exn env ob1 ob2 = error_body exn (EcPath.p_equal p1 p2 && i11 = i21 && i12 = i22) | OP_Fix f1, OP_Fix f2 -> opfix_compatible exn env f1 f2 - | OP_TC, OP_TC -> () + | OP_TC (p1, n1), OP_TC (p2, n2) -> + error_body exn (EcPath.p_equal p1 p2 && n1 = n2) | _, _ -> raise exn and opfix_compatible exn env f1 f2 = @@ -198,10 +204,10 @@ let rec pred_compatible exn env pb1 pb2 = match pb1, pb2 with | PR_Plain f1, PR_Plain f2 -> error_body exn (EcReduction.is_conv (EcEnv.LDecl.init env []) f1 f2) | PR_Plain {f_node = Fop(p,tys)}, _ -> - let pb1 = get_open_pred exn env p tys in + let pb1 = get_open_pred exn env p tys in pred_compatible exn env pb1 pb2 | _, PR_Plain {f_node = Fop(p,tys)} -> - let pb2 = get_open_pred exn env p tys in + let pb2 = get_open_pred exn env p tys in pred_compatible exn env pb1 pb2 | PR_Ind pr1, PR_Ind pr2 -> ind_compatible exn env pr1 pr2 @@ -230,7 +236,7 @@ let operator_compatible env oper1 oper2 = let params = oper1.op_tparams in tparams_compatible oper1.op_tparams oper2.op_tparams; let oty1, okind1 = oper1.op_ty, oper1.op_kind in - let tparams = List.map (fun (id,_) -> tvar id) params in + let tparams = etyargs_of_tparams params in let oty2, okind2 = EcSubst.open_oper oper2 tparams in if not (EcReduction.EqTest.for_type env oty1 oty2) then raise (Incompatible (DifferentType(oty1, oty2))); @@ -373,17 +379,17 @@ let rec replay_tyd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, otyd | `Datatype { tydt_ctors = octors }, Tconstr (np, _) -> begin match (EcEnv.Ty.by_path np env).tyd_type with | `Datatype { tydt_ctors = _ } -> - let newtparams = List.fst newtyd.tyd_params in - let newtparams_ty = List.map tvar newtparams in - let newdtype = tconstr np newtparams_ty in - let tysubst = CS.Tvar.init (List.fst otyd.tyd_params) newtparams_ty in + let newtparams = etyargs_of_tparams newtyd.tyd_params in + let newdtype = tconstr_tc np newtparams in + let tysubst = + CS.Tvar.init (List.combine (List.fst otyd.tyd_params) newtparams) in List.fold_left (fun subst (name, tyargs) -> let np = EcPath.pqoname (EcPath.prefix np) name in let newtyargs = List.map (CS.Tvar.subst tysubst) tyargs in EcSubst.add_opdef subst (xpath ove name) - (newtparams, e_op np newtparams_ty (toarrow newtyargs newdtype))) + (List.fst newtyd.tyd_params, e_op_tc np newtparams (toarrow newtyargs newdtype))) subst octors | _ -> subst end @@ -446,15 +452,18 @@ and replay_opd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopd) = in begin try ty_compatible env ue - (List.map fst reftyvars, refty) - (List.map fst (EcUnify.UniEnv.tparams ue), ty) + (reftyvars, refty) + (EcUnify.UniEnv.tparams ue, ty) with Incompatible err -> clone_error env (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; - if not (EcUnify.UniEnv.closed ue) then - ove.ovre_hooks.herr - ~loc "this operator body contains free type variables"; + Option.iter (fun infos -> + ove.ovre_hooks.herr ~loc + (Format.asprintf + "this operator body contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos) + ) (EcUnify.UniEnv.xclosed ue); let sty = CS.Tuni.subst (EcUnify.UniEnv.close ue) in let body = EcFol.Fsubst.f_subst sty body in @@ -560,16 +569,19 @@ and replay_prd (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, oopr) = begin try ty_compatible env ue - (List.map fst reftyvars, refty) - (List.map fst (EcUnify.UniEnv.tparams ue), body.f_ty) + (reftyvars, refty) + (EcUnify.UniEnv.tparams ue, body.f_ty) with Incompatible err -> clone_error env (CE_OpIncompatible ((snd ove.ovre_prefix, x), err)) end; - if not (EcUnify.UniEnv.closed ue) then - ove.ovre_hooks.herr - ~loc "this predicate body contains free type variables"; + Option.iter (fun infos -> + ove.ovre_hooks.herr ~loc + (Format.asprintf + "this predicate body contains free %a variables" + EcUserMessages.TypingError.pp_uniflags infos) + ) (EcUnify.UniEnv.xclosed ue); let fs = CS.Tuni.subst (EcUnify.UniEnv.close ue) in let body = EcFol.Fsubst.f_subst fs body in @@ -870,7 +882,7 @@ and replay_typeclass (* -------------------------------------------------------------------- *) and replay_instance - (ove : _ ovrenv) (subst, ops, proofs, scope) (import, (typ, ty), tc, lc) + (ove : _ ovrenv) (subst, ops, proofs, scope) (import, x, tci) = let opath = ove.ovre_opath in let npath = ove.ovre_npath in @@ -898,7 +910,7 @@ and replay_instance | OB_oper (Some (OP_Record _)) | OB_oper (Some (OP_Proj _)) | OB_oper (Some (OP_Fix _)) - | OB_oper (Some (OP_TC )) -> + | OB_oper (Some (OP_TC _)) -> Some (EcPath.pappend npath q) | OB_oper (Some (OP_Plain f)) -> match f.f_node with @@ -908,9 +920,15 @@ and replay_instance let forpath p = odfl p (forpath p) in + let fortypeclass (tc : typeclass) = + { tc_name = forpath tc.tc_name; + tc_args = List.map (EcSubst.subst_etyarg subst) tc.tc_args; } in + try - let (typ, ty) = EcSubst.subst_genty subst (typ, ty) in - let tc = + let subst, tci_params = EcSubst.fresh_tparams subst tci.tci_params in + let tci_type = EcSubst.subst_ty subst tci.tci_type in + + let tci_instance : tcibody = let rec doring cr = { r_type = EcSubst.subst_ty subst cr.r_type; r_zero = forpath cr.r_zero; @@ -933,14 +951,25 @@ and replay_instance f_inv = forpath cr.f_inv; f_div = cr.f_div |> omap forpath; } in - match tc with - | `Ring cr -> `Ring (doring cr) - | `Field cr -> `Field (dofield cr) - | `General p -> `General (forpath p) + match tci.tci_instance with + | `Ring cr -> `Ring (doring cr) + | `Field cr -> `Field (dofield cr) + + | `General (tc, syms) -> + let tc = fortypeclass tc in + let syms = + Option.map + (Mstr.map (fun (p, tys) -> + (forpath p, List.map (EcSubst.subst_etyarg subst) tys))) + syms in + `General (tc, syms) in - let scope = ove.ovre_hooks.hadd_item scope import (Th_instance ((typ, ty), tc, lc)) in - (subst, ops, proofs, scope) + let tci = { tci with tci_params; tci_type; tci_instance; } in + + let scope = + ove.ovre_hooks.hadd_item scope import (Th_instance (x, tci)) + in (subst, ops, proofs, scope) with E.InvInstPath -> (subst, ops, proofs, scope) @@ -987,8 +1016,8 @@ and replay1 (ove : _ ovrenv) (subst, ops, proofs, scope) item = | Th_typeclass (x, tc) -> replay_typeclass ove (subst, ops, proofs, scope) (item.ti_import, x, tc) - | Th_instance ((typ, ty), tc, lc) -> - replay_instance ove (subst, ops, proofs, scope) (item.ti_import, (typ, ty), tc, lc) + | Th_instance (x, tci) -> + replay_instance ove (subst, ops, proofs, scope) (item.ti_import, x, tci) | Th_theory (ox, cth) -> begin let thmode = cth.cth_mode in diff --git a/src/ecTypeClass.ml b/src/ecTypeClass.ml index f142cc94d9..efdaf16edc 100644 --- a/src/ecTypeClass.ml +++ b/src/ecTypeClass.ml @@ -1,87 +1,147 @@ (* -------------------------------------------------------------------- *) -open EcUtils +open EcIdent open EcPath +open EcUtils +open EcAst +open EcTheory (* -------------------------------------------------------------------- *) -type graph = { - tcg_nodes : Sp.t Mp.t; - tcg_closure : Sp.t Mp.t; -} +exception NoMatch -type nodes = { - tcn_graph : graph; - tcn_nodes : Sp.t; -} +(* -------------------------------------------------------------------- *) +module TyMatch(E : sig val env : EcEnv.env end) = struct + let rec doit_type (map : ty option Mid.t) (pattern : ty) (ty : ty) = + let pattern = EcEnv.ty_hnorm pattern E.env in + let ty = EcEnv.ty_hnorm ty E.env in -type node = EcPath.path + match pattern.ty_node, ty.ty_node with + | Tunivar _, _ -> + assert false -exception CycleDetected + | Tvar a, _ -> begin + match Option.get (Mid.find_opt a map) with + | None -> + Mid.add a (Some ty) map + + | Some ty' -> + if not (EcCoreEqTest.for_type E.env ty ty') then + raise NoMatch; + map -(* -------------------------------------------------------------------- *) -module Graph = struct - let empty : graph = { - tcg_nodes = Mp.empty; - tcg_closure = Mp.empty; - } - - let dump gr = - Printf.sprintf "%s\n" - (String.concat "\n" - (List.map - (fun (p, ps) -> Printf.sprintf "%s -> %s" - (EcPath.tostring p) - (String.concat ", " (List.map EcPath.tostring (Sp.elements ps)))) - (Mp.bindings gr.tcg_nodes))) - - let has_path ~src ~dst g = - if EcPath.p_equal src dst then - true - else - match Mp.find_opt src g.tcg_closure with - | None -> false - | Some m -> Mp.mem dst m - - let add ~src ~dst g = - if has_path ~src ~dst g then - raise CycleDetected; - - match Mp.find_opt src g.tcg_nodes with - | Some m when Mp.mem dst m -> g - | _ -> - let up_node m = Sp.add dst (odfl Sp.empty m) - and up_clos m = - Sp.union - (odfl Sp.empty (Mp.find_opt dst g.tcg_closure)) - (Sp.add dst (odfl Sp.empty m)) - in - { g with - tcg_nodes = Mp.change (some -| up_node) src g.tcg_nodes; - tcg_closure = Mp.change (some -| up_clos) src g.tcg_closure; } + end + + | Tglob id1, Tglob id2 when EcIdent.id_equal id1 id2 -> + map + + | Tconstr (p, args), Tconstr (p', args') -> + if not (EcPath.p_equal p p') then + raise NoMatch; + doit_etyargs map args args' + + | Ttuple ptns, Ttuple tys when List.length ptns = List.length tys -> + doit_types map ptns tys + + | Tfun (p1, p2), Tfun (ty1, ty2) -> + doit_types map [p1; p2] [ty1; ty2] + + | _, _ -> + raise NoMatch + + and doit_types (map : ty option Mid.t) (pts : ty list) (tys : ty list) = + List.fold_left2 doit_type map pts tys + + and doit_etyarg (map : ty option Mid.t) ((pattern, ptcws) : etyarg) ((ty, ttcws) : etyarg) = + let map = doit_type map pattern ty in + let map = doit_tcws map ptcws ttcws in + map + + and doit_etyargs (map : ty option Mid.t) (pts : etyarg list) (etys : etyarg list) = + List.fold_left2 doit_etyarg map pts etys + + and doit_tcw (map : ty option Mid.t) (ptcw : tcwitness) (ttcw : tcwitness) = + match ptcw, ttcw with + | TCIUni _, _ -> + assert false + + | TCIConcrete ptcw, TCIConcrete ttcw -> + if not (EcPath.p_equal ptcw.path ttcw.path) then + raise NoMatch; + doit_etyargs map ptcw.etyargs ttcw.etyargs + + | TCIAbstract _, TCIAbstract _ -> + if not (EcAst.tcw_equal ptcw ttcw) then + raise NoMatch; + map + + | _, _ -> + raise NoMatch + + and doit_tcws (map : ty option Mid.t) (ptcws : tcwitness list) (ttcws : tcwitness list) = + List.fold_left2 doit_tcw map ptcws ttcws end (* -------------------------------------------------------------------- *) -module Nodes = struct - let empty g = { - tcn_graph = g; - tcn_nodes = Sp.empty; - } - - let add n nodes = - let module E = struct exception Discard end in - - try - let aout = - Sp.filter - (fun p -> - if Graph.has_path ~src:p ~dst:n nodes.tcn_graph then raise E.Discard; - not (Graph.has_path ~src:n ~dst:p nodes.tcn_graph)) - nodes.tcn_nodes - in - { nodes with tcn_nodes = Sp.add n aout } - with E.Discard -> nodes - - let toset nodes = nodes.tcn_nodes - - let reduce set g = - toset (Sp.fold add set (empty g)) -end +let ty_match (env : EcEnv.env) (params : ident list) ~(pattern : ty) ~(ty : ty) = + let module M = TyMatch(struct let env = env end) in + let map = Mid.of_list (List.map (fun a -> (a, None)) params) in + M.doit_type map pattern ty + +(* -------------------------------------------------------------------- *) +let etyargs_match + (env : EcEnv.env) + (params : ident list) + ~(patterns : etyarg list) + ~(etyargs : etyarg list) += + let module M = TyMatch(struct let env = env end) in + let map = Mid.of_list (List.map (fun a -> (a, None)) params) in + M.doit_etyargs map patterns etyargs + +(* -------------------------------------------------------------------- *) +let rec check_tcinstance + (env : EcEnv.env) + (ty : ty) + (tc : typeclass) + ((p, tci) : path option * tcinstance) += + let exception Bailout in + + try + let p = oget ~exn:Bailout p in + + let tgargs = + match tci.tci_instance with + | `General (tgp, _) -> + if not (EcPath.p_equal tc.tc_name tgp.tc_name) then + raise Bailout; + tgp.tc_args + | _ -> raise Bailout in + + let map = + etyargs_match env (List.fst tci.tci_params) + ~patterns:tgargs ~etyargs:tc.tc_args in + + let map = + let module M = TyMatch(struct let env = env end) in + M.doit_type map tci.tci_type ty in + + + let _, args = List.fold_left_map (fun subst (a, aargs) -> + let aty = oget ~exn:Bailout (Mid.find a map) in + let aargs = List.map (fun aarg -> + let aarg = EcCoreSubst.Tvar.subst_tc subst aarg in + oget ~exn:Bailout (infer env aty aarg) + ) aargs in + let subst = Mid.add a (aty, aargs) subst in + (subst, (aty, aargs)) + ) Mid.empty tci.tci_params in + + Some (TCIConcrete { path = p; etyargs = args; }) + + with Bailout | NoMatch -> None + +(* -------------------------------------------------------------------- *) +and infer (env : EcEnv.env) (ty : ty) (tc : typeclass) = + List.find_map_opt + (check_tcinstance env ty tc) + (EcEnv.TcInstance.get_all env) diff --git a/src/ecTypeClass.mli b/src/ecTypeClass.mli index 9c8b566600..66c7ed7f42 100644 --- a/src/ecTypeClass.mli +++ b/src/ecTypeClass.mli @@ -1,23 +1,7 @@ (* -------------------------------------------------------------------- *) -open EcPath +open EcAst +open EcDecl +open EcEnv -type node = path - -type graph -type nodes - -exception CycleDetected - -module Graph : sig - val empty : graph - val add : src:node -> dst:node -> graph -> graph - val has_path : src:node -> dst:node -> graph -> bool - val dump : graph -> string -end - -module Nodes : sig - val empty : graph -> nodes - val add : node -> nodes -> nodes - val toset : nodes -> Sp.t - val reduce : Sp.t -> graph -> Sp.t -end +(* -------------------------------------------------------------------- *) +val infer : env -> ty -> typeclass -> tcwitness option diff --git a/src/ecTypes.ml b/src/ecTypes.ml index 3da35d7287..75b30cfdb3 100644 --- a/src/ecTypes.ml +++ b/src/ecTypes.ml @@ -42,7 +42,7 @@ let rec dump_ty ty = EcIdent.tostring p | Tunivar i -> - Printf.sprintf "#%d" i + Printf.sprintf "#%d" (i :> int) | Tvar id -> EcIdent.tostring id @@ -52,17 +52,18 @@ let rec dump_ty ty = | Tconstr (p, tys) -> Printf.sprintf "%s[%s]" (EcPath.tostring p) - (String.concat ", " (List.map dump_ty tys)) + (String.concat ", " (List.map dump_ty (List.fst tys))) | Tfun (t1, t2) -> Printf.sprintf "(%s) -> (%s)" (dump_ty t1) (dump_ty t2) (* -------------------------------------------------------------------- *) -let tuni uid = mk_ty (Tunivar uid) -let tvar id = mk_ty (Tvar id) -let tconstr p lt = mk_ty (Tconstr (p, lt)) -let tfun t1 t2 = mk_ty (Tfun (t1, t2)) -let tglob m = mk_ty (Tglob m) +let tuni uid = mk_ty (Tunivar uid) +let tvar id = mk_ty (Tvar id) +let tconstr p lt = mk_ty (Tconstr (p, List.map (fun ty -> (ty, [])) lt)) +let tconstr_tc p lt = mk_ty (Tconstr (p, lt)) +let tfun t1 t2 = mk_ty (Tfun (t1, t2)) +let tglob m = mk_ty (Tglob m) (* -------------------------------------------------------------------- *) let tunit = tconstr EcCoreLib.CI_Unit .p_unit [] @@ -103,7 +104,7 @@ let rec tyfun_flat (ty : ty) = (* -------------------------------------------------------------------- *) let as_tdistr (ty : ty) = match ty.ty_node with - | Tconstr (p, [sty]) + | Tconstr (p, [sty, []]) when EcPath.p_equal p EcCoreLib.CI_Distr.p_distr -> Some sty @@ -112,7 +113,7 @@ let as_tdistr (ty : ty) = let is_tdistr (ty : ty) = as_tdistr ty <> None (* -------------------------------------------------------------------- *) -let ty_map f t = +let rec ty_map (f : ty -> ty) (t : ty) : ty = match t.ty_node with | Tglob _ | Tunivar _ | Tvar _ -> t @@ -120,39 +121,88 @@ let ty_map f t = ttuple (List.Smart.map f lty) | Tconstr (p, lty) -> - let lty = List.Smart.map f lty in - tconstr p lty + let lty = List.Smart.map (etyarg_map f) lty in + tconstr_tc p lty | Tfun (t1, t2) -> tfun (f t1) (f t2) -let ty_fold f s ty = +and etyarg_map (f : ty -> ty) ((ty, tcw) : etyarg) : etyarg = + let ty = f ty in + let tcw = List.Smart.map (tcw_map f) tcw in + (ty, tcw) + +and tcw_map (f : ty -> ty) (tcw : tcwitness) : tcwitness = + match tcw with + | TCIUni _ -> + tcw + + | TCIConcrete { path; etyargs; } -> + let etyargs = List.Smart.map (etyarg_map f) etyargs in + TCIConcrete { path; etyargs; } + + | TCIAbstract _ -> + tcw + +(* -------------------------------------------------------------------- *) +let rec ty_fold (f : 'a -> ty -> 'a) (v : 'a) (ty : ty) : 'a = match ty.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> s - | Ttuple lty -> List.fold_left f s lty - | Tconstr(_, lty) -> List.fold_left f s lty - | Tfun(t1,t2) -> f (f s t1) t2 + | Tglob _ | Tunivar _ | Tvar _ -> v + | Ttuple lty -> List.fold_left f v lty + | Tconstr (_, lty) -> List.fold_left (etyarg_fold f) v lty + | Tfun (t1, t2) -> f (f v t1) t2 + +and etyarg_fold (f : 'a -> ty -> 'a) (v : 'a) (ety : etyarg) : 'a = + let (ty, tcw) = ety in + List.fold_left (tcw_fold f) (f v ty) tcw + +and tcw_fold (f : 'a -> ty -> 'a) (v : 'a) (tcw : tcwitness) : 'a = + match tcw with + | TCIConcrete { etyargs } -> + List.fold_left (etyarg_fold f) v etyargs + + | TCIUni _ | TCIAbstract _ -> + v -let ty_sub_exists f t = - match t.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> false - | Ttuple lty -> List.exists f lty - | Tconstr (_, lty) -> List.exists f lty - | Tfun (t1, t2) -> f t1 || f t2 +(* -------------------------------------------------------------------- *) +let ty_iter (f : ty -> unit) (ty : ty) : unit = + ty_fold (fun () -> f) () ty -let ty_iter f t = - match t.ty_node with - | Tglob _ | Tunivar _ | Tvar _ -> () - | Ttuple lty -> List.iter f lty - | Tconstr (_, lty) -> List.iter f lty - | Tfun (t1,t2) -> f t1; f t2 +let etyarg_iter (f : ty -> unit) (ety : etyarg) : unit = + etyarg_fold (fun () -> f) () ety +let tcw_iter (f : ty -> unit) (tcw : tcwitness) : unit = + tcw_fold (fun () -> f) () tcw + +(* -------------------------------------------------------------------- *) +let ty_sub_exists (f : ty -> bool) (ty : ty) = + let exception Exists in + try + ty_iter (fun ty -> if f ty then raise Exists) ty; + false + with Exists -> true + +let etyarg_sub_exists (f : ty -> bool) (ety : etyarg) = + let exception Exists in + try + etyarg_iter (fun ty -> if f ty then raise Exists) ety; + false + with Exists -> true + +let tcw_sub_exists (f : ty -> bool) (tcw : tcwitness) = + let exception Exists in + try + tcw_iter (fun ty -> if f ty then raise Exists) tcw; + false + with Exists -> true + +(* -------------------------------------------------------------------- *) exception FoundUnivar -let rec ty_check_uni t = - match t.ty_node with +let rec ty_check_uni (ty : ty) : unit = + match ty.ty_node with | Tunivar _ -> raise FoundUnivar - | _ -> ty_iter ty_check_uni t + | _ -> ty_iter ty_check_uni ty (* -------------------------------------------------------------------- *) let symbol_of_ty (ty : ty) = @@ -197,7 +247,6 @@ let ovar_of_var { v_name = n; v_type = t } = { ov_name = Some n; ov_type = t } module Tvar = struct - let rec fv_rec fv t = match t.ty_node with | Tvar id -> Sid.add id fv @@ -209,6 +258,34 @@ end let ty_fv_and_tvar (ty : ty) = EcIdent.fv_union ty.ty_fv (Mid.map (fun () -> 1) (Tvar.fv ty)) +(* -------------------------------------------------------------------- *) +let rec etyargs_tvar_fv (etyargs : etyarg list) = + List.fold_left + (fun fv etyarg -> Sid.union fv (etyarg_tvar_fv etyarg)) + Sid.empty etyargs + +and etyarg_tvar_fv ((ty, tcws) : etyarg) : Sid.t = + Sid.union (Tvar.fv ty) (tcws_tvar_fv tcws) + +and tcws_tvar_fv (tcws : tcwitness list) = + List.fold_left + (fun fv tcw -> Sid.union fv (tcw_tvar_fv tcw)) + Sid.empty tcws + +and tcw_tvar_fv (tcw : tcwitness) : Sid.t = + match tcw with + | TCIUni _ -> + Sid.empty + + | TCIConcrete { etyargs } -> + etyargs_tvar_fv etyargs + + | TCIAbstract { support = `Var tyvar } -> + Sid.singleton tyvar + + | TCIAbstract { support = (`Abs _) } -> + Sid.empty + (* -------------------------------------------------------------------- *) type pvar_kind = EcAst.pvar_kind @@ -310,38 +387,54 @@ let lp_bind = function List.pmap (fun (x, ty) -> omap (fun x -> (x, ty)) x) b (* -------------------------------------------------------------------- *) -type expr = EcAst.expr - +type expr = EcAst.expr type expr_node = EcAst.expr_node - type equantif = EcAst.equantif type ebinding = EcAst.ebinding type ebindings = EcAst.ebindings type closure = (EcIdent.t * ty) list * expr +(* -------------------------------------------------------------------- *) +type etyarg = EcAst.etyarg + +let etyarg_fv = EcAst.etyarg_fv +let etyargs_fv = EcAst.etyargs_fv +let etyarg_hash = EcAst.etyarg_hash +let etyarg_equal = EcAst.etyarg_equal + +(* -------------------------------------------------------------------- *) +type tcwitness = EcAst.tcwitness + +let tcw_fv = EcAst.tcw_fv +let tcw_hash = EcAst.tcw_hash +let tcw_equal = EcAst.tcw_equal + (* -------------------------------------------------------------------- *) let e_equal = EcAst.e_equal -let e_hash = EcAst.e_hash let e_compare = fun e1 e2 -> e_hash e1 - e_hash e2 let e_fv = EcAst.e_fv +let e_hash = EcAst.e_hash let e_ty e = e.e_ty (* -------------------------------------------------------------------- *) let lp_fv = EcAst.lp_fv - let pv_fv = EcAst.pv_fv (* -------------------------------------------------------------------- *) let eqt_equal = EcAst.eqt_equal -(* -------------------------------------------------------------------- *) - let e_tt = mk_expr (Eop (EcCoreLib.CI_Unit.p_tt, [])) tunit let e_int = fun i -> mk_expr (Eint i) tint let e_local = fun x ty -> mk_expr (Elocal x) ty let e_var = fun x ty -> mk_expr (Evar x) ty -let e_op = fun x targs ty -> mk_expr (Eop (x, targs)) ty + +let e_op_tc x targs ty = + mk_expr (Eop (x, targs)) ty + +let e_op x targs ty = + e_op_tc x (List.map (fun ty -> (ty, [])) targs) ty + let e_let = fun pt e1 e2 -> mk_expr (Elet (pt, e1, e2)) e2.e_ty let e_tuple = fun es -> match es with @@ -359,13 +452,6 @@ let e_proj_simpl e i ty = | _ -> e_proj e i ty let e_quantif q b e = - if List.is_empty b then e else - - let b, e = - match e.e_node with - | Equant (q', b', e) when eqt_equal q q' -> (b@b', e) - | _ -> b, e in - let ty = match q with | `ELambda -> toarrow (List.map snd b) e.e_ty @@ -378,11 +464,7 @@ let e_exists b e = e_quantif `EExists b e let e_lam b e = e_quantif `ELambda b e let e_app x args ty = - if args = [] then x - else - match x.e_node with - | Eapp(x', args') -> mk_expr (Eapp (x', (args'@args))) ty - | _ -> mk_expr (Eapp (x, args)) ty + mk_expr (Eapp (x, args)) ty let e_app_op ?(tyargs=[]) op args ty = e_app (e_op op tyargs (toarrow (List.map e_ty args) ty)) args ty @@ -438,54 +520,33 @@ let e_oget (e : expr) (ty : ty) : expr = e_app op [e] ty (* -------------------------------------------------------------------- *) -let e_map fty fe e = +let e_map (fe : expr -> expr) (e : expr) : expr = match e.e_node with - | Eint _ | Elocal _ | Evar _ -> e - - | Eop (p, tys) -> - let tys' = List.Smart.map fty tys in - let ty' = fty e.e_ty in - e_op p tys' ty' + | Eint _ -> e + | Elocal _ -> e + | Evar _ -> e + | Eop _ -> e | Eapp (e1, args) -> - let e1' = fe e1 in - let args' = List.Smart.map fe args in - let ty' = fty e.e_ty in - e_app e1' args' ty' + e_app (fe e1) (List.Smart.map fe args) e.e_ty | Elet (lp, e1, e2) -> - let e1' = fe e1 in - let e2' = fe e2 in - e_let lp e1' e2' + e_let lp (fe e1) (fe e2) | Etuple le -> - let le' = List.Smart.map fe le in - e_tuple le' + e_tuple (List.Smart.map fe le) | Eproj (e1, i) -> - let e' = fe e1 in - let ty = fty e.e_ty in - e_proj e' i ty + e_proj (fe e1) i e.e_ty | Eif (e1, e2, e3) -> - let e1' = fe e1 in - let e2' = fe e2 in - let e3' = fe e3 in - e_if e1' e2' e3' + e_if (fe e1) (fe e2) (fe e3) - | Ematch (b, es, ty) -> - let ty' = fty ty in - let b' = fe b in - let es' = List.Smart.map fe es in - e_match b' es' ty' + | Ematch (e, bs, ty) -> + e_match (fe e) (List.Smart.map fe bs) ty | Equant (q, b, bd) -> - let dop (x, ty as xty) = - let ty' = fty ty in - if ty == ty' then xty else (x, ty') in - let b' = List.Smart.map dop b in - let bd' = fe bd in - e_quantif q b' bd' + e_quantif q b (fe bd) let e_fold (fe : 'a -> expr -> 'a) (state : 'a) (e : expr) = match e.e_node with @@ -504,6 +565,7 @@ let e_fold (fe : 'a -> expr -> 'a) (state : 'a) (e : expr) = let e_iter (fe : expr -> unit) (e : expr) = e_fold (fun () e -> fe e) () e +(* -------------------------------------------------------------------- *) module MSHe = EcMaps.MakeMSH(struct type t = expr let tag e = e.e_tag end) module Me = MSHe.M module Se = MSHe.S @@ -554,3 +616,4 @@ let split_args e = match e.e_node with | Eapp (e, args) -> (e, args) | _ -> (e, []) + \ No newline at end of file diff --git a/src/ecTypes.mli b/src/ecTypes.mli index 34b7b4cbf2..2fc4295516 100644 --- a/src/ecTypes.mli +++ b/src/ecTypes.mli @@ -1,4 +1,6 @@ (* -------------------------------------------------------------------- *) + +open EcAst open EcBigInt open EcMaps open EcSymbols @@ -27,13 +29,14 @@ val dump_ty : ty -> string val ty_equal : ty -> ty -> bool val ty_hash : ty -> int -val tuni : EcUid.uid -> ty -val tvar : EcIdent.t -> ty -val ttuple : ty list -> ty -val tconstr : EcPath.path -> ty list -> ty -val tfun : ty -> ty -> ty -val tglob : EcIdent.t -> ty -val tpred : ty -> ty +val tuni : tyuni -> ty +val tvar : EcIdent.t -> ty +val ttuple : ty list -> ty +val tconstr : EcPath.path -> ty list -> ty +val tconstr_tc : EcPath.path -> EcAst.etyarg list -> ty +val tfun : ty -> ty -> ty +val tglob : EcIdent.t -> ty +val tpred : ty -> ty val ty_fv_and_tvar : ty -> int Mid.t @@ -64,20 +67,30 @@ exception FoundUnivar val ty_check_uni : ty -> unit (* -------------------------------------------------------------------- *) - module Tvar : sig - val fv : ty -> Sid.t + val fv : ty -> Sid.t end (* -------------------------------------------------------------------- *) (* [map f t] applies [f] on strict subterms of [t] (not recursive) *) val ty_map : (ty -> ty) -> ty -> ty +val etyarg_map : (ty -> ty) -> etyarg -> etyarg +val tcw_map : (ty -> ty) -> tcwitness -> tcwitness (* [sub_exists f t] true if one of the strict-subterm of [t] valid [f] *) val ty_sub_exists : (ty -> bool) -> ty -> bool +val etyarg_sub_exists : (ty -> bool) -> etyarg -> bool +val tcw_sub_exists : (ty -> bool) -> tcwitness -> bool +(* -------------------------------------------------------------------- *) val ty_fold : ('a -> ty -> 'a) -> 'a -> ty -> 'a +val etyarg_fold : ('a -> ty -> 'a) -> 'a -> etyarg -> 'a +val tcw_fold : ('a -> ty -> 'a) -> 'a -> tcwitness -> 'a + +(* -------------------------------------------------------------------- *) val ty_iter : (ty -> unit) -> ty -> unit +val etyarg_iter : (ty -> unit) -> etyarg -> unit +val tcw_iter : (ty -> unit) -> tcwitness -> unit (* -------------------------------------------------------------------- *) val symbol_of_ty : ty -> string @@ -158,6 +171,27 @@ type closure = (EcIdent.t * ty) list * expr (* -------------------------------------------------------------------- *) val eqt_equal : equantif -> equantif -> bool +(* -------------------------------------------------------------------- *) +type etyarg = EcAst.etyarg + +val etyarg_fv : etyarg -> int Mid.t +val etyargs_fv : etyarg list -> int Mid.t +val etyarg_hash : etyarg -> int +val etyarg_equal : etyarg -> etyarg -> bool + +(* -------------------------------------------------------------------- *) +type tcwitness = EcAst.tcwitness + +val tcw_fv : tcwitness -> int Mid.t +val tcw_hash : tcwitness -> int +val tcw_equal : tcwitness -> tcwitness -> bool + +(* -------------------------------------------------------------------- *) +val etyargs_tvar_fv : etyarg list -> Sid.t +val etyarg_tvar_fv : etyarg -> Sid.t +val tcws_tvar_fv : tcwitness list -> Sid.t +val tcw_tvar_fv : tcwitness -> Sid.t + (* -------------------------------------------------------------------- *) val e_equal : expr -> expr -> bool val e_compare : expr -> expr -> int @@ -171,6 +205,7 @@ val e_int : zint -> expr val e_decimal : zint * (int * zint) -> expr val e_local : EcIdent.t -> ty -> expr val e_var : prog_var -> ty -> expr +val e_op_tc : EcPath.path -> etyarg list -> ty -> expr val e_op : EcPath.path -> ty list -> ty -> expr val e_app : expr -> expr list -> ty -> expr val e_let : lpattern -> expr -> expr -> expr @@ -208,8 +243,7 @@ val split_args : expr -> expr * expr list (* -------------------------------------------------------------------- *) val e_map : - (ty -> ty ) (* 1-subtype op. *) - -> (expr -> expr) (* 1-subexpr op. *) + (expr -> expr) (* 1-subexpr op. *) -> expr -> expr @@ -217,5 +251,3 @@ val e_fold : ('state -> expr -> 'state) -> 'state -> expr -> 'state val e_iter : (expr -> unit) -> expr -> unit - -(* -------------------------------------------------------------------- *) diff --git a/src/ecTyping.ml b/src/ecTyping.ml index 2b912b7583..a99e2d6bde 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -26,7 +26,7 @@ let wp = (ref (None : wp option)) (* -------------------------------------------------------------------- *) type opmatch = [ - | `Op of EcPath.path * EcTypes.ty list + | `Op of EcPath.path * EcTypes.etyarg list | `Lc of EcIdent.t | `Var of EcTypes.prog_var | `Proj of EcTypes.prog_var * EcMemory.proj_arg @@ -114,7 +114,7 @@ type filter_error = type tyerror = | UniVarNotAllowed -| FreeTypeVariables +| FreeUniVariables of EcUnify.uniflags | TypeVarNotAllowed | OnlyMonoTypeAllowed of symbol option | NoConcreteAnonParams @@ -171,6 +171,8 @@ type tyerror = | ModuleNotAbstract of symbol | ProcedureUnbounded of symbol * symbol | LvMapOnNonAssign +| TCArgsCountMismatch of qsymbol * ty_params * ty list +| CannotInferTC of ty * typeclass | NoDefaultMemRestr | ProcAssign of qsymbol | PositiveShouldBeBeforeNegative @@ -201,7 +203,7 @@ let unify_or_fail (env : EcEnv.env) ue loc ~expct:ty1 ty2 = let tyinst = ty_subst (Tuni.subst uidmap) in tyerror loc env (TypeMismatch ((tyinst ty1, tyinst ty2), (tyinst t1, tyinst t2))) - | `TcCtt _ -> + | `TcCtt _ | `TcTw _ -> (* FIXME: proper error message *) tyerror loc env TypeClassMismatch (* -------------------------------------------------------------------- *) @@ -326,7 +328,7 @@ module OpSelect = struct type opsel = [ | `Pv of EcMemory.memory option * pvsel - | `Op of (EcPath.path * ty list) + | `Op of (EcPath.path * etyarg list) | `Lc of EcIdent.ident | `Nt of EcUnify.sbody ] @@ -354,7 +356,7 @@ let gen_select_op let fpv me (pv, ty, ue) : OpSelect.gopsel = (`Pv (me, pv), ty, ue, (pv :> opmatch)) - and fop (op, ty, ue, bd) : OpSelect.gopsel= + and fop ((op : path * etyarg list), ty, ue, bd) : OpSelect.gopsel = match bd with | None -> (`Op op, ty, ue, (`Op op :> opmatch)) | Some bd -> (`Nt bd, ty, ue, (`Op op :> opmatch)) @@ -376,7 +378,7 @@ let gen_select_op and by_tc ((p, _), _, _, _) = match oget (EcEnv.Op.by_path_opt p env) with - | { op_kind = OB_oper (Some OP_TC) } -> false + | { op_kind = OB_oper (Some (OP_TC _)) } -> false | _ -> true in @@ -468,26 +470,6 @@ let tp_uni = { tp_uni = true ; tp_tvar = false; } (* params/local vars. *) (* -------------------------------------------------------------------- *) type ismap = (instr list) Mstr.t -(* -------------------------------------------------------------------- *) -let transtcs (env : EcEnv.env) tcs = - let for1 tc = - match EcEnv.TypeClass.lookup_opt (unloc tc) env with - | None -> tyerror tc.pl_loc env (UnknownTypeClass (unloc tc)) - | Some (p, _) -> p (* FIXME: TC HOOK *) - in - Sp.of_list (List.map for1 tcs) - -(* -------------------------------------------------------------------- *) -let transtyvars (env : EcEnv.env) (loc, tparams) = - let tparams = tparams |> omap - (fun tparams -> - let for1 ({ pl_desc = x }, tc) = (EcIdent.create x, transtcs env tc) in - if not (List.is_unique (List.map (unloc |- fst) tparams)) then - tyerror loc env DuplicatedTyVar; - List.map for1 tparams) - in - EcUnify.UniEnv.create tparams - (* -------------------------------------------------------------------- *) exception TymodCnvFailure of tymod_cnv_failure @@ -976,7 +958,7 @@ let trans_msymbol env msymb = (m,mt) (* -------------------------------------------------------------------- *) -let rec transty (tp : typolicy) (env : EcEnv.env) ue ty = +let rec transty (tp : typolicy) (env : EcEnv.env) (ue : EcUnify.unienv) (ty : pty) : ty = match ty.pl_desc with | PTunivar -> if tp.tp_uni @@ -1035,6 +1017,47 @@ let transty_for_decl env ty = let ue = UE.create (Some []) in transty tp_nothing env ue ty +(* -------------------------------------------------------------------- *) +let transtc (env : EcEnv.env) ue ((tc_name, args) : ptcparam) : typeclass = + match EcEnv.TypeClass.lookup_opt (unloc tc_name) env with + | None -> + tyerror (loc tc_name) env (UnknownTypeClass (unloc tc_name)) + + | Some (p, decl) -> + let args = List.map (transty tp_tydecl env ue) args in + + if List.length decl.tc_tparams <> List.length args then begin + tyerror (loc tc_name) env + (TCArgsCountMismatch (unloc tc_name, decl.tc_tparams, args)); + end; + + let tvi = EcUnify.UniEnv.opentvi ue decl.tc_tparams None in + + (* FIXME:TC can raise an exception *) + List.iter2 + (fun (ty, _) aty -> EcUnify.unify env ue ty aty) + tvi.args args; + + { tc_name = p; tc_args = tvi.args; } + +(* -------------------------------------------------------------------- *) +let transtyvars (env : EcEnv.env) (loc, (tparams : ptyparams option)) = + match tparams with + | None -> + UE.create None + + | Some tparams -> + let ue = UE.create (Some []) in + + let for1 ({ pl_desc = x }, tc) = + let x = EcIdent.create x in + let tc = List.map (transtc env ue) tc in + UE.push (x, tc) ue in + if not (List.is_unique (List.map (unloc |- fst) tparams)) then + tyerror loc env DuplicatedTyVar; + List.iter for1 tparams; + ue + (* -------------------------------------------------------------------- *) let transpattern1 env ue (p : EcParsetree.plpattern) = match p.pl_desc with @@ -1085,7 +1108,8 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let recty = oget (EcEnv.Ty.by_path_opt recp env) in let rec_ = snd (oget (EcDecl.tydecl_as_record recty)) in let reccty = tconstr recp (List.map (tvar |- fst) recty.tyd_params) in - let reccty, rectvi = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in + let reccty, recopnd = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in + let fields = List.fold_left (fun map (((_, idx), _, _) as field) -> @@ -1105,8 +1129,9 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let pty = EcUnify.UniEnv.fresh ue in let fty = snd (List.nth rec_ i) in let fty, _ = - EcUnify.UniEnv.openty ue recty.tyd_params - (Some (EcUnify.TVIunamed rectvi)) fty + EcUnify.UniEnv.openty + ue recty.tyd_params + (Some (EcUnify.tvi_unamed recopnd.args)) fty in (try EcUnify.unify env ue pty fty with EcUnify.UnificationFailure _ -> assert false); @@ -1139,7 +1164,9 @@ let transpattern env ue (p : EcParsetree.plpattern) = let transtvi env ue tvi = match tvi.pl_desc with | TVIunamed lt -> - EcUnify.TVIunamed (List.map (transty tp_relax env ue) lt) + let tys = List.map (transty tp_relax env ue) lt in + let tvi = List.map (fun ty -> (Some ty, None)) tys in + EcUnify.TVIunamed tvi | TVInamed lst -> let add locals (s, t) = @@ -1148,8 +1175,9 @@ let transtvi env ue tvi = (s, transty tp_relax env ue t) :: locals in - let lst = List.fold_left add [] lst in - EcUnify.TVInamed (List.rev_map (fun (s,t) -> unloc s, t) lst) + let tvi = List.fold_left add [] lst in + let tvi = List.map (snd_map (fun ty -> (Some ty, None))) tvi in + EcUnify.TVInamed (List.rev_map (fun (s, t) -> unloc s, t) tvi) let rec destr_tfun env ue tf = match tf.ty_node with @@ -1224,9 +1252,8 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = let recty = oget (EcEnv.Ty.by_path_opt recp env) in let rec_ = snd (oget (EcDecl.tydecl_as_record recty)) in - let reccty = tconstr recp (List.map (tvar |- fst) recty.tyd_params) in - let reccty, rtvi = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in - let tysopn = Tvar.init (List.map fst recty.tyd_params) rtvi in + let reccty = tconstr_tc recp (EcDecl.etyargs_of_tparams recty.tyd_params) in + let reccty, ropnd = EcUnify.UniEnv.openty ue recty.tyd_params None reccty in let fields = List.fold_left @@ -1255,7 +1282,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = | None -> match dflrec with | None -> tyerror loc env (MissingRecField name) - | Some _ -> `Dfl (Tvar.subst tysopn rty, name) + | Some _ -> `Dfl (Tvar.subst ropnd.subst rty, name) in List.mapi (fun i (name, rty) -> get_field i name rty) rec_ in @@ -1271,7 +1298,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = | `Dfl (rty, name) -> let nm = oget (EcPath.prefix recp) in - (proj (nm, name, (rtvi, reccty), rty, oget dflrec), rty) + (proj (nm, name, (ropnd.args, reccty), rty, oget dflrec), rty) in List.map for1 fields @@ -1282,7 +1309,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = (EcPath.prefix recp) (Printf.sprintf "mk_%s" (EcPath.basename recp)) in - (ctor, fields, (rtvi, reccty)) + (ctor, fields, (ropnd.args, reccty)) (* -------------------------------------------------------------------- *) let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = @@ -1321,8 +1348,8 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = EcUnify.UniEnv.restore ~src:subue ~dst:ue; let ctorty = - let tvi = Some (EcUnify.TVIunamed tvi) in - fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in + let tvi = Some (EcUnify.tvi_unamed tvi) in + fst (EcUnify.UniEnv.opentys ue indty.tyd_params tvi ctorty) in let pty = EcUnify.UniEnv.fresh ue in (try EcUnify.unify env ue (toarrow ctorty pty) opty @@ -1379,7 +1406,6 @@ let trans_if_match ~loc env ue (gindty, gind) (c, b1, b2) = gind.tydt_ctors (*-------------------------------------------------------------------- *) - let var_or_proj fvar fproj pv ty = match pv with | `Var pv -> fvar pv ty @@ -1593,7 +1619,7 @@ let form_of_opselect in (f_lambda flam (Fsubst.f_subst subst body), args) | (`Op _ | `Lc _ | `Pv _) as sel -> let op = match sel with - | `Op (p, tys) -> f_op p tys ty + | `Op (p, tys) -> f_op_tc p tys ty | `Lc id -> f_local id ty | `Pv (me, pv) -> var_or_proj (fun x ty -> f_pvar x ty (oget me)) f_proj pv ty @@ -1610,7 +1636,7 @@ let form_of_opselect * - e is the index to update * - ty is the type of the value [x] *) -type lvmap = (path * ty list) * prog_var * expr * ty +type lvmap = (path * etyarg list) * prog_var * expr * ty type lVAl = | Lval of lvalue @@ -1620,7 +1646,7 @@ let i_asgn_lv (_loc : EcLocation.t) (_env : EcEnv.env) lv e = match lv with | Lval lv -> i_asgn (lv, e) | LvMap ((op,tys), x, ei, ty) -> - let op = e_op op tys (toarrow [ty; ei.e_ty; e.e_ty] ty) in + let op = e_op_tc op tys (toarrow [ty; ei.e_ty; e.e_ty] ty) in i_asgn (LvVar (x,ty), e_app op [e_var x ty; ei; e] ty) let i_rnd_lv loc env lv e = @@ -2323,7 +2349,7 @@ and fundef_add_symbol env (memenv : memenv) xtys : memenv = and fundef_check_type subst_uni env os (ty, loc) = let ty = subst_uni ty in - if not (EcUid.Suid.is_empty (Tuni.fv ty)) then + if not (TyUni.Suid.is_empty (Tuni.fv ty)) then tyerror loc env (OnlyMonoTypeAllowed os); ty @@ -3065,12 +3091,12 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt = let (ctor, fields, (rtvi, reccty)) = let proj (recp, name, (rtvi, reccty), pty, arg) = let proj = EcPath.pqname recp name in - let proj = f_op proj rtvi (tfun reccty pty) in + let proj = f_op_tc proj rtvi (tfun reccty pty) in f_app proj [arg] pty in trans_record env ue ((fun f -> let f = transf env f in (f, f.f_ty)), proj) (f.pl_loc, b, fields) in - let ctor = f_op ctor rtvi (toarrow (List.map snd fields) reccty) in + let ctor = f_op_tc ctor rtvi (toarrow (List.map snd fields) reccty) in f_app ctor (List.map fst fields) reccty | PFproj (subf, x) -> begin @@ -3088,7 +3114,7 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt = let rty = EcUnify.UniEnv.fresh ue in (try EcUnify.unify env ue (tfun subf.f_ty rty) pty with EcUnify.UnificationFailure _ -> assert false); - f_app (f_op op tvi pty) [subf] rty + f_app (f_op_tc op tvi pty) [subf] rty end | PFproji (psubf, i) -> begin @@ -3317,15 +3343,21 @@ let trans_dcodepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1 d (* -------------------------------------------------------------------- *) let get_instances (tvi, bty) env = - let inst = List.pmap - (function - | (_, (`Ring _ | `Field _)) as x -> Some x - | _ -> None) - (EcEnv.TypeClass.get_instances env) in + let inst = + let filter ((_, tci) : path option * EcTheory.tcinstance) = + match tci with + | EcTheory.{ + tci_params = []; + tci_instance = (`Ring _ | `Field _) as bd + } -> Some (tci.tci_type, bd) + + | _ -> None + + in List.pmap filter (EcEnv.TcInstance.get_all env) in - List.pmap (fun ((typ, gty), cr) -> + List.pmap (fun (gty, cr) -> let ue = EcUnify.UniEnv.create (Some tvi) in - let (gty, _typ) = EcUnify.UniEnv.openty ue typ None gty in + let (gty, _) = EcUnify.UniEnv.openty ue [] None gty in try EcUnify.unify env ue bty gty; let ts = Tuni.subst (UE.close ue) in diff --git a/src/ecTyping.mli b/src/ecTyping.mli index eb3e48f9f1..da425bf7a8 100644 --- a/src/ecTyping.mli +++ b/src/ecTyping.mli @@ -18,7 +18,7 @@ val wp : wp option ref (* -------------------------------------------------------------------- *) type opmatch = [ - | `Op of EcPath.path * EcTypes.ty list + | `Op of EcPath.path * EcTypes.etyarg list | `Lc of EcIdent.t | `Var of EcTypes.prog_var | `Proj of EcTypes.prog_var * EcMemory.proj_arg @@ -27,7 +27,7 @@ type opmatch = [ type 'a mismatch_sets = [`Eq of 'a * 'a | `Sub of 'a ] -type 'a suboreq = [`Eq of 'a | `Sub of 'a ] +type 'a suboreq = [`Eq of 'a | `Sub of 'a ] type mismatch_funsig = | MF_targs of ty * ty (* expected, got *) @@ -106,7 +106,7 @@ type filter_error = type tyerror = | UniVarNotAllowed -| FreeTypeVariables +| FreeUniVariables of EcUnify.uniflags | TypeVarNotAllowed | OnlyMonoTypeAllowed of symbol option | NoConcreteAnonParams @@ -163,6 +163,8 @@ type tyerror = | ModuleNotAbstract of symbol | ProcedureUnbounded of symbol * symbol | LvMapOnNonAssign +| TCArgsCountMismatch of qsymbol * ty_params * ty list +| CannotInferTC of ty * typeclass | NoDefaultMemRestr | ProcAssign of qsymbol | PositiveShouldBeBeforeNegative @@ -183,6 +185,9 @@ val tp_tydecl : typolicy val tp_relax : typolicy (* -------------------------------------------------------------------- *) +val transtc: + env -> EcUnify.unienv -> ptcparam -> typeclass + val transtyvars: env -> (EcLocation.t * ptyparams option) -> EcUnify.unienv diff --git a/src/ecUid.ml b/src/ecUid.ml index 6e9124b62c..8b4643cfd0 100644 --- a/src/ecUid.ml +++ b/src/ecUid.ml @@ -6,37 +6,84 @@ open EcSymbols (* -------------------------------------------------------------------- *) let unique () = Oo.id (object end) +(* -------------------------------------------------------------------- *) +module type ICore = sig + type uid + + (* ------------------------------------------------------------------ *) + val unique : unit -> uid + val uid_equal : uid -> uid -> bool + val uid_compare : uid -> uid -> int + + (* ------------------------------------------------------------------ *) + module Muid : Map.S with type key = uid + module Suid : Set.S with module M = Map.MakeBase(Muid) + + (* ------------------------------------------------------------------ *) + module SMap : sig + type uidmap + + val create : unit -> uidmap + val lookup : uidmap -> symbol -> uid option + val forsym : uidmap -> symbol -> uid + val pp_uid : Format.formatter -> uid -> unit + end +end + (* -------------------------------------------------------------------- *) type uid = int -type uidmap = { - (*---*) um_tbl : (symbol, uid) Hashtbl.t; - mutable um_uid : int; -} +(* -------------------------------------------------------------------- *) +module Core : ICore with type uid := uid = struct + (* ------------------------------------------------------------------ *) + let unique () : uid = + unique () + + let uid_equal x y = x == y + let uid_compare x y = x - y + + (* ------------------------------------------------------------------ *) + module Muid = Mint + module Suid = Set.MakeOfMap(Muid) + + (* ------------------------------------------------------------------ *) + module SMap = struct + type uidmap = { + (*---*) um_tbl : (symbol, uid) Hashtbl.t; + mutable um_uid : int; + } + + let create () = + { um_tbl = Hashtbl.create 0; + um_uid = 0; } -let create () = - { um_tbl = Hashtbl.create 0; - um_uid = 0; } + let lookup (um : uidmap) (x : symbol) = + try Some (Hashtbl.find um.um_tbl x) + with Not_found -> None -let lookup (um : uidmap) (x : symbol) = - try Some (Hashtbl.find um.um_tbl x) - with Not_found -> None + let forsym (um : uidmap) (x : symbol) = + match lookup um x with + | Some uid -> uid + | None -> + let uid = um.um_uid in + um.um_uid <- um.um_uid + 1; + Hashtbl.add um.um_tbl x uid; + uid -let forsym (um : uidmap) (x : symbol) = - match lookup um x with - | Some uid -> uid - | None -> - let uid = um.um_uid in - um.um_uid <- um.um_uid + 1; - Hashtbl.add um.um_tbl x uid; - uid + let pp_uid fmt u = + Format.fprintf fmt "#%d" u + end +end (* -------------------------------------------------------------------- *) -let uid_equal x y = x == y -let uid_compare x y = x - y +module CoreGen() : ICore with type uid = private uid = struct + type nonrec uid = uid + + include Core +end -module Muid = Mint -module Suid = Set.MakeOfMap(Muid) +(* -------------------------------------------------------------------- *) +include Core (* -------------------------------------------------------------------- *) module NameGen = struct diff --git a/src/ecUid.mli b/src/ecUid.mli index 885bcbd99f..429132eef9 100644 --- a/src/ecUid.mli +++ b/src/ecUid.mli @@ -5,20 +5,37 @@ open EcSymbols (* -------------------------------------------------------------------- *) val unique : unit -> int +module type ICore = sig + type uid + + (* ------------------------------------------------------------------ *) + val unique : unit -> uid + val uid_equal : uid -> uid -> bool + val uid_compare : uid -> uid -> int + + (* ------------------------------------------------------------------ *) + module Muid : Map.S with type key = uid + module Suid : Set.S with module M = Map.MakeBase(Muid) + + (* ------------------------------------------------------------------ *) + module SMap : sig + type uidmap + + val create : unit -> uidmap + val lookup : uidmap -> symbol -> uid option + val forsym : uidmap -> symbol -> uid + val pp_uid : Format.formatter -> uid -> unit + end +end + (* -------------------------------------------------------------------- *) type uid = int -type uidmap - -val create : unit -> uidmap -val lookup : uidmap -> symbol -> uid option -val forsym : uidmap -> symbol -> uid (* -------------------------------------------------------------------- *) -val uid_equal : uid -> uid -> bool -val uid_compare : uid -> uid -> int +include ICore with type uid := uid -module Muid : Map.S with type key = uid -module Suid : Set.S with module M = Map.MakeBase(Muid) +(* -------------------------------------------------------------------- *) +module CoreGen() : ICore with type uid = private uid (* -------------------------------------------------------------------- *) module NameGen : sig diff --git a/src/ecUnify.ml b/src/ecUnify.ml index cd557aadef..f092b79d8a 100644 --- a/src/ecUnify.ml +++ b/src/ecUnify.ml @@ -3,287 +3,448 @@ open EcSymbols open EcIdent open EcMaps open EcUtils -open EcUid open EcAst open EcTypes open EcCoreSubst open EcDecl module Sp = EcPath.Sp -module TC = EcTypeClass -(* -------------------------------------------------------------------- *) -exception UnificationFailure of [`TyUni of ty * ty | `TcCtt of ty * Sp.t] -exception UninstanciateUni +(* ==================================================================== *) +type problem = [ + | `TyUni of ty * ty + | `TcTw of tcwitness * tcwitness + | `TcCtt of tcuni * ty * typeclass +] -(* -------------------------------------------------------------------- *) -type pb = [ `TyUni of ty * ty | `TcCtt of ty * Sp.t ] +(* ==================================================================== *) +type uniflags = { tyvars: bool; tcvars: bool; } -module UFArgs = struct - module I = struct - type t = uid +exception UnificationFailure of problem +exception UninstanciateUni of uniflags - let equal = uid_equal - let compare = uid_compare - end +(* ==================================================================== *) +module Unify = struct + module UFArgs = struct + module I = struct + type t = tyuni + + let equal = TyUni.uid_equal + let compare = TyUni.uid_compare + end + + module D = struct + type data = ty option + type effects = problem list - module D = struct - type data = Sp.t * ty option - type effects = pb list + let default : data = + None - let default : data = - (Sp.empty, None) + let isvoid (x : data) = + Option.is_none x - let isvoid ((_, x) : data) = - (x = None) + let noeffects : effects = [] - let noeffects : effects = [] + let union (ty1 : data) (ty2 : data) : data * effects = + let ty, cts = + match ty1, ty2 with + | None, None -> + (None, []) + | Some ty1, Some ty2 -> + Some ty1, [(ty1, ty2)] - let union d1 d2 = - match d1, d2 with - | (tc1, None), (tc2, None) -> - ((Sp.union tc1 tc2, None), []) + | None, Some ty | Some ty, None -> + Some ty, [] in - | (tc1, Some ty1), (tc2, Some ty2) -> - ((Sp.union tc1 tc2, Some ty1), [`TyUni (ty1, ty2)]) + let cts = List.map (fun x -> `TyUni x) cts in - | (tc1, None ), (tc2, Some ty) - | (tc2, Some ty), (tc1, None ) -> - let tc = Sp.diff tc1 tc2 in - if Sp.is_empty tc - then ((Sp.union tc1 tc2, Some ty), []) - else ((Sp.union tc1 tc2, Some ty), [`TcCtt (ty, tc)]) + ty, (cts :> effects) + end end -end -module UF = EcUFind.Make(UFArgs.I)(UFArgs.D) + (* ------------------------------------------------------------------ *) + module UF = EcUFind.Make(UFArgs.I)(UFArgs.D) + + (* ------------------------------------------------------------------ *) + type ucore = { + uf : UF.t; + tvtc : typeclass list Mid.t; + tcenv : tcenv; + } + + and tcenv = { + (* Map from UID to TC problems. *) + problems : (ty * typeclass) TcUni.Muid.t; + + (* Map from univars to TC problems that depend on them. *) + byunivar : TcUni.Suid.t TyUni.Muid.t; + + (* Map from problems UID to type-class instance witness *) + resolution : tcwitness TcUni.Muid.t + } + + (* ------------------------------------------------------------------ *) + let tcenv_empty : tcenv = + { problems = TcUni.Muid.empty + ; byunivar = TyUni.Muid.empty + ; resolution = TcUni.Muid.empty } + + (* ------------------------------------------------------------------ *) + let tcenv_closed (tcenv : tcenv) : bool = (* FIXME:TC *) + TcUni.Muid.cardinal tcenv.resolution + = TcUni.Muid.cardinal tcenv.problems + + (* ------------------------------------------------------------------ *) + let create_tcproblem + (tcenv : tcenv) + (ty : ty) + (tcw : typeclass * tcwitness option) + : tcenv * tcwitness + = + let tc, tw = tcw in + let uid = TcUni.unique () in + let deps = Tuni.univars ty in (* FIXME:TC *) + + let tcenv = { + problems = TcUni.Muid.add uid (ty, tc) tcenv.problems; + byunivar = TyUni.Suid.fold (fun duni byunivar -> + TyUni.Muid.change (fun pbs -> + Some (TcUni.Suid.add uid (Option.value ~default:TcUni.Suid.empty pbs)) + ) duni byunivar + ) deps tcenv.byunivar; + resolution = + ofold + (fun tw map -> TcUni.Muid.add uid tw map) + tcenv.resolution tw; + } in + + tcenv, TCIUni uid + + (* ------------------------------------------------------------------ *) + let initial_ucore ?(tvtc = Mid.empty) () : ucore = + { uf = UF.initial; tcenv = tcenv_empty; tvtc; } + + (* ------------------------------------------------------------------ *) + let fresh + ?(tcs : (typeclass * tcwitness option) list option) + ?(ty : ty option) + ({ uf; tcenv } as uc : ucore) + = + let uid = TyUni.unique () in -(* -------------------------------------------------------------------- *) -module UnifyCore = struct - let fresh ?(tc = Sp.empty) ?ty uf = - let uid = EcUid.unique () in let uf = match ty with | Some { ty_node = Tunivar id } -> - let uf = UF.set uid (tc, None) uf in - fst (UF.union uid id uf) - | None | Some _ -> UF.set uid (tc, ty) uf + let uf = UF.set uid None uf in + let ty, effects = UF.union uid id uf in + assert (List.is_empty effects); + ty + + | (None | Some _) as ty -> + UF.set uid ty uf in - (uf, tuni uid) -end -(* -------------------------------------------------------------------- *) -let rec unify_core (env : EcEnv.env) (tvtc : Sp.t Mid.t) (uf : UF.t) pb = - let failure () = raise (UnificationFailure pb) in + let ty = Option.value ~default:(tuni uid) (UF.data uid uf) in - let gr = EcEnv.TypeClass.graph env in - let inst = EcEnv.TypeClass.get_instances env in + let tcenv, tws = + List.fold_left_map + (fun tcenv tcw -> create_tcproblem tcenv ty tcw) + tcenv (Option.value ~default:[] tcs) in - let uf = ref uf in - let pb = let x = Queue.create () in Queue.push pb x; x in + ({ uc with uf; tcenv; }, (tuni uid, tws)) - let instances_for_tcs tcs = - let tcfilter (i, tc) = - match tc with `General p -> Some (i, p) | _ -> None - in - List.filter - (fun (_, tc1) -> - Sp.for_all - (fun tc2 -> TC.Graph.has_path ~src:tc1 ~dst:tc2 gr) - tcs) - (List.pmap tcfilter inst) - in + (* ------------------------------------------------------------------ *) + let unify_core (env : EcEnv.env) (uc : ucore) (pb : problem) : ucore = + let failure () = raise (UnificationFailure pb) in - let has_tcs ~src ~dst = - Sp.for_all - (fun dst1 -> - Sp.exists - (fun src1 -> TC.Graph.has_path ~src:src1 ~dst:dst1 gr) - src) - dst - in + let uc = ref uc in + let pb = let x = Queue.create () in Queue.push pb x; x in - let ocheck i t = - let i = UF.find i !uf in - let map = Hint.create 0 in + let ocheck i t = + let i = UF.find i (!uc).uf in + let map = Hint.create 0 in - let rec doit t = - match t.ty_node with - | Tunivar i' -> begin - let i' = UF.find i' !uf in + let rec doit t = + match t.ty_node with + | Tunivar i' -> begin + let i' = UF.find i' (!uc).uf in match i' with | _ when i = i' -> true - | _ when Hint.mem map i' -> false + | _ when Hint.mem map (i' :> int) -> false | _ -> - match snd (UF.data i' !uf) with - | None -> Hint.add map i' (); false + match UF.data i' (!uc).uf with + | None -> Hint.add map (i' :> int) (); false | Some t -> match doit t with | true -> true - | false -> Hint.add map i' (); false - end + | false -> Hint.add map (i' :> int) (); false + end - | _ -> EcTypes.ty_sub_exists doit t + | _ -> EcTypes.ty_sub_exists doit t + in + doit t in - doit t - in - - let setvar i t = - let (ti, effects) = UFArgs.D.union (UF.data i !uf) (Sp.empty, Some t) in - if odfl false (snd ti |> omap (ocheck i)) then failure (); - List.iter (Queue.push^~ pb) effects; - uf := UF.set i ti !uf - and getvar t = - match t.ty_node with - | Tunivar i -> snd_map (odfl t) (UF.data i !uf) - | _ -> (Sp.empty, t) - - in + let setvar (i : tyuni) (t : ty) = + let (ti, effects) = + UFArgs.D.union (UF.data i (!uc).uf) (Some t) + in + if odfl false (ti |> omap (ocheck i)) then failure (); + List.iter (Queue.push^~ pb) effects; + + begin + (* FIXME:TC (cache!)*) + match TyUni.Muid.find i (!uc).tcenv.byunivar with + | tcpbs -> + uc := { !uc with tcenv = { (!uc).tcenv with + byunivar = TyUni.Muid.remove i (!uc).tcenv.byunivar + } }; + let tcpbs = TcUni.Suid.elements tcpbs in + let tcpbs = List.map (fun uid -> + let pb = TcUni.Muid.find uid (!uc).tcenv.problems in + (uid, pb) + ) tcpbs in + List.iter (fun (uid, (ty, tc)) -> Queue.push (`TcCtt (uid, ty, tc)) pb) tcpbs + + | exception Not_found -> () + end; + + uc := { !uc with uf = UF.set i ti (!uc).uf } + + and getvar t = + match t.ty_node with + | Tunivar i -> odfl t (UF.data i (!uc).uf) + | _ -> t + in - let doit () = - while not (Queue.is_empty pb) do - match Queue.pop pb with - | `TyUni (t1, t2) -> begin - let (t1, t2) = (snd (getvar t1), snd (getvar t2)) in - - match ty_equal t1 t2 with - | true -> () - | false -> begin - match t1.ty_node, t2.ty_node with - | Tunivar id1, Tunivar id2 -> begin - if not (uid_equal id1 id2) then - let effects = reffold (swap |- UF.union id1 id2) uf in + let doit () = + while not (Queue.is_empty pb) do + match Queue.pop pb with + | `TyUni (t1, t2) -> begin + let (t1, t2) = (getvar t1, getvar t2) in + + match ty_equal t1 t2 with + | true -> () + | false -> begin + match t1.ty_node, t2.ty_node with + | Tunivar id1, Tunivar id2 -> begin + if not (TyUni.uid_equal id1 id2) then + let effects = + reffold (fun uc -> + let uf, effects = UF.union id1 id2 uc.uf in + effects, { uc with uf } + ) uc in List.iter (Queue.push^~ pb) effects - end + end - | Tunivar id, _ -> setvar id t2 - | _, Tunivar id -> setvar id t1 + | Tunivar id, _ -> setvar id t2 + | _, Tunivar id -> setvar id t1 - | Ttuple lt1, Ttuple lt2 -> - if List.length lt1 <> List.length lt2 then failure (); - List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 + | Ttuple lt1, Ttuple lt2 -> + if List.length lt1 <> List.length lt2 then failure (); + List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 - | Tfun (t1, t2), Tfun (t1', t2') -> - Queue.push (`TyUni (t1, t1')) pb; - Queue.push (`TyUni (t2, t2')) pb + | Tfun (t1, t2), Tfun (t1', t2') -> + Queue.push (`TyUni (t1, t1')) pb; + Queue.push (`TyUni (t2, t2')) pb - | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> + | Tconstr (p1, lt1), Tconstr (p2, lt2) when EcPath.p_equal p1 p2 -> if List.length lt1 <> List.length lt2 then failure (); - List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) lt1 lt2 - | Tconstr (p, lt), _ when EcEnv.Ty.defined p env -> - Queue.push (`TyUni (EcEnv.Ty.unfold p lt env, t2)) pb + let ty1, tws1 = List.split lt1 in + let ty2, tws2 = List.split lt2 in - | _, Tconstr (p, lt) when EcEnv.Ty.defined p env -> - Queue.push (`TyUni (t1, EcEnv.Ty.unfold p lt env)) pb + List.iter2 (fun t1 t2 -> Queue.push (`TyUni (t1, t2)) pb) ty1 ty2; - | _, _ -> failure () - end - end + List.iter2 (fun tw1 tw2 -> + if List.length tw1 <> List.length tw2 then failure (); + List.iter2 (fun w1 w2 -> Queue.push (`TcTw (w1, w2)) pb) tw1 tw2 + ) tws1 tws2 - | `TcCtt (ty, tc) -> begin - let tytc, ty = getvar ty in - - match ty.ty_node with - | Tunivar i -> - uf := UF.set i (Sp.union tc tytc, None) !uf - - | Tvar x -> - let xtcs = odfl Sp.empty (Mid.find_opt x tvtc) in - if not (has_tcs ~src:xtcs ~dst:tc) then - failure () - - | _ -> - if not (has_tcs ~src:tytc ~dst:tc) then - let module E = struct exception Failure end in - - let inst = instances_for_tcs tc in - - let for1 uf p = - let for_inst ((typ, gty), p') = - try - if not (TC.Graph.has_path ~src:p' ~dst:p gr) then - raise E.Failure; - let (uf, gty) = - let (uf, subst) = - List.fold_left - (fun (uf, s) (v, tc) -> - let (uf, uid) = UnifyCore.fresh ~tc uf in - (uf, Mid.add v uid s)) - (uf, Mid.empty) typ - in - (uf, Tvar.subst subst gty) - in - try Some (unify_core env tvtc uf (`TyUni (gty, ty))) - with UnificationFailure _ -> raise E.Failure - with E.Failure -> None - in - try List.find_map for_inst inst - with Not_found -> failure () - in - uf := List.fold_left for1 !uf (Sp.elements tc) - end - done - in - doit (); !uf + | Tconstr (p, lt), _ when EcEnv.Ty.defined p env -> + Queue.push (`TyUni (EcEnv.Ty.unfold p lt env, t2)) pb -(* -------------------------------------------------------------------- *) -let close (uf : UF.t) = - let map = Hint.create 0 in + | _, Tconstr (p, lt) when EcEnv.Ty.defined p env -> + Queue.push (`TyUni (t1, EcEnv.Ty.unfold p lt env)) pb - let rec doit t = - match t.ty_node with - | Tunivar i -> begin - match Hint.find_opt map i with - | Some t -> t - | None -> begin - let t = - match snd (UF.data i uf) with - | None -> tuni (UF.find i uf) - | Some t -> doit t - in - Hint.add map i t; t - end - end + | _, _ -> failure () + end + end - | _ -> ty_map doit t - in - fun t -> doit t + | `TcCtt (uid, ty, tc) -> + if not (List.is_empty tc.tc_args) then + failure (); + + let deps = ref TyUni.Suid.empty in + + let rec check (ty : ty) : ty = + match ty.ty_node with + | Tunivar tyuvar -> begin + match UF.data tyuvar (!uc).uf with + | None -> + deps := TyUni.Suid.add tyuvar !deps; + ty + | Some ty -> + check ty + end + | _ -> ty_map check ty in + + let ty = check ty in + let deps = !deps in + + if TyUni.Suid.is_empty deps then begin + match ty.ty_node with + | Tvar a -> + let tcs = ofdfl failure (Mid.find_opt a (!uc).tvtc) in + let idx = + let eq (tc' : typeclass) = + EcPath.p_equal tc.tc_name tc'.tc_name + && List.for_all2 (EcCoreEqTest.for_etyarg env) tc.tc_args tc'.tc_args in + ofdfl failure (List.find_index eq tcs) in + + uc := { !uc with tcenv = { (!uc).tcenv with resolution = + TcUni.Muid.add + uid + (TCIAbstract { support = `Var a; offset = idx; }) + (!uc).tcenv.resolution + } } + + | _-> + let tci = ofdfl failure (EcTypeClass.infer env ty tc) in + uc := { !uc with tcenv = { (!uc).tcenv with resolution = + TcUni.Muid.add uid tci (!uc).tcenv.resolution + } } + end else begin + TyUni.Suid.iter (fun tyvar -> + uc := { !uc with tcenv = { (!uc).tcenv with byunivar = + TyUni.Muid.change (fun map -> + let map = Option.value ~default:TcUni.Suid.empty map in + Some (TcUni.Suid.add uid map) + ) tyvar (!uc).tcenv.byunivar + } } + ) deps + end + + | _ -> + () (* FIXME:TC *) + done + in + doit (); !uc + (* -------------------------------------------------------------------- *) + type closed = { tyuni : ty -> ty; tcuni : tcwitness -> tcwitness; } + (* -------------------------------------------------------------------- *) + let close (uc : ucore) : closed = + let tymap = Hint.create 0 in + let tcmap = Hint.create 0 in -(* -------------------------------------------------------------------- *) -let subst_of_uf (uf : UF.t) = - let close = close uf in - let uids = UF.domain uf in - List.fold_left - (fun m uid -> - match close (tuni uid) with - | { ty_node = Tunivar uid' } when uid_equal uid uid' -> m - | t -> Muid.add uid t m - ) - Muid.empty - uids + let rec doit_ty t = + match t.ty_node with + | Tunivar i -> begin + match Hint.find_opt tymap (i :> int) with + | Some t -> t + | None -> begin + let t = + match UF.data i uc.uf with + | None -> tuni (UF.find i uc.uf) + | Some t -> doit_ty t + in + Hint.add tymap (i :> int) t; t + end + end + + | _ -> ty_map doit_ty t + + and doit_tc (tw : tcwitness) = + match tw with + | TCIUni uid -> begin + match Hint.find_opt tcmap (uid :> int) with + | Some tw -> tw + | None -> + let tw = + match TcUni.Muid.find_opt uid uc.tcenv.resolution with + | None -> tw + | Some (TCIUni uid') when TcUni.uid_equal uid uid' -> tw (* FIXME:TC *) + | Some tw -> doit_tc tw + in + Hint.add tcmap (uid :> int) tw; tw + end + | TCIConcrete { path; etyargs } -> + let etyargs = + List.map + (fun (ty, tws) -> (doit_ty ty, List.map doit_tc tws)) + etyargs + in TCIConcrete { path; etyargs; } + + | TCIAbstract { support = (`Var _ | `Abs _) } -> + tw + + in { tyuni = doit_ty; tcuni = doit_tc; } + + (* ------------------------------------------------------------------ *) + let subst_of_uf (uc : ucore) : unisubst = + let close = close uc in + + let dereference_tyuni (uid : tyuni) = + match close.tyuni (tuni uid) with + | { ty_node = Tunivar uid' } when TyUni.uid_equal uid uid' -> None + | ty -> Some ty in + + let dereference_tcuni (uid : tcuni) = + match close.tcuni (TCIUni uid) with + | TCIUni uid' when TcUni.uid_equal uid uid' -> None + | tw -> Some tw in + + let uvars = + let bindings = + List.filter_map (fun uid -> + Option.map (fun ty -> (uid, ty)) (dereference_tyuni uid) + ) (UF.domain uc.uf) in + TyUni.Muid.of_list bindings in + + let utcvars = + let bindings = + List.filter_map (fun uid -> + Option.map (fun tw -> (uid, tw)) (dereference_tcuni uid) + ) (TcUni.Muid.keys uc.tcenv.problems) in + TcUni.Muid.of_list bindings in + + { uvars; utcvars; } + + (* -------------------------------------------------------------------- *) + let check_closed (uc : ucore) = + let tyvars = not (UF.closed uc.uf) in + let tcvars = not (tcenv_closed uc.tcenv) in + + if tyvars || tcvars then + raise (UninstanciateUni { tyvars; tcvars }) +end (* -------------------------------------------------------------------- *) type unienv_r = { - ue_uf : UF.t; + ue_uc : Unify.ucore; ue_named : EcIdent.t Mstr.t; - ue_tvtc : Sp.t Mid.t; ue_decl : EcIdent.t list; ue_closed : bool; } type unienv = unienv_r ref +type petyarg = ty option * tcwitness option list option + type tvar_inst = -| TVIunamed of ty list -| TVInamed of (EcSymbols.symbol * ty) list +| TVIunamed of petyarg list +| TVInamed of (EcSymbols.symbol * petyarg) list type tvi = tvar_inst option -type uidmap = uid -> ty option + +let tvi_unamed (ety : etyarg list) : tvar_inst = + TVIunamed (List.map + (fun (ty, tcw) -> Some ty, Some (List.map Option.some tcw)) + ety + ) module UniEnv = struct let copy (ue : unienv) : unienv = @@ -292,7 +453,7 @@ module UniEnv = struct let restore ~(dst:unienv) ~(src:unienv) = dst := !src - let getnamed ue x = + let getnamed (ue : unienv) (x : symbol) = match Mstr.find_opt x (!ue).ue_named with | Some a -> a | None -> begin @@ -304,107 +465,190 @@ module UniEnv = struct }; id end - let create (vd : (EcIdent.t * Sp.t) list option) = - let ue = { - ue_uf = UF.initial; - ue_named = Mstr.empty; - ue_tvtc = Mid.empty; - ue_decl = []; - ue_closed = false; - } in - + let create (vd : (EcIdent.t * typeclass list) list option) : unienv = let ue = match vd with - | None -> ue + | None -> + { ue_uc = Unify.initial_ucore () + ; ue_named = Mstr.empty + ; ue_decl = [] + ; ue_closed = false + } + | Some vd -> let vdmap = List.map (fun (x, _) -> (EcIdent.name x, x)) vd in - { ue with - ue_named = Mstr.of_list vdmap; - ue_tvtc = Mid.of_list vd; - ue_decl = List.rev_map fst vd; - ue_closed = true; } + let tvtc = Mid.of_list vd in + { ue_uc = Unify.initial_ucore ~tvtc () + ; ue_named = Mstr.of_list vdmap + ; ue_decl = List.rev_map fst vd + ; ue_closed = true; + } + in ref ue + + let push ((x, tc) : ident * typeclass list) (ue : unienv) = + assert (not (Mstr.mem (EcIdent.name x) (!ue).ue_named)); + assert ((!ue).ue_closed); + + (* FIXME:TC use API for pushing a variable*) + ue := + { ue_uc = { (!ue).ue_uc with tvtc = Mid.add x tc (!ue).ue_uc.tvtc } + ; ue_named = Mstr.add (EcIdent.name x) x (!ue).ue_named + ; ue_decl = x :: (!ue).ue_decl + ; ue_closed = true } + + let xfresh + ?(tcs : (typeclass * tcwitness option) list option) + ?(ty : ty option) + (ue : unienv) + = + let (uc, tytw) = Unify.fresh ?tcs ?ty (!ue).ue_uc in + ue := { !ue with ue_uc = uc }; tytw + + let fresh ?(ty : ty option) (ue : unienv) = + let (uc, (ty, tw)) = Unify.fresh ?ty (!ue).ue_uc in + assert (List.is_empty tw); + ue := { !ue with ue_uc = uc }; ty + + type opened = { + subst : etyarg Mid.t; + params : (ty * typeclass list) list; + args : etyarg list; + } + + let subst_tv (subst : etyarg Mid.t) (params : ty_params) = + List.map (fun (tv, tcs) -> + let tv = Tvar.subst subst (tvar tv) in + let tcs = + List.map + (fun tc -> + let tc_args = + List.map (Tvar.subst_etyarg subst) tc.tc_args + in { tc with tc_args }) + tcs + in (tv, tcs)) params + + let opentvi (ue : unienv) (params : ty_params) (tvi : tvi) : opened = + let tvi = + match tvi with + | None -> + List.map (fun (v, tcs) -> + (v, (None, List.map (fun x -> (x, None)) tcs)) + ) params + + | Some (TVIunamed lt) -> + let combine (v, tc) (ty, tcw) = + let tctcw = + match tcw with + | None -> + List.map (fun tc -> (tc, None)) tc + | Some tcw -> + List.combine tc tcw + in (v, (ty, tctcw)) in + + List.map2 combine params lt + + | Some (TVInamed lt) -> + List.map (fun (v, tc) -> + let ty, tcw = + List.assoc_opt (EcIdent.name v) lt + |> Option.value ~default:(None, None) in + + let tcw = + match tcw with + | None -> + List.map (fun _ -> None) tc + | Some tcw -> + tcw in + + (v, (ty, List.map2 (fun x y -> (x, y)) tc tcw)) + ) params in - ref ue - - let fresh ?tc ?ty ue = - let (uf, uid) = UnifyCore.fresh ?tc ?ty (!ue).ue_uf in - ue := { !ue with ue_uf = uf }; uid - - let opentvi ue (params : ty_params) tvi = - match tvi with - | None -> - List.fold_left - (fun s (v, tc) -> Mid.add v (fresh ~tc ue) s) - Mid.empty params - - | Some (TVIunamed lt) -> - List.fold_left2 - (fun s (v, tc) ty -> Mid.add v (fresh ~tc ~ty ue) s) - Mid.empty params lt - - | Some (TVInamed lt) -> - let for1 s (v, tc) = - let t = - try fresh ~tc ~ty:(List.assoc (EcIdent.name v) lt) ue - with Not_found -> fresh ~tc ue - in - Mid.add v t s - in - List.fold_left for1 Mid.empty params - let subst_tv subst params = - List.map (fun (tv, _) -> subst (tvar tv)) params + let subst = + List.fold_left (fun s (v, (ty, tcws)) -> + let tcs = + let for1 (tc, tcw) = + let tc = + { tc_name = tc.tc_name; + tc_args = List.map (Tvar.subst_etyarg s) tc.tc_args } in + (tc, tcw) + in List.map for1 tcws + in Mid.add v (xfresh ?ty ~tcs ue) s + ) Mid.empty tvi in + + let args = List.map (fun (x, _) -> oget (Mid.find_opt x subst)) params in + let params = subst_tv subst params in - let openty_r ue params tvi = - let subst = f_subst_init ~tv:(opentvi ue params tvi) () in - (subst, subst_tv (ty_subst subst) params) + { subst; args; params; } - let opentys ue params tvi tys = - let (subst, tvs) = openty_r ue params tvi in - (List.map (ty_subst subst) tys, tvs) + let opentys (ue : unienv) (params : ty_params) (tvi : tvi) (tys : ty list) = + let opened = opentvi ue params tvi in + let tys = List.map (Tvar.subst opened.subst) tys in + tys, opened - let openty ue params tvi ty = - let (subst, tvs) = openty_r ue params tvi in - (ty_subst subst ty, tvs) + let openty (ue : unienv) (params : ty_params) (tvi : tvi) (ty : ty) = + let opened = opentvi ue params tvi in + Tvar.subst opened.subst ty, opened let repr (ue : unienv) (t : ty) : ty = match t.ty_node with - | Tunivar id -> odfl t (snd (UF.data id (!ue).ue_uf)) + | Tunivar id -> odfl t (Unify.UF.data id (!ue).ue_uc.uf) | _ -> t - let closed (ue : unienv) = - UF.closed (!ue).ue_uf + let xclosed (ue : unienv) = + try Unify.check_closed (!ue).ue_uc; None + with UninstanciateUni infos -> Some infos - let close (ue : unienv) = - if not (closed ue) then raise UninstanciateUni; - (subst_of_uf (!ue).ue_uf) + let closed (ue : unienv) = + Option.is_none (xclosed ue) - let assubst ue = subst_of_uf (!ue).ue_uf + let assubst (ue : unienv) : unisubst = + Unify.subst_of_uf (!ue).ue_uc - let tparams ue = - let fortv x = odfl Sp.empty (Mid.find_opt x (!ue).ue_tvtc) in - List.map (fun x -> (x, fortv x)) (List.rev (!ue).ue_decl) + let close (ue : unienv) = + Unify.check_closed (!ue).ue_uc; + assubst ue + + let tparams (ue : unienv) = + let subst = EcCoreSubst.f_subst_init ~tu:(assubst ue) () in + let fortv x = + let tvtc = odfl [] (Mid.find_opt x (!ue).ue_uc.tvtc) in + List.map (EcCoreSubst.tc_subst subst) tvtc in + List.map (fun x -> (x, fortv x)) (List.rev (!ue).ue_decl) end (* -------------------------------------------------------------------- *) -let unify env ue t1 t2 = - let uf = unify_core env (!ue).ue_tvtc (!ue).ue_uf (`TyUni (t1, t2)) in - ue := { !ue with ue_uf = uf; } +let unify_core (env : EcEnv.env) (ue : unienv) (pb : problem) = + let uc = Unify.unify_core env (!ue).ue_uc pb in + ue := { !ue with ue_uc = uc; } -let hastc env ue ty tc = - let uf = unify_core env (!ue).ue_tvtc (!ue).ue_uf (`TcCtt (ty, tc)) in - ue := { !ue with ue_uf = uf; } +(* -------------------------------------------------------------------- *) +let unify (env : EcEnv.env) (ue : unienv) (t1 : ty) (t2 : ty) = + unify_core env ue (`TyUni (t1, t2)) (* -------------------------------------------------------------------- *) -let tfun_expected ue psig = - let tres = UniEnv.fresh ue in - EcTypes.toarrow psig tres +let tfun_expected (ue : unienv) (psig : ty list) = + EcTypes.toarrow psig (UniEnv.fresh ue) (* -------------------------------------------------------------------- *) type sbody = ((EcIdent.t * ty) list * expr) Lazy.t (* -------------------------------------------------------------------- *) -let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig = +type select_filter_t = EcPath.path -> operator -> bool + +type select_t = + ((EcPath.path * etyarg list) * ty * unienv * sbody option) list + +let select_op + ?(hidden : bool = false) + ?(filter : select_filter_t = fun _ _ -> true) + (tvi : tvi) + (env : EcEnv.env) + (name : qsymbol) + (ue : unienv) + (psig : dom) + : select_t += ignore hidden; (* FIXME *) let module D = EcDecl in @@ -419,12 +663,12 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig let len = List.length lt in fun op -> let tparams = op.D.op_tparams in - List.length tparams = len + List.length tparams = len | Some (TVInamed ls) -> fun op -> let tparams = List.map (fst_map EcIdent.name) op.D.op_tparams in let tparams = Msym.of_list tparams in - List.for_all (fun (x, _) -> Msym.mem x tparams) ls + List.for_all (fun (x, _) -> Msym.mem x tparams) ls in filter oppath op && filter_on_tvi op @@ -436,27 +680,19 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig let subue = UniEnv.copy ue in try - begin try - match tvi with - | None -> - () - - | Some (TVIunamed lt) -> - List.iter2 - (fun ty (_, tc) -> hastc env subue ty tc) - lt op.D.op_tparams - - | Some (TVInamed ls) -> - let tparams = List.map (fst_map EcIdent.name) op.D.op_tparams in - let tparams = Msym.of_list tparams in - List.iter (fun (x, ty) -> - hastc env subue ty (oget (Msym.find_opt x tparams))) - ls - with UnificationFailure _ -> raise E.Failure - end; - - let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in - let top = ty_subst tip op.D.op_ty in + let UniEnv.{ subst = tip; args } = + UniEnv.opentvi subue op.D.op_tparams tvi in + let tip = f_subst_init ~tv:tip () in + + (* + List.iter + (fun (tv, tcs) -> + try hastcs_r env subue tv tcs + with UnificationFailure _ -> raise E.Failure) + tvtcs; + *) + + let top = EcCoreSubst.ty_subst tip op.D.op_ty in let texpected = tfun_expected subue psig in (try unify env subue top texpected @@ -473,8 +709,9 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig in Some (Lazy.from_fun substnt) | _ -> None + in - in Some ((path, tvs), top, subue, bd) + Some ((path, args), top, subue, bd) (* FIXME:TC *) with E.Failure -> None diff --git a/src/ecUnify.mli b/src/ecUnify.mli index 2c7fbdb1a1..92f81fde77 100644 --- a/src/ecUnify.mli +++ b/src/ecUnify.mli @@ -1,41 +1,60 @@ (* -------------------------------------------------------------------- *) -open EcUid +open EcIdent open EcSymbols -open EcPath open EcTypes +open EcAst open EcDecl -(* -------------------------------------------------------------------- *) -exception UnificationFailure of [`TyUni of ty * ty | `TcCtt of ty * Sp.t] -exception UninstanciateUni +(* ==================================================================== *) +type problem = [ + | `TyUni of ty * ty + | `TcTw of tcwitness * tcwitness + | `TcCtt of EcAst.tcuni * ty * typeclass +] + +type uniflags = { tyvars: bool; tcvars: bool; } + +exception UnificationFailure of problem +exception UninstanciateUni of uniflags type unienv +type petyarg = ty option * tcwitness option list option + type tvar_inst = -| TVIunamed of ty list -| TVInamed of (EcSymbols.symbol * ty) list +| TVIunamed of petyarg list +| TVInamed of (EcSymbols.symbol * petyarg) list type tvi = tvar_inst option -type uidmap = uid -> ty option + +val tvi_unamed : etyarg list -> tvar_inst module UniEnv : sig - val create : (EcIdent.t * Sp.t) list option -> unienv + type opened = { + subst : etyarg Mid.t; + params : (ty * typeclass list) list; + args : etyarg list; + } + + val create : (EcIdent.t * typeclass list) list option -> unienv + val push : (EcIdent.t * typeclass list) -> unienv -> unit val copy : unienv -> unienv (* constant time *) val restore : dst:unienv -> src:unienv -> unit (* constant time *) - val fresh : ?tc:EcPath.Sp.t -> ?ty:ty -> unienv -> ty + val xfresh : ?tcs:(typeclass * EcTypes.tcwitness option) list -> ?ty:ty -> unienv -> etyarg + val fresh : ?ty:ty -> unienv -> ty val getnamed : unienv -> symbol -> EcIdent.t val repr : unienv -> ty -> ty - val opentvi : unienv -> ty_params -> tvi -> ty EcIdent.Mid.t - val openty : unienv -> ty_params -> tvi -> ty -> ty * ty list - val opentys : unienv -> ty_params -> tvi -> ty list -> ty list * ty list + val opentvi : unienv -> ty_params -> tvi -> opened + val openty : unienv -> ty_params -> tvi -> ty -> ty * opened + val opentys : unienv -> ty_params -> tvi -> ty list -> ty list * opened val closed : unienv -> bool - val close : unienv -> ty Muid.t - val assubst : unienv -> ty Muid.t + val xclosed : unienv -> uniflags option + val close : unienv -> EcCoreSubst.unisubst + val assubst : unienv -> EcCoreSubst.unisubst val tparams : unienv -> ty_params end val unify : EcEnv.env -> unienv -> ty -> ty -> unit -val hastc : EcEnv.env -> unienv -> ty -> Sp.t -> unit val tfun_expected : unienv -> EcTypes.ty list -> EcTypes.ty @@ -43,10 +62,10 @@ type sbody = ((EcIdent.t * ty) list * expr) Lazy.t val select_op : ?hidden:bool - -> ?filter:(path -> operator -> bool) + -> ?filter:(EcPath.path -> operator -> bool) -> tvi -> EcEnv.env -> qsymbol -> unienv -> dom - -> ((EcPath.path * ty list) * ty * unienv * sbody option) list + -> ((EcPath.path * etyarg list) * ty * unienv * sbody option) list diff --git a/src/ecUserMessages.ml b/src/ecUserMessages.ml index 6973f029ee..2cee8c036f 100644 --- a/src/ecUserMessages.ml +++ b/src/ecUserMessages.ml @@ -1,8 +1,8 @@ (* -------------------------------------------------------------------- *) open EcSymbols -open EcUid open EcPath open EcUtils +open EcAst open EcTypes open EcCoreSubst open EcEnv @@ -21,6 +21,7 @@ let set_ppo (newppo : pp_options) = module TypingError : sig open EcTyping + val pp_uniflags : Format.formatter -> EcUnify.uniflags -> unit val pp_fxerror : env -> Format.formatter -> fxerror -> unit val pp_tyerror : env -> Format.formatter -> tyerror -> unit val pp_cnv_failure : env -> Format.formatter -> tymod_cnv_failure -> unit @@ -30,6 +31,16 @@ module TypingError : sig end = struct open EcTyping + let pp_uniflags (fmt : Format.formatter) ({ tyvars; tcvars; } : EcUnify.uniflags) = + let msg = + match tyvars, tcvars with + | false, false -> None + | true, false -> Some "type" + | false, true -> Some "type-class" + | true, true -> Some "type&type-class" in + + Option.iter (Format.fprintf fmt "%s") msg + let pp_mismatch_funsig env0 fmt error = let ppe0 = EcPrinting.PPEnv.ofenv env0 in @@ -235,8 +246,10 @@ end = struct | UniVarNotAllowed -> msg "type place holders not allowed" - | FreeTypeVariables -> - msg "this expression contains free type variables" + | FreeUniVariables infos -> + msg + "this expression contains free %a variables" + pp_uniflags infos | TypeVarNotAllowed -> msg "type variables not allowed" @@ -348,7 +361,7 @@ end = struct | MultipleOpMatch (name, tys, matches) -> begin let uvars = List.map Tuni.univars tys in - let uvars = List.fold_left Suid.union Suid.empty uvars in + let uvars = List.fold_left TyUni.Suid.union TyUni.Suid.empty uvars in begin match tys with | [] -> @@ -366,7 +379,7 @@ end = struct let pp_op fmt ((op, inst), subue) = let uidmap = EcUnify.UniEnv.assubst subue in - let inst = Tuni.subst_dom uidmap inst in + let inst = Tuni.subst_dom uidmap (List.fst inst) in begin match inst with | [] -> @@ -379,8 +392,8 @@ end = struct end; let myuvars = List.map Tuni.univars inst in - let myuvars = List.fold_left Suid.union uvars myuvars in - let myuvars = Suid.elements myuvars in + let myuvars = List.fold_left TyUni.Suid.union uvars myuvars in + let myuvars = TyUni.Suid.elements myuvars in let uidmap = EcUnify.UniEnv.assubst subue in let tysubst = ty_subst (Tuni.subst uidmap) in @@ -506,6 +519,14 @@ end = struct | LvMapOnNonAssign -> msg "map-style left-value cannot be used with assignments" + | TCArgsCountMismatch (_, typarams, tys) -> + msg "typeclass expects %d arguments, got %d" + (List.length typarams) (List.length tys) + + | CannotInferTC (ty, tc) -> + msg "cannot infer typeclass `%a' for type `%a'" + (EcPrinting.pp_typeclass env) tc pp_type ty + | NoDefaultMemRestr -> msg "no default sign for memory restriction. Use '+' or '-', or \ set the %s pragma to retrieve the old behaviour" @@ -613,8 +634,10 @@ end = struct let pp_tperror (env : env) fmt = function | TPE_Typing e -> TypingError.pp_tyerror env fmt e - | TPE_TyNotClosed -> - Format.fprintf fmt "this predicate type contains free type variables" + | TPE_TyNotClosed infos -> + Format.fprintf fmt + "this predicate type contains free %a variables" + TypingError.pp_uniflags infos | TPE_DuplicatedConstr x -> Format.fprintf fmt "duplicated constructor name: `%s'" x end @@ -633,8 +656,10 @@ end = struct match error with | NTE_Typing e -> TypingError.pp_tyerror env fmt e - | NTE_TyNotClosed -> - msg "this notation type contains free type variables" + | NTE_TyNotClosed infos -> + msg + "this notation type contains free %a variables" + TypingError.pp_uniflags infos | NTE_DupIdent -> msg "an ident is bound several time" | NTE_UnknownBinder x -> diff --git a/src/ecUserMessages.mli b/src/ecUserMessages.mli index efe97e0efc..97d3e0d10b 100644 --- a/src/ecUserMessages.mli +++ b/src/ecUserMessages.mli @@ -14,6 +14,7 @@ val set_ppo : pp_options -> unit module TypingError : sig open EcTyping + val pp_uniflags : Format.formatter -> EcUnify.uniflags -> unit val pp_tyerror : env -> Format.formatter -> tyerror -> unit val pp_cnv_failure : env -> Format.formatter -> tymod_cnv_failure -> unit val pp_mismatch_funsig : env -> Format.formatter -> mismatch_funsig -> unit diff --git a/src/ecUtils.ml b/src/ecUtils.ml index 6213d2f966..41ccf352cb 100644 --- a/src/ecUtils.ml +++ b/src/ecUtils.ml @@ -116,6 +116,12 @@ type 'a tuple8 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a tuple9 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a pair = 'a * 'a +(* -------------------------------------------------------------------- *) +module SmartPair = struct + let mk ((a, b) as p) a' b' = + if a == a' && b == b' then p else (a', b') +end + (* -------------------------------------------------------------------- *) let t2_map (f : 'a -> 'b) (x, y) = (f x, f y) @@ -481,6 +487,17 @@ module List = struct | None -> failwith "List.last" | Some x -> x + let betail = + let rec aux (acc : 'a list) (s : 'a list) = + match s, acc with + | [], [] -> + failwith "List.betail" + | [], v :: vs-> + List.rev vs, v + | x :: xs, _ -> + aux (x :: acc) xs + in fun s -> aux [] s + let mbfilter (p : 'a -> bool) (s : 'a list) = match s with [] | [_] -> s | _ -> List.filter p s diff --git a/src/ecUtils.mli b/src/ecUtils.mli index fe135ee604..2c7bbe65f5 100644 --- a/src/ecUtils.mli +++ b/src/ecUtils.mli @@ -64,6 +64,11 @@ type 'a tuple8 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a tuple9 = 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a * 'a type 'a pair = 'a tuple2 +(* -------------------------------------------------------------------- *) +module SmartPair : sig + val mk : 'a * 'b -> 'a -> 'b -> 'a * 'b +end + (* -------------------------------------------------------------------- *) val in_seq1: ' a -> 'a list @@ -281,6 +286,7 @@ module List : sig val min : ?cmp:('a -> 'a -> int) -> 'a list -> 'a val max : ?cmp:('a -> 'a -> int) -> 'a list -> 'a + val betail : 'a list -> 'a list * 'a val destruct : 'a list -> 'a * 'a list val nth_opt : 'a list -> int -> 'a option val mbfilter : ('a -> bool) -> 'a list -> 'a list diff --git a/src/phl/ecPhlCond.ml b/src/phl/ecPhlCond.ml index abf5e4ddc2..b83903f552 100644 --- a/src/phl/ecPhlCond.ml +++ b/src/phl/ecPhlCond.ml @@ -226,8 +226,8 @@ let t_equiv_match_same_constr tc = let bhl = List.map (fst_map EcIdent.fresh) cl in let bhr = List.map (fst_map EcIdent.fresh) cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.f_ty) in let lhs = f_eq fl (f_app copl (List.map (curry f_local) bhl) fl.f_ty) in let lhs = f_exists (List.map (snd_map gtty) bhl) lhs in @@ -242,8 +242,8 @@ let t_equiv_match_same_constr tc = let sb, bhl = add_elocals sb cl in let sb, bhr = add_elocals sb cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.f_ty) in let pre = f_ands_simpl [ f_eq fl (f_app copl (List.map (curry f_local) bhl) fl.f_ty); f_eq fr (f_app copr (List.map (curry f_local) bhr) fr.f_ty) ] @@ -305,8 +305,8 @@ let t_equiv_match_eq tc = sb cl cr in let cop = EcPath.pqoname (EcPath.prefix pl) c in - let copl = f_op cop tyl (toarrow (List.snd cl) fl.f_ty) in - let copr = f_op cop tyr (toarrow (List.snd cr) fr.f_ty) in + let copl = f_op_tc cop tyl (toarrow (List.snd cl) fl.f_ty) in + let copr = f_op_tc cop tyr (toarrow (List.snd cr) fr.f_ty) in let pre = f_ands_simpl [ f_eq fl (f_app copl (List.map (curry f_local) bh) fl.f_ty); f_eq fr (f_app copr (List.map (curry f_local) bh) fr.f_ty) ] diff --git a/src/phl/ecPhlEqobs.ml b/src/phl/ecPhlEqobs.ml index b9efd3efdd..c928d03439 100644 --- a/src/phl/ecPhlEqobs.ml +++ b/src/phl/ecPhlEqobs.ml @@ -249,7 +249,7 @@ and i_eqobs_in il ir sim local (eqo:Mpv2.t) = let typr, _, tyinstr = oget (EcEnv.Ty.get_top_decl el.e_ty env) in let test = EcPath.p_equal typl typr && - List.for_all2 (EcReduction.EqTest.for_type env) tyinstl tyinstr in + List.for_all2 (EcReduction.EqTest.for_etyarg env) tyinstl tyinstr in if not test then raise EqObsInError; let rsim = ref sim in let doit eqs1 (argsl,sl) (argsr, sr) = diff --git a/src/phl/ecPhlInline.ml b/src/phl/ecPhlInline.ml index 4e7f6d0276..36fb790d89 100644 --- a/src/phl/ecPhlInline.ml +++ b/src/phl/ecPhlInline.ml @@ -32,7 +32,7 @@ module LowSubst = struct let rec esubst m e = match e.e_node with | Evar pv -> e_var (pvsubst m pv) e.e_ty - | _ -> EcTypes.e_map (fun ty -> ty) (esubst m) e + | _ -> EcTypes.e_map (esubst m) e let lvsubst m lv = match lv with diff --git a/src/phl/ecPhlOutline.ml b/src/phl/ecPhlOutline.ml index 6774ad118b..7b6091423d 100644 --- a/src/phl/ecPhlOutline.ml +++ b/src/phl/ecPhlOutline.ml @@ -279,8 +279,8 @@ let process_outline info tc = let sty = f_subst_init ~tu () in let es = e_subst sty in Some (lv_of_expr (es res)) - with EcUnify.UninstanciateUni -> - EcTyping.tyerror loc env EcTyping.FreeTypeVariables + with EcUnify.UninstanciateUni infos -> + EcTyping.tyerror loc env (FreeUniVariables infos) end | None, _ -> None | _, _ -> raise (OutlineError OE_UnnecessaryReturn) diff --git a/src/phl/ecPhlRCond.ml b/src/phl/ecPhlRCond.ml index e28a079391..f12b272d56 100644 --- a/src/phl/ecPhlRCond.ml +++ b/src/phl/ecPhlRCond.ml @@ -162,7 +162,7 @@ module LowMatch = struct in (x, xty)) cvars in let vars = List.map (curry f_local) names in let cty = toarrow (List.snd names) f.f_ty in - let po = f_op cname (List.snd tyinst) cty in + let po = f_op_tc cname (List.snd tyinst) cty in let po = f_app po vars f.f_ty in f_exists (List.map (snd_map gtty) names) (f_eq f po) in @@ -191,7 +191,7 @@ module LowMatch = struct let epr, asgn = if frame then begin let vars = List.map (fun (pv, ty) -> f_pvar pv ty (fst me)) pvs in - let epr = f_op cname (List.snd tyinst) f.f_ty in + let epr = f_op_tc cname (List.snd tyinst) f.f_ty in let epr = f_app epr vars f.f_ty in Some (f_eq f epr), [] end else begin @@ -200,7 +200,7 @@ module LowMatch = struct (* FIXME: factorize out *) let rty = ttuple (List.snd cvars) in let proj = EcInductive.datatype_proj_path typ (EcPath.basename cname) in - let proj = e_op proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in + let proj = e_op_tc proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in let proj = e_app proj [e] (toption rty) in let proj = e_oget proj rty in i_asgn (lv, proj)) in diff --git a/src/phl/ecPhlRwEquiv.ml b/src/phl/ecPhlRwEquiv.ml index 3e38064377..f7b63d3f06 100644 --- a/src/phl/ecPhlRwEquiv.ml +++ b/src/phl/ecPhlRwEquiv.ml @@ -145,8 +145,8 @@ let process_rewrite_equiv info tc = let res = omap (fun v -> EcTyping.transexpcast subenv `InProc ue ret_ty v) pres in let es = e_subst (Tuni.subst (EcUnify.UniEnv.close ue)) in Some (List.map es args, omap (EcModules.lv_of_expr |- es) res) - with EcUnify.UninstanciateUni -> - EcTyping.tyerror (loc pargs) env EcTyping.FreeTypeVariables + with EcUnify.UninstanciateUni infos -> + EcTyping.tyerror (loc pargs) env (FreeUniVariables infos) end in diff --git a/src/phl/ecPhlWhile.ml b/src/phl/ecPhlWhile.ml index 9dea44ff9e..f48e16adcd 100644 --- a/src/phl/ecPhlWhile.ml +++ b/src/phl/ecPhlWhile.ml @@ -447,7 +447,7 @@ module ASyncWhile = struct | Fint z -> e_int z | Flocal x -> e_local x fp.f_ty - | Fop (p, tys) -> e_op p tys fp.f_ty + | Fop (p, tys) -> e_op_tc p tys fp.f_ty | Fapp (f, fs) -> e_app (aux f) (List.map aux fs) fp.f_ty | Ftuple fs -> e_tuple (List.map aux fs) | Fproj (f, i) -> e_proj (aux f) i fp.f_ty diff --git a/subtypes/subtype.ec b/subtypes/subtype.ec new file mode 100644 index 0000000000..1f4c2f2535 --- /dev/null +++ b/subtypes/subtype.ec @@ -0,0 +1,107 @@ +(* ==================================================================== *) +subtype 'a word (n : int) = { + w : 'a list | size w = n +} + witness. + +op cat ['a] [n m : int] (x : {'a word n}) (y : {'a word m}) : {'a word (n+m)} = + x ++ y. + +==> (traduction) + +op cat ['a] (x : 'a word) (y : 'a word) : 'a word = + x ++ y. + +lemma cat_spec ['a] : + forall (n m : int) (x y : 'a word), + size x = n => size y = m => size (cat x y) = (n + m). + +op xor [n m : int] (w1 : {word n}) (w2 : {word m}) : {word (min (n, m))} = + ... + +lemma foo ['a] [n : int] (w1 w2 : {'a word n}) : + xor w1 w2 = xor w2 w1. + +op vectorize ['a] [n m : int] (w : {'a word (n * m)}) : {{'a word n} word m}. + +lemma vectorize_spec ['a] (w : 'a list) : size w = (n * m) => + size (vectorize w) = m + /\ (all (fun w' => size w' = n) (vectorize w)). + +-> Keeping information in application? Yes + -> should provide a syntax for giving the arguments + + {w : word 256} + + vectorize<:int, n = 4> w ==> infer: m = 64 + +-> What to do when the inference fails + 1. we reject (most likely) + 2. we open a goal + +-> In a proof script (apply: foo) or (rewrite foo) + 1. inference des dépendances (n, m, ...) + 2. décharger les conditions de bord (size w1 = n, size w2 = n) + +-> Goal + n : int + m : int + w1 : {word n} + w2 : {word m} + ==================================================================== + E[xor (cat w1 w2) (cat w2 w1)] + + rewrite foo + + n : int + m : int + w1 : {word n} + w2 : {word m} + ==================================================================== + E[xor (cat w2 w1) (cat w1 w2)] + + under condition: + exists p . size (cat w1 w2) = p /\ size (cat w2 w1) = p. + + ?p = size (cat w1 w2) + ?p = size (cat w2 w1) + +-> can be solved using a extended prolog-like engine + 1. declarations of variables (w1 : {word n}) (w2 : {word m}) + 2. prolog-like facts from operators types (-> ELPI) + 3. theories (ring / int) + +-> subtypes in procedures + + We can only depend on operators / constants. I.e. the following + program should be rejected: + + module M = { + var n : int + + proc f(x : {bool word M.n}) = { + } + } + + Question: + - What about dependent types in the type for results: + we reject programs if we cannot statically check the condition + - What about the logics? we have to patch them. + +(* ==================================================================== *) +all : 'a t * 'a -> bool + +axiom all_spec ['a] : forall (f : 'a t -> 'a) (s : 'a t), all (s, f s). + +nth ['a] 'a -> 'a list -> int -> 'a + +lemma nth_spec ['a] (x : 'a) (s : 'a list) (i : int) : + forall P, + (forall y, all<: 'a> (y, x) -> P y) -> + P x -> (forall y, all<: 'a list> (s, y) -> P y) -> P (nth x s i). + +ws : {word n} list + +nth<:word> witness ws 2 : word +nth<:{word n}> + +coercion : 'a word n -> 'a list