Skip to content

Commit 2754b01

Browse files
committed
DistrMatrix
1 parent a333999 commit 2754b01

File tree

1 file changed

+369
-0
lines changed

1 file changed

+369
-0
lines changed

examples/distrmatrix.ec

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
(* -------------------------------------------------------------------- *)
2+
require import AllCore List Distr DList Number StdOrder StdBigop.
3+
require import RealSeries.
4+
require (*--*) DynMatrix.
5+
(*---*) import IntOrder RealOrder RField Bigint Bigreal.
6+
7+
(* -------------------------------------------------------------------- *)
8+
clone import DynMatrix as DM.
9+
(*-*) import DM.ZR.
10+
11+
(* -------------------------------------------------------------------- *)
12+
abbrev "_.[_]" ['a] (xs : 'a list) (i : int) = nth<:'a> witness xs i.
13+
14+
(* -------------------------------------------------------------------- *)
15+
lemma compE ['a 'b 'c] (f : 'a -> 'b) (g : 'b -> 'c) (x : 'a) :
16+
(g \o f) x = g (f x).
17+
proof. done. qed.
18+
19+
hint simplify compE.
20+
21+
(* -------------------------------------------------------------------- *)
22+
lemma dlist_ubound (n : int) (d : R distr) (E : R -> bool) : 0 <= n =>
23+
mu
24+
(dlist d n)
25+
(fun xs => exists i, 0 <= i < n /\ E xs.[i])
26+
<= n%r * mu d E.
27+
proof.
28+
elim: n => /= [|n ge0_n ih]; first by rewrite dlist0 // dunitE //#.
29+
rewrite dlistS //= dmapE /(\o) /=.
30+
pose P1 (x : R) := E x.
31+
pose P2 (xs : R list) := exists i, (0 <= i < n /\ E xs.[i]).
32+
pose P (x_xs : R * R list) := P1 x_xs.`1 \/ P2 x_xs.`2.
33+
rewrite (mu_eq_support _ _ P).
34+
- case=> [x xs] /supp_dprod /= [_].
35+
case/(supp_dlist _ _ _ ge0_n) => [sz_xs _].
36+
rewrite /P /=; apply/eq_iff; split; first smt().
37+
case=> [Ex|]; first exists 0; smt().
38+
by case=> i rg_i; exists (i+1) => //#.
39+
apply: (ler_trans _ _ _ (le_dprod_or _ _ _ _)).
40+
rewrite fromintD mulrDl /= addrC ler_add.
41+
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
42+
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
43+
qed.
44+
45+
(* -------------------------------------------------------------------- *)
46+
lemma L ['a 'b 'c 'd 'e 'ab 'ac 'bd 'cd]
47+
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr)
48+
(Fab : 'a * 'b -> 'ab)
49+
(Fcd : 'c * 'd -> 'cd)
50+
(F : 'ab -> 'cd -> 'e)
51+
(Fac : 'a * 'c -> 'ac)
52+
(Fbd : 'b * 'd -> 'bd)
53+
(G : 'ac -> 'bd -> 'e)
54+
:
55+
(forall a b c d, F (Fab (a, b)) (Fcd (c, d)) = G (Fac (a, c)) (Fbd (b, d))) =>
56+
57+
dlet
58+
(dmap (da `*` db) Fab)
59+
(fun ab =>
60+
dmap
61+
(dmap (dc `*` dd) Fcd)
62+
(fun cd => F ab cd))
63+
= dlet
64+
(dmap (da `*` dc) Fac)
65+
(fun ac =>
66+
dmap
67+
(dmap (db `*` dd) Fbd)
68+
(fun bd => G ac bd)).
69+
proof.
70+
pose D1 := dlet (da `*` db)
71+
(fun ab => dlet dc (fun c => dmap dd (fun d => F (Fab ab) (Fcd (c, d))))).
72+
move=> eq; rewrite dlet_dmap /= &(eq_trans _ D1) /D1 => {D1}.
73+
- by rewrite &(eq_dlet) // => ab /=; rewrite dmap_comp dmap_dprodE.
74+
pose D2 := dlet (da `*` dc)
75+
(fun ac => dlet db (fun b => dmap dd (fun d => G (Fac ac) (Fbd (b, d))))).
76+
rewrite dlet_dmap /= &(eq_trans _ D2) /D2 => {D2}; last first.
77+
- by rewrite &(eq_dlet) // => ac /=; rewrite dmap_comp dmap_dprodE.
78+
rewrite !dprod_dlet !dlet_dlet /= &(eq_dlet) // => a /=.
79+
rewrite dlet_dlet /= dlet_swap &(eq_dlet) // => b /=.
80+
rewrite 2!(dlet_dunit, dlet_unit) /= dlet_dmap.
81+
rewrite &(eq_dlet) // => c /=; rewrite &(eq_dmap) // => d /=.
82+
by apply: eq.
83+
qed.
84+
85+
(* -------------------------------------------------------------------- *)
86+
lemma L2 ['a 'b 'c 'd 'e]
87+
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr)
88+
(F : 'a -> 'b -> 'c -> 'd -> 'e)
89+
:
90+
dlet
91+
(da `*` db)
92+
(fun ab : 'a * 'b =>
93+
dmap
94+
(dc `*` dd)
95+
(fun cd : 'c * 'd => F ab.`1 ab.`2 cd.`1 cd.`2))
96+
= dlet
97+
(da `*` dc)
98+
(fun ac : 'a * 'c =>
99+
dmap
100+
(db `*` dd)
101+
(fun bd : 'b * 'd => F ac.`1 bd.`1 ac.`2 bd.`2)).
102+
proof.
103+
pose F1 (ab : 'a * 'b) (cd : 'c * 'd) := F ab.`1 ab.`2 cd.`1 cd.`2.
104+
pose F2 (ac : 'a * 'c) (bd : 'b * 'd) := F ac.`1 bd.`1 ac.`2 bd.`2.
105+
have := L da db dc dd idfun idfun F1 idfun idfun F2 _; first done.
106+
by rewrite !dmap_id.
107+
qed.
108+
109+
(* -------------------------------------------------------------------- *)
110+
lemma dprod_dunit ['a 'b] (x : 'a) (y : 'b) :
111+
dunit x `*` dunit y = dunit (x, y).
112+
proof.
113+
by apply: eq_distr => -[a b]; rewrite dprod1E !dunit1E /#.
114+
qed.
115+
116+
(* -------------------------------------------------------------------- *)
117+
op dadd (d1 d2 : R distr) =
118+
dmap (d1 `*` d2) (fun xy : R * R => xy.`1 + xy.`2).
119+
120+
(* -------------------------------------------------------------------- *)
121+
lemma dlistD (n : int) (d1 d2 : R distr) : 0 <= n =>
122+
dlet (dlist d1 n) (fun (xs : R list) =>
123+
dmap (dlist d2 n) (fun (ys : R list) =>
124+
mkseq (fun i => xs.[i] + ys.[i]) n))
125+
= dlist (dadd d1 d2) n.
126+
proof.
127+
pose S n (xs ys : R list) := mkseq (fun i => xs.[i] + ys.[i]) n.
128+
pose T n (xs : R list * R list) := S n xs.`1 xs.`2.
129+
move=> ge0_n; rewrite -(dmap_dprodE _ _ (T n)). (* SLOW *)
130+
elim: n ge0_n => /= [|n ge0_n ih]; last rewrite !dlistS //.
131+
- by rewrite !dlist0 // dprod_dunit dmap_dunit /T /S /= mkseq0.
132+
pose C (x_xs : R * R list) := x_xs.`1 :: x_xs.`2.
133+
pose F (x : R * R) (xs : R list * R list) :=
134+
S (n+1) (x.`1 :: xs.`1) (x.`2 :: xs.`2).
135+
pose G (xs ys : R list) := S (n+1) xs ys.
136+
rewrite dmap_dprodE; have -> :=
137+
L d1 (dlist d1 n) d2 (dlist d2 n) C C G idfun idfun F _; first by done.
138+
rewrite !dmap_id /= dmap_dprodE {1}/dadd dlet_dmap.
139+
apply/eq_dlet => // -[x y] /=; rewrite -ih.
140+
rewrite dmap_comp &(eq_dmap) => -[xs ys] /=.
141+
by rewrite /F /S /C /T /= mkseqSr //= &(eq_in_mkseq) //#.
142+
qed.
143+
144+
(* -------------------------------------------------------------------- *)
145+
lemma dlist_dlist ['a] (d : 'a distr) (m n : int) :
146+
0 <= m => 0 <= n =>
147+
dmap (dlist (dlist d m) n) flatten = dlist d (m * n).
148+
proof.
149+
move=> ge0_m; elim: n => /= [|n ge0_n ih].
150+
- by rewrite !dlist0 // dmap_dunit.
151+
rewrite mulrDr /= [dlist d (m * n + m)]dlist_add //.
152+
- by apply: IntOrder.mulr_ge0.
153+
rewrite dlistSr //= dmap_comp !dmap_dprodE /=.
154+
rewrite -ih dlet_dmap /= &(eq_dlet) // => xss /=.
155+
by rewrite &(eq_dmap) => xs /=; rewrite flatten_rcons.
156+
qed.
157+
158+
(* -------------------------------------------------------------------- *)
159+
lemma dmatrix_dlist (r c : int) (d : R distr) :
160+
0 <= r => 0 <= c => dmatrix d r c =
161+
dmap
162+
(dlist d (r * c))
163+
(fun vs => offunm ((fun i j => vs.[j * r + i]), r, c)).
164+
proof.
165+
move=> ge0_r ge0_c @/dmatrix @/dvector.
166+
rewrite dlist_dmap dmap_comp !lez_maxr //.
167+
rewrite -dlist_dlist // dmap_comp &(eq_dmap_in) => xss /=.
168+
case/(supp_dlist _ _ _ ge0_c) => size_xss /allP xssE.
169+
have {xssE} xssE: forall xs, xs \in xss => size xs = r.
170+
- by move=> xs /xssE /(supp_dlist _ _ _ ge0_r).
171+
apply/eq_matrixP=> @/ofcols /= i j [].
172+
rewrite !lez_maxr // => rgi rgj.
173+
rewrite !get_offunm /= ?lez_maxr //.
174+
rewrite (nth_map witness) 1:/#.
175+
rewrite (get_oflist witness) 1:#smt:(mem_nth).
176+
rewrite -nth_flatten ~-1:#smt:(mem_nth); do 2! congr.
177+
rewrite sumzE BIA.big_map predT_comp /(\o) /=.
178+
pose D := BIA.big predT (fun _ => r) (take j xss).
179+
apply: (eq_trans _ D) => @/D.
180+
- rewrite !BIA.big_seq &(BIA.eq_bigr) //=.
181+
by move=> xs /mem_take /xssE.
182+
by rewrite big_constz count_predT size_take //#.
183+
qed.
184+
185+
(* -------------------------------------------------------------------- *)
186+
lemma dmatrixD (r c : int) (d1 d2 : R distr) : 0 <= r => 0 <= c =>
187+
dlet (dmatrix d1 r c) (fun (m1 : matrix) =>
188+
dmap (dmatrix d2 r c) (fun (m2 : matrix) => m1 + m2))
189+
= dmatrix (dadd d1 d2) r c.
190+
proof.
191+
move=> ge0_r ge0_c; rewrite 2?dmatrix_dlist //=.
192+
pose F vs := offunm (fun i j => vs.[j * r + i], r, c).
193+
rewrite dlet_dmap /= dlet_swap dlet_dmap /= dlet_swap /=.
194+
rewrite dmatrix_dlist // -/F -dlistD ~-1:/#.
195+
rewrite dmap_dlet &(eq_dlet) // => xs /=.
196+
rewrite dlet_dunit dmap_comp &(eq_dmap) => ys /=.
197+
apply/eq_matrixP; split.
198+
- by rewrite /F size_addm !size_offunm.
199+
move=> i j []; rewrite rows_addm cols_addm /=.
200+
rewrite !rows_offunm !cols_offunm !maxzz => rgi rgj.
201+
by rewrite get_addm !get_offunm //= nth_mkseq //#.
202+
qed.
203+
204+
(* -------------------------------------------------------------------- *)
205+
op dmul (n : int) (d1 d2 : R distr) =
206+
dmap
207+
(dlist d1 n `*` dlist d2 n)
208+
(fun vs : R list * R list =>
209+
DM.Big.BAdd.big predT
210+
(fun xy : R * R => xy.`1 * xy.`2)
211+
(zip vs.`1 vs.`2)).
212+
213+
(* -------------------------------------------------------------------- *)
214+
lemma foo ['a 'b 'c] (da : 'a distr) (db : 'b distr) (f : 'a -> 'c) :
215+
dmap (da `*` db) (fun ab : 'a * 'b => f ab.`1)
216+
= weight db \cdot dmap da f.
217+
proof. by rewrite dmap_dprodE_swap /= dlet_cst_weight. qed.
218+
219+
(* -------------------------------------------------------------------- *)
220+
lemma dmatrix_cols (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
221+
dmatrix d r c = dmap (dlist (dvector d r) c) (ofcols r c).
222+
proof. by move=> ge0_c ge0_r @/dmatrix; rewrite lez_maxr. qed.
223+
224+
(* -------------------------------------------------------------------- *)
225+
lemma dmatrix_rows (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
226+
dmatrix d r c = dmap (dlist (dvector d c) r) (trmx \o ofcols c r).
227+
proof.
228+
move=> ge0_r ge0_c; rewrite -dmap_comp -dmatrix_cols //.
229+
apply/eq_distr => /= m; rewrite (dmap1E _ trmx).
230+
have ->: pred1 m \o trmx = pred1 (trmx m) by smt(trmxK).
231+
case: (size m = (r, c)); last first.
232+
- by move=> ne_size; rewrite !dmatrix0E //#.
233+
case=> <<- <<-; rewrite -{2}rows_tr -{2}cols_tr !dmatrix1E /=.
234+
by rewrite BRM.exchange_big.
235+
qed.
236+
237+
(* -------------------------------------------------------------------- *)
238+
hint simplify drop0, take0, cats0, cat0s.
239+
240+
(* -------------------------------------------------------------------- *)
241+
(* FIXME: refactor *)
242+
243+
lemma dlist_insert ['a] (i n : int) (d : 'a distr) :
244+
0 <= n => 0 <= i <= n => dlist d (n+1) =
245+
dmap (d `*` dlist d n) (fun x_xs : 'a * 'a list => insert x_xs.`1 x_xs.`2 i).
246+
proof.
247+
move=> ge0_n [ge0_i lti]; apply/eq_sym.
248+
pose f (x_xs : _ * _) := insert x_xs.`1 x_xs.`2 i.
249+
pose g (xs : 'a list) := (xs.[i], take i xs ++ drop (i+1) xs).
250+
have ge0_Sn: 0 <= n + 1 by smt(). apply: (dmap_bij _ _ f g).
251+
- case=> [x xs] /supp_dprod[/=] x_in_d.
252+
case/(supp_dlist _ _ _ ge0_n)=> sz_xs /allP xs_in_d.
253+
move=> @/f /=; apply/supp_dlist; first smt().
254+
rewrite size_insert ?sz_xs //=; apply/allP.
255+
by move=> y /mem_insert[->>//|/xs_in_d].
256+
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[sz_xs /allP xs_in_d] @/g.
257+
rewrite dprod1E !dlist1E ~-1://# sz_xs /=.
258+
rewrite size_cat size_take // size_drop 1:/#.
259+
rewrite iftrue 1:/# -(BRM.big_consT (mu1 d)) &(BRM.eq_big_perm).
260+
by rewrite -cat_cons perm_eq_sym &(perm_eq_nth_take_drop) //#.
261+
- case=> x xs /supp_dprod[/=] _ /(supp_dlist _ _ _ ge0_n)[sz_xs _].
262+
rewrite /g /f /= nth_insert ?sz_xs //= take_insert_le 1:/#.
263+
by rewrite drop_insert_gt 1:/# /= cat_take_drop.
264+
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[/=] sz_xs _ @/f @/g /=.
265+
have sz_take: size (take i xs) = i by rewrite size_take //#.
266+
by apply/insert_nth_take_drop => //#.
267+
qed.
268+
269+
hint simplify insert0.
270+
271+
(* -------------------------------------------------------------------- *)
272+
lemma dmatrix_cols_i (i : int) (d : R distr) (r c : int) :
273+
0 <= c => 0 <= r => 0 <= i < c =>
274+
dmatrix d r c =
275+
dmap
276+
(dvector d r `*` dlist (dvector d r) (c-1))
277+
(fun c_cs : _ * _ => ofcols r c (insert c_cs.`1 c_cs.`2 i)).
278+
proof.
279+
move=> ge0_c ge0_r rgi; rewrite dmatrix_cols //.
280+
rewrite {1}(_ : c = (c - 1) + 1) // (dlist_insert i) ~-1://# /=.
281+
by rewrite dmap_comp &(eq_dmap) => -[v vs].
282+
qed.
283+
284+
(* -------------------------------------------------------------------- *)
285+
lemma dmatrix_rows_i (j : int) (d : R distr) (r c : int) :
286+
0 <= c => 0 <= r => 0 <= j < r =>
287+
dmatrix d r c =
288+
dmap
289+
(dvector d c `*` dlist (dvector d c) (r-1))
290+
(fun r_rs : _ * _ => trmx (ofcols c r (insert r_rs.`1 r_rs.`2 j))).
291+
proof.
292+
move=> ge0_c ge0_r rgj; rewrite dmatrix_rows //.
293+
rewrite {1}(_ : r = (r - 1) + 1) // (dlist_insert j) ~-1://# /=.
294+
by rewrite dmap_comp &(eq_dmap) => -[v vs].
295+
qed.
296+
297+
(* -------------------------------------------------------------------- *)
298+
lemma col_ofcols (i r c : int) (vs : vector list) :
299+
0 <= r => 0 <= c => 0 <= i < c
300+
=> size vs = c
301+
=> all (fun v : vector => size v = r) vs
302+
=> col (ofcols r c vs) i = vs.[i].
303+
proof.
304+
move=> ge0_r ge0_c rgi sz_vs /allP => sz_in_vs.
305+
have sz_rows: rows (ofcols r c vs) = r.
306+
- by rewrite rows_offunm lez_maxr // sz_in_vs.
307+
apply/eq_vectorP; split=> /=.
308+
- by rewrite sz_rows sz_in_vs // &(mem_nth) sz_vs.
309+
by move=> j; rewrite sz_rows => rgj; rewrite get_offunm //#.
310+
qed.
311+
312+
(* -------------------------------------------------------------------- *)
313+
lemma dmatrixM (m n p : int) (d1 d2 : R distr) :
314+
0 <= m => 0 <= n => 0 <= p =>
315+
316+
let d =
317+
dlet (dmatrix d1 m n) (fun (m1 : matrix) =>
318+
dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in
319+
320+
forall i j, 0 <= i < m => 0 <= j < p =>
321+
dmap d (fun m => m.[i, j]) =
322+
((weight d1) ^ (n * (m-1)) * (weight d2) ^ (n * (p-1))) \cdot dmul n d1 d2.
323+
proof.
324+
move=> ge0_m ge0_n ge0_p d i j rg_i rg_j.
325+
have [gt0_m gt0_p]: (0 <= m-1) /\ (0 <= p-1) by smt().
326+
rewrite /d (dmatrix_rows_i i) //= (dmatrix_cols_i j) //=.
327+
pose D1 := dvector d1 n `*` _; pose D2 := dvector d2 n `*` _.
328+
pose F1 := fun (r_rs : _ * _) => trmx (ofcols n m (insert r_rs.`1 r_rs.`2 i)).
329+
pose F2 := fun (c_cs : _ * _) => ofcols n p (insert c_cs.`1 c_cs.`2 j).
330+
pose F r rs c cs := (trmx (ofcols n m (insert r rs i)) * ofcols n p (insert c cs j)).[i, j].
331+
pose D := dlet D1 (fun c : _ * _ => dmap D2 (fun r : _ * _ => F c.`1 c.`2 r.`1 r.`2)).
332+
apply: (eq_trans _ D) => @/D => {D}.
333+
- rewrite dmap_dlet dlet_dmap /= &(eq_dlet) // => ? /=.
334+
by rewrite 2!dmap_comp &(eq_dmap).
335+
pose G (x_xs : (_ * _) * (_ * _)) := F x_xs.`1.`1 x_xs.`2.`1 x_xs.`1.`2 x_xs.`2.`2.
336+
rewrite L2 /= => {D1 D2}; pose D1 := _ `*` _; pose D2 := _ `*` _.
337+
have @/G /= <- := dmap_dprodE D1 D2 G => {G}.
338+
pose G (vs : vector * vector) := dotp vs.`1 vs.`2.
339+
apply: (eq_trans _ (dmap (D1 `*` D2) (fun x : _ * _ => G x.`1))).
340+
- apply: eq_dmap_in=> -[[c r] [cs rs]] @/G @/F /=.
341+
case/supp_dprod=> /= /supp_dprod[/=].
342+
case/(supp_dvector _ _ _ ge0_n) => sz_c _.
343+
case/(supp_dvector _ _ _ ge0_n) => sz_r _.
344+
move/supp_dprod=> [/=].
345+
case/(supp_dlist _ _ _ gt0_m) => [sz_cs /allP sz_in_cs].
346+
case/(supp_dlist _ _ _ gt0_p) => [sz_rs /allP sz_in_rs].
347+
rewrite get_mulmx row_trmx /= !col_ofcols //.
348+
- by rewrite size_insert ?sz_cs //#.
349+
- apply/allP=> v /mem_insert [->>|] //=.
350+
by move/sz_in_cs => /(supp_dvector _ _ _ ge0_n).
351+
- by rewrite size_insert ?sz_rs //#.
352+
- apply/allP=> v /mem_insert [->>|] //=.
353+
by move/sz_in_rs => /(supp_dvector _ _ _ ge0_n).
354+
by rewrite !nth_insert // (sz_cs, sz_rs) //#.
355+
rewrite foo /D2 weight_dprod !weight_dlist // !weight_dmap.
356+
rewrite !weight_dlist ?lez_maxr // -!exprM.
357+
congr=> @/D1 @/G => {D1 D2 G} @/dmul.
358+
rewrite !dmap_dprodE /= dlet_dmap lez_maxr //.
359+
apply/in_eq_dlet => //= xs /(supp_dlist _ _ _ ge0_n)[sz_xs _].
360+
rewrite dmap_comp lez_maxr //; apply/eq_dmap_in => /= ys.
361+
case/(supp_dlist _ _ _ ge0_n)=> sz_ys _ @/dotp.
362+
rewrite !size_oflist sz_xs sz_ys lez_maxr //.
363+
apply/eq_sym; rewrite (Big.BAdd.big_nth witness) predT_comp.
364+
rewrite size_zip sz_xs sz_ys lez_minr //.
365+
rewrite !Big.BAdd.big_seq /= &(Big.BAdd.eq_bigr) /=.
366+
move=> k /mem_range rg_k; rewrite !(get_oflist witness) ~-1://#.
367+
have := nth_zip witness witness xs ys k _; first by smt().
368+
by rewrite (nth_change_dfl witness) => [|->//]; rewrite size_zip /#.
369+
qed.

0 commit comments

Comments
 (0)