Skip to content

Commit 4d79300

Browse files
committed
DistrMatrix
1 parent fa96e87 commit 4d79300

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed

examples/distrmatrix.ec

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
(* -------------------------------------------------------------------- *)
2+
require import AllCore List Distr DList Number StdOrder StdBigop.
3+
require import RealSeries.
4+
require (*--*) DynMatrix.
5+
(*---*) import IntOrder RealOrder RField Bigint Bigreal.
6+
7+
(* -------------------------------------------------------------------- *)
8+
clone import DynMatrix as DM.
9+
(*-*) import DM.ZR.
10+
11+
(* -------------------------------------------------------------------- *)
12+
abbrev "_.[_]" ['a] (xs : 'a list) (i : int) = nth<:'a> witness xs i.
13+
14+
(* -------------------------------------------------------------------- *)
15+
lemma compE ['a 'b 'c] (f : 'a -> 'b) (g : 'b -> 'c) (x : 'a) :
16+
(g \o f) x = g (f x).
17+
proof. done. qed.
18+
19+
hint simplify compE.
20+
21+
(* -------------------------------------------------------------------- *)
22+
lemma dlist_ubound (n : int) (d : R distr) (E : R -> bool) : 0 <= n =>
23+
mu
24+
(dlist d n)
25+
(fun xs => exists i, 0 <= i < n /\ E xs.[i])
26+
<= n%r * mu d E.
27+
proof.
28+
elim: n => /= [|n ge0_n ih]; first by rewrite dlist0 // dunitE //#.
29+
rewrite dlistS //= dmapE /(\o) /=.
30+
pose P1 (x : R) := E x.
31+
pose P2 (xs : R list) := exists i, (0 <= i < n /\ E xs.[i]).
32+
pose P (x_xs : R * R list) := P1 x_xs.`1 \/ P2 x_xs.`2.
33+
rewrite (mu_eq_support _ _ P).
34+
- case=> [x xs] /supp_dprod /= [_].
35+
case/(supp_dlist _ _ _ ge0_n) => [sz_xs _].
36+
rewrite /P /=; apply/eq_iff; split; first smt().
37+
case=> [Ex|]; first exists 0; smt().
38+
by case=> i rg_i; exists (i+1) => //#.
39+
apply: (ler_trans _ _ _ (le_dprod_or _ _ _ _)).
40+
rewrite fromintD mulrDl /= addrC ler_add.
41+
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
42+
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
43+
qed.
44+
45+
(* -------------------------------------------------------------------- *)
46+
op dadd (d1 d2 : R distr) =
47+
dmap (d1 `*` d2) (fun xy : R * R => xy.`1 + xy.`2).
48+
49+
(* -------------------------------------------------------------------- *)
50+
lemma dlistD (n : int) (d1 d2 : R distr) : 0 <= n =>
51+
dlet (dlist d1 n) (fun (xs : R list) =>
52+
dmap (dlist d2 n) (fun (ys : R list) =>
53+
mkseq (fun i => xs.[i] + ys.[i]) n))
54+
= dlist (dadd d1 d2) n.
55+
proof.
56+
pose S n (xs ys : R list) := mkseq (fun i => xs.[i] + ys.[i]) n.
57+
pose T n (xs : R list * R list) := S n xs.`1 xs.`2.
58+
move=> ge0_n; rewrite -(dmap_dprodE _ _ (T n)). (* SLOW *)
59+
elim: n ge0_n => /= [|n ge0_n ih]; last rewrite !dlistS //.
60+
- by rewrite !dlist0 // dprod_dunit dmap_dunit /T /S /= mkseq0.
61+
pose C (x_xs : R * R list) := x_xs.`1 :: x_xs.`2.
62+
pose F (x : R * R) (xs : R list * R list) :=
63+
S (n+1) (x.`1 :: xs.`1) (x.`2 :: xs.`2).
64+
pose G (xs ys : R list) := S (n+1) xs ys.
65+
rewrite dmap_dprodE; have -> := dprod_dmap_cross
66+
d1 (dlist d1 n) d2 (dlist d2 n) C C G idfun idfun F _; first by done.
67+
rewrite !dmap_id /= dmap_dprodE {1}/dadd dlet_dmap.
68+
apply/eq_dlet => // -[x y] /=; rewrite -ih.
69+
rewrite dmap_comp &(eq_dmap) => -[xs ys] /=.
70+
by rewrite /F /S /C /T /= mkseqSr //= &(eq_in_mkseq) //#.
71+
qed.
72+
73+
(* -------------------------------------------------------------------- *)
74+
lemma dmatrix_dlist (r c : int) (d : R distr) :
75+
0 <= r => 0 <= c => dmatrix d r c =
76+
dmap
77+
(dlist d (r * c))
78+
(fun vs => offunm ((fun i j => vs.[j * r + i]), r, c)).
79+
proof.
80+
move=> ge0_r ge0_c @/dmatrix @/dvector.
81+
rewrite dlist_dmap dmap_comp !lez_maxr //.
82+
rewrite -dlist_dlist // dmap_comp &(eq_dmap_in) => xss /=.
83+
case/(supp_dlist _ _ _ ge0_c) => size_xss /allP xssE.
84+
have {xssE} xssE: forall xs, xs \in xss => size xs = r.
85+
- by move=> xs /xssE /(supp_dlist _ _ _ ge0_r).
86+
apply/eq_matrixP=> @/ofcols /= i j [].
87+
rewrite !lez_maxr // => rgi rgj.
88+
rewrite !get_offunm /= ?lez_maxr //.
89+
rewrite (nth_map witness) 1:/#.
90+
rewrite (get_oflist witness) 1:#smt:(mem_nth).
91+
rewrite -nth_flatten ~-1:#smt:(mem_nth); do 2! congr.
92+
rewrite sumzE BIA.big_map predT_comp /(\o) /=.
93+
pose D := BIA.big predT (fun _ => r) (take j xss).
94+
apply: (eq_trans _ D) => @/D.
95+
- rewrite !BIA.big_seq &(BIA.eq_bigr) //=.
96+
by move=> xs /mem_take /xssE.
97+
by rewrite big_constz count_predT size_take //#.
98+
qed.
99+
100+
(* -------------------------------------------------------------------- *)
101+
lemma dmatrixD (r c : int) (d1 d2 : R distr) : 0 <= r => 0 <= c =>
102+
dlet (dmatrix d1 r c) (fun (m1 : matrix) =>
103+
dmap (dmatrix d2 r c) (fun (m2 : matrix) => m1 + m2))
104+
= dmatrix (dadd d1 d2) r c.
105+
proof.
106+
move=> ge0_r ge0_c; rewrite 2?dmatrix_dlist //=.
107+
pose F vs := offunm (fun i j => vs.[j * r + i], r, c).
108+
rewrite dlet_dmap /= dlet_swap dlet_dmap /= dlet_swap /=.
109+
rewrite dmatrix_dlist // -/F -dlistD ~-1:/#.
110+
rewrite dmap_dlet &(eq_dlet) // => xs /=.
111+
rewrite dlet_dunit dmap_comp &(eq_dmap) => ys /=.
112+
apply/eq_matrixP; split.
113+
- by rewrite /F size_addm !size_offunm.
114+
move=> i j []; rewrite rows_addm cols_addm /=.
115+
rewrite !rows_offunm !cols_offunm !maxzz => rgi rgj.
116+
by rewrite get_addm !get_offunm //= nth_mkseq //#.
117+
qed.
118+
119+
(* -------------------------------------------------------------------- *)
120+
op dmul (n : int) (d1 d2 : R distr) =
121+
dmap
122+
(dlist d1 n `*` dlist d2 n)
123+
(fun vs : R list * R list =>
124+
DM.Big.BAdd.big predT
125+
(fun xy : R * R => xy.`1 * xy.`2)
126+
(zip vs.`1 vs.`2)).
127+
128+
(* -------------------------------------------------------------------- *)
129+
lemma dmatrix_cols (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
130+
dmatrix d r c = dmap (dlist (dvector d r) c) (ofcols r c).
131+
proof. by move=> ge0_c ge0_r @/dmatrix; rewrite lez_maxr. qed.
132+
133+
(* -------------------------------------------------------------------- *)
134+
lemma dmatrix_rows (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
135+
dmatrix d r c = dmap (dlist (dvector d c) r) (trmx \o ofcols c r).
136+
proof.
137+
move=> ge0_r ge0_c; rewrite -dmap_comp -dmatrix_cols //.
138+
apply/eq_distr => /= m; rewrite (dmap1E _ trmx).
139+
have ->: pred1 m \o trmx = pred1 (trmx m) by smt(trmxK).
140+
case: (size m = (r, c)); last first.
141+
- by move=> ne_size; rewrite !dmatrix0E //#.
142+
case=> <<- <<-; rewrite -{2}rows_tr -{2}cols_tr !dmatrix1E /=.
143+
by rewrite BRM.exchange_big.
144+
qed.
145+
146+
(* -------------------------------------------------------------------- *)
147+
hint simplify drop0, take0, cats0, cat0s.
148+
149+
(* -------------------------------------------------------------------- *)
150+
lemma dmatrix_cols_i (i : int) (d : R distr) (r c : int) :
151+
0 <= c => 0 <= r => 0 <= i < c =>
152+
dmatrix d r c =
153+
dmap
154+
(dvector d r `*` dlist (dvector d r) (c-1))
155+
(fun c_cs : _ * _ => ofcols r c (insert c_cs.`1 c_cs.`2 i)).
156+
proof.
157+
move=> ge0_c ge0_r rgi; rewrite dmatrix_cols //.
158+
rewrite {1}(_ : c = (c - 1) + 1) // (dlist_insert witness i) ~-1://# /=.
159+
by rewrite dmap_comp &(eq_dmap) => -[v vs].
160+
qed.
161+
162+
(* -------------------------------------------------------------------- *)
163+
lemma dmatrix_rows_i (j : int) (d : R distr) (r c : int) :
164+
0 <= c => 0 <= r => 0 <= j < r =>
165+
dmatrix d r c =
166+
dmap
167+
(dvector d c `*` dlist (dvector d c) (r-1))
168+
(fun r_rs : _ * _ => trmx (ofcols c r (insert r_rs.`1 r_rs.`2 j))).
169+
proof.
170+
move=> ge0_c ge0_r rgj; rewrite dmatrix_rows //.
171+
rewrite {1}(_ : r = (r - 1) + 1) // (dlist_insert witness j) ~-1://# /=.
172+
by rewrite dmap_comp &(eq_dmap) => -[v vs].
173+
qed.
174+
175+
(* -------------------------------------------------------------------- *)
176+
lemma col_ofcols (i r c : int) (vs : vector list) :
177+
0 <= r => 0 <= c => 0 <= i < c
178+
=> size vs = c
179+
=> all (fun v : vector => size v = r) vs
180+
=> col (ofcols r c vs) i = vs.[i].
181+
proof.
182+
move=> ge0_r ge0_c rgi sz_vs /allP => sz_in_vs.
183+
have sz_rows: rows (ofcols r c vs) = r.
184+
- by rewrite rows_offunm lez_maxr // sz_in_vs.
185+
apply/eq_vectorP; split=> /=.
186+
- by rewrite sz_rows sz_in_vs // &(mem_nth) sz_vs.
187+
by move=> j; rewrite sz_rows => rgj; rewrite get_offunm //#.
188+
qed.
189+
190+
(* -------------------------------------------------------------------- *)
191+
lemma dmatrixM (m n p : int) (d1 d2 : R distr) :
192+
0 <= m => 0 <= n => 0 <= p =>
193+
194+
let d =
195+
dlet (dmatrix d1 m n) (fun (m1 : matrix) =>
196+
dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in
197+
198+
forall i j, 0 <= i < m => 0 <= j < p =>
199+
dmap d (fun m => m.[i, j]) =
200+
((weight d1) ^ (n * (m-1)) * (weight d2) ^ (n * (p-1))) \cdot dmul n d1 d2.
201+
proof.
202+
move=> ge0_m ge0_n ge0_p d i j rg_i rg_j.
203+
have [gt0_m gt0_p]: (0 <= m-1) /\ (0 <= p-1) by smt().
204+
rewrite /d (dmatrix_rows_i i) //= (dmatrix_cols_i j) //=.
205+
pose D1 := dvector d1 n `*` _; pose D2 := dvector d2 n `*` _.
206+
pose F1 := fun (r_rs : _ * _) => trmx (ofcols n m (insert r_rs.`1 r_rs.`2 i)).
207+
pose F2 := fun (c_cs : _ * _) => ofcols n p (insert c_cs.`1 c_cs.`2 j).
208+
pose F r rs c cs := (trmx (ofcols n m (insert r rs i)) * ofcols n p (insert c cs j)).[i, j].
209+
pose D := dlet D1 (fun c : _ * _ => dmap D2 (fun r : _ * _ => F c.`1 c.`2 r.`1 r.`2)).
210+
apply: (eq_trans _ D) => @/D => {D}.
211+
- rewrite dmap_dlet dlet_dmap /= &(eq_dlet) // => ? /=.
212+
by rewrite 2!dmap_comp &(eq_dmap).
213+
pose G (x_xs : (_ * _) * (_ * _)) := F x_xs.`1.`1 x_xs.`2.`1 x_xs.`1.`2 x_xs.`2.`2.
214+
rewrite dprod_cross /= => {D1 D2}; pose D1 := _ `*` _; pose D2 := _ `*` _.
215+
have @/G /= <- := dmap_dprodE D1 D2 G => {G}.
216+
pose G (vs : vector * vector) := dotp vs.`1 vs.`2.
217+
apply: (eq_trans _ (dmap (D1 `*` D2) (fun x : _ * _ => G x.`1))).
218+
- apply: eq_dmap_in=> -[[c r] [cs rs]] @/G @/F /=.
219+
case/supp_dprod=> /= /supp_dprod[/=].
220+
case/(supp_dvector _ _ _ ge0_n) => sz_c _.
221+
case/(supp_dvector _ _ _ ge0_n) => sz_r _.
222+
move/supp_dprod=> [/=].
223+
case/(supp_dlist _ _ _ gt0_m) => [sz_cs /allP sz_in_cs].
224+
case/(supp_dlist _ _ _ gt0_p) => [sz_rs /allP sz_in_rs].
225+
rewrite get_mulmx row_trmx /= !col_ofcols //.
226+
- by rewrite size_insert ?sz_cs //#.
227+
- apply/allP=> v /mem_insert [->>|] //=.
228+
by move/sz_in_cs => /(supp_dvector _ _ _ ge0_n).
229+
- by rewrite size_insert ?sz_rs //#.
230+
- apply/allP=> v /mem_insert [->>|] //=.
231+
by move/sz_in_rs => /(supp_dvector _ _ _ ge0_n).
232+
by rewrite !nth_insert // (sz_cs, sz_rs) //#.
233+
rewrite dprod_marginalL /D2 weight_dprod !weight_dlist // !weight_dmap.
234+
rewrite !weight_dlist ?lez_maxr // -!exprM.
235+
congr=> @/D1 @/G => {D1 D2 G} @/dmul.
236+
rewrite !dmap_dprodE /= dlet_dmap lez_maxr //.
237+
apply/in_eq_dlet => //= xs /(supp_dlist _ _ _ ge0_n)[sz_xs _].
238+
rewrite dmap_comp lez_maxr //; apply/eq_dmap_in => /= ys.
239+
case/(supp_dlist _ _ _ ge0_n)=> sz_ys _ @/dotp.
240+
rewrite !size_oflist sz_xs sz_ys lez_maxr //.
241+
apply/eq_sym; rewrite (Big.BAdd.big_nth witness) predT_comp.
242+
rewrite size_zip sz_xs sz_ys lez_minr //.
243+
rewrite !Big.BAdd.big_seq /= &(Big.BAdd.eq_bigr) /=.
244+
move=> k /mem_range rg_k; rewrite !(get_oflist witness) ~-1://#.
245+
have := nth_zip witness witness xs ys k _; first by smt().
246+
by rewrite (nth_change_dfl witness) => [|->//]; rewrite size_zip /#.
247+
qed.
248+
249+
(* -------------------------------------------------------------------- *)
250+
lemma dmatrixM_ll (m n p : int) (d1 d2 : R distr) :
251+
0 <= m => 0 <= n => 0 <= p =>
252+
253+
is_lossless d1 => is_lossless d2 =>
254+
255+
let d =
256+
dlet (dmatrix d1 m n) (fun (m1 : matrix) =>
257+
dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in
258+
259+
forall i j, 0 <= i < m => 0 <= j < p =>
260+
dmap d (fun m => m.[i, j]) = dmul n d1 d2.
261+
proof.
262+
move=> *; rewrite dmatrixM //; pose c := (_ * _)%Real.
263+
rewrite (_ : c = 1%r) -1:dscalar1 // /c.
264+
by do 2! rewrite (_ : weight _ = 1%r) // expr1z.
265+
qed.

0 commit comments

Comments
 (0)