|
| 1 | +(* -------------------------------------------------------------------- *) |
| 2 | +require import AllCore List Distr DList Number StdOrder StdBigop. |
| 3 | +require import RealSeries. |
| 4 | +require (*--*) DynMatrix. |
| 5 | +(*---*) import IntOrder RealOrder RField Bigint Bigreal. |
| 6 | + |
| 7 | +(* -------------------------------------------------------------------- *) |
| 8 | +clone import DynMatrix as DM. |
| 9 | +(*-*) import DM.ZR. |
| 10 | + |
| 11 | +(* -------------------------------------------------------------------- *) |
| 12 | +abbrev "_.[_]" ['a] (xs : 'a list) (i : int) = nth<:'a> witness xs i. |
| 13 | + |
| 14 | +(* -------------------------------------------------------------------- *) |
| 15 | +lemma compE ['a 'b 'c] (f : 'a -> 'b) (g : 'b -> 'c) (x : 'a) : |
| 16 | + (g \o f) x = g (f x). |
| 17 | +proof. done. qed. |
| 18 | +
|
| 19 | +hint simplify compE. |
| 20 | +
|
| 21 | +(* -------------------------------------------------------------------- *) |
| 22 | +lemma dlist_ubound (n : int) (d : R distr) (E : R -> bool) : 0 <= n => |
| 23 | + mu |
| 24 | + (dlist d n) |
| 25 | + (fun xs => exists i, 0 <= i < n /\ E xs.[i]) |
| 26 | + <= n%r * mu d E. |
| 27 | +proof. |
| 28 | +elim: n => /= [|n ge0_n ih]; first by rewrite dlist0 // dunitE //#. |
| 29 | +rewrite dlistS //= dmapE /(\o) /=. |
| 30 | +pose P1 (x : R) := E x. |
| 31 | +pose P2 (xs : R list) := exists i, (0 <= i < n /\ E xs.[i]). |
| 32 | +pose P (x_xs : R * R list) := P1 x_xs.`1 \/ P2 x_xs.`2. |
| 33 | +rewrite (mu_eq_support _ _ P). |
| 34 | +- case=> [x xs] /supp_dprod /= [_]. |
| 35 | + case/(supp_dlist _ _ _ ge0_n) => [sz_xs _]. |
| 36 | + rewrite /P /=; apply/eq_iff; split; first smt(). |
| 37 | + case=> [Ex|]; first exists 0; smt(). |
| 38 | + by case=> i rg_i; exists (i+1) => //#. |
| 39 | +apply: (ler_trans _ _ _ (le_dprod_or _ _ _ _)). |
| 40 | +rewrite fromintD mulrDl /= addrC ler_add. |
| 41 | +- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)). |
| 42 | +- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)). |
| 43 | +qed. |
| 44 | +
|
| 45 | +(* -------------------------------------------------------------------- *) |
| 46 | +op dadd (d1 d2 : R distr) = |
| 47 | + dmap (d1 `*` d2) (fun xy : R * R => xy.`1 + xy.`2). |
| 48 | +
|
| 49 | +(* -------------------------------------------------------------------- *) |
| 50 | +lemma dlistD (n : int) (d1 d2 : R distr) : 0 <= n => |
| 51 | + dlet (dlist d1 n) (fun (xs : R list) => |
| 52 | + dmap (dlist d2 n) (fun (ys : R list) => |
| 53 | + mkseq (fun i => xs.[i] + ys.[i]) n)) |
| 54 | + = dlist (dadd d1 d2) n. |
| 55 | +proof. |
| 56 | +pose S n (xs ys : R list) := mkseq (fun i => xs.[i] + ys.[i]) n. |
| 57 | +pose T n (xs : R list * R list) := S n xs.`1 xs.`2. |
| 58 | +move=> ge0_n; rewrite -(dmap_dprodE _ _ (T n)). (* SLOW *) |
| 59 | +elim: n ge0_n => /= [|n ge0_n ih]; last rewrite !dlistS //. |
| 60 | +- by rewrite !dlist0 // dprod_dunit dmap_dunit /T /S /= mkseq0. |
| 61 | +pose C (x_xs : R * R list) := x_xs.`1 :: x_xs.`2. |
| 62 | +pose F (x : R * R) (xs : R list * R list) := |
| 63 | + S (n+1) (x.`1 :: xs.`1) (x.`2 :: xs.`2). |
| 64 | +pose G (xs ys : R list) := S (n+1) xs ys. |
| 65 | +rewrite dmap_dprodE; have -> := dprod_dmap_cross |
| 66 | + d1 (dlist d1 n) d2 (dlist d2 n) C C G idfun idfun F _; first by done. |
| 67 | +rewrite !dmap_id /= dmap_dprodE {1}/dadd dlet_dmap. |
| 68 | +apply/eq_dlet => // -[x y] /=; rewrite -ih. |
| 69 | +rewrite dmap_comp &(eq_dmap) => -[xs ys] /=. |
| 70 | +by rewrite /F /S /C /T /= mkseqSr //= &(eq_in_mkseq) //#. |
| 71 | +qed. |
| 72 | +
|
| 73 | +(* -------------------------------------------------------------------- *) |
| 74 | +lemma dmatrix_dlist (r c : int) (d : R distr) : |
| 75 | + 0 <= r => 0 <= c => dmatrix d r c = |
| 76 | + dmap |
| 77 | + (dlist d (r * c)) |
| 78 | + (fun vs => offunm ((fun i j => vs.[j * r + i]), r, c)). |
| 79 | +proof. |
| 80 | +move=> ge0_r ge0_c @/dmatrix @/dvector. |
| 81 | +rewrite dlist_dmap dmap_comp !lez_maxr //. |
| 82 | +rewrite -dlist_dlist // dmap_comp &(eq_dmap_in) => xss /=. |
| 83 | +case/(supp_dlist _ _ _ ge0_c) => size_xss /allP xssE. |
| 84 | +have {xssE} xssE: forall xs, xs \in xss => size xs = r. |
| 85 | +- by move=> xs /xssE /(supp_dlist _ _ _ ge0_r). |
| 86 | +apply/eq_matrixP=> @/ofcols /= i j []. |
| 87 | +rewrite !lez_maxr // => rgi rgj. |
| 88 | +rewrite !get_offunm /= ?lez_maxr //. |
| 89 | +rewrite (nth_map witness) 1:/#. |
| 90 | +rewrite (get_oflist witness) 1:#smt:(mem_nth). |
| 91 | +rewrite -nth_flatten ~-1:#smt:(mem_nth); do 2! congr. |
| 92 | +rewrite sumzE BIA.big_map predT_comp /(\o) /=. |
| 93 | +pose D := BIA.big predT (fun _ => r) (take j xss). |
| 94 | +apply: (eq_trans _ D) => @/D. |
| 95 | +- rewrite !BIA.big_seq &(BIA.eq_bigr) //=. |
| 96 | + by move=> xs /mem_take /xssE. |
| 97 | +by rewrite big_constz count_predT size_take //#. |
| 98 | +qed. |
| 99 | +
|
| 100 | +(* -------------------------------------------------------------------- *) |
| 101 | +lemma dmatrixD (r c : int) (d1 d2 : R distr) : 0 <= r => 0 <= c => |
| 102 | + dlet (dmatrix d1 r c) (fun (m1 : matrix) => |
| 103 | + dmap (dmatrix d2 r c) (fun (m2 : matrix) => m1 + m2)) |
| 104 | + = dmatrix (dadd d1 d2) r c. |
| 105 | +proof. |
| 106 | +move=> ge0_r ge0_c; rewrite 2?dmatrix_dlist //=. |
| 107 | +pose F vs := offunm (fun i j => vs.[j * r + i], r, c). |
| 108 | +rewrite dlet_dmap /= dlet_swap dlet_dmap /= dlet_swap /=. |
| 109 | +rewrite dmatrix_dlist // -/F -dlistD ~-1:/#. |
| 110 | +rewrite dmap_dlet &(eq_dlet) // => xs /=. |
| 111 | +rewrite dlet_dunit dmap_comp &(eq_dmap) => ys /=. |
| 112 | +apply/eq_matrixP; split. |
| 113 | +- by rewrite /F size_addm !size_offunm. |
| 114 | +move=> i j []; rewrite rows_addm cols_addm /=. |
| 115 | +rewrite !rows_offunm !cols_offunm !maxzz => rgi rgj. |
| 116 | +by rewrite get_addm !get_offunm //= nth_mkseq //#. |
| 117 | +qed. |
| 118 | +
|
| 119 | +(* -------------------------------------------------------------------- *) |
| 120 | +op dmul (n : int) (d1 d2 : R distr) = |
| 121 | + dmap |
| 122 | + (dlist d1 n `*` dlist d2 n) |
| 123 | + (fun vs : R list * R list => |
| 124 | + DM.Big.BAdd.big predT |
| 125 | + (fun xy : R * R => xy.`1 * xy.`2) |
| 126 | + (zip vs.`1 vs.`2)). |
| 127 | +
|
| 128 | +(* -------------------------------------------------------------------- *) |
| 129 | +lemma dmatrix_cols (d : R distr) (r c : int) : 0 <= c => 0 <= r => |
| 130 | + dmatrix d r c = dmap (dlist (dvector d r) c) (ofcols r c). |
| 131 | +proof. by move=> ge0_c ge0_r @/dmatrix; rewrite lez_maxr. qed. |
| 132 | +
|
| 133 | +(* -------------------------------------------------------------------- *) |
| 134 | +lemma dmatrix_rows (d : R distr) (r c : int) : 0 <= c => 0 <= r => |
| 135 | + dmatrix d r c = dmap (dlist (dvector d c) r) (trmx \o ofcols c r). |
| 136 | +proof. |
| 137 | +move=> ge0_r ge0_c; rewrite -dmap_comp -dmatrix_cols //. |
| 138 | +apply/eq_distr => /= m; rewrite (dmap1E _ trmx). |
| 139 | +have ->: pred1 m \o trmx = pred1 (trmx m) by smt(trmxK). |
| 140 | +case: (size m = (r, c)); last first. |
| 141 | +- by move=> ne_size; rewrite !dmatrix0E //#. |
| 142 | +case=> <<- <<-; rewrite -{2}rows_tr -{2}cols_tr !dmatrix1E /=. |
| 143 | +by rewrite BRM.exchange_big. |
| 144 | +qed. |
| 145 | +
|
| 146 | +(* -------------------------------------------------------------------- *) |
| 147 | +hint simplify drop0, take0, cats0, cat0s. |
| 148 | +
|
| 149 | +(* -------------------------------------------------------------------- *) |
| 150 | +lemma dmatrix_cols_i (i : int) (d : R distr) (r c : int) : |
| 151 | + 0 <= c => 0 <= r => 0 <= i < c => |
| 152 | + dmatrix d r c = |
| 153 | + dmap |
| 154 | + (dvector d r `*` dlist (dvector d r) (c-1)) |
| 155 | + (fun c_cs : _ * _ => ofcols r c (insert c_cs.`1 c_cs.`2 i)). |
| 156 | +proof. |
| 157 | +move=> ge0_c ge0_r rgi; rewrite dmatrix_cols //. |
| 158 | +rewrite {1}(_ : c = (c - 1) + 1) // (dlist_insert witness i) ~-1://# /=. |
| 159 | +by rewrite dmap_comp &(eq_dmap) => -[v vs]. |
| 160 | +qed. |
| 161 | +
|
| 162 | +(* -------------------------------------------------------------------- *) |
| 163 | +lemma dmatrix_rows_i (j : int) (d : R distr) (r c : int) : |
| 164 | + 0 <= c => 0 <= r => 0 <= j < r => |
| 165 | + dmatrix d r c = |
| 166 | + dmap |
| 167 | + (dvector d c `*` dlist (dvector d c) (r-1)) |
| 168 | + (fun r_rs : _ * _ => trmx (ofcols c r (insert r_rs.`1 r_rs.`2 j))). |
| 169 | +proof. |
| 170 | +move=> ge0_c ge0_r rgj; rewrite dmatrix_rows //. |
| 171 | +rewrite {1}(_ : r = (r - 1) + 1) // (dlist_insert witness j) ~-1://# /=. |
| 172 | +by rewrite dmap_comp &(eq_dmap) => -[v vs]. |
| 173 | +qed. |
| 174 | +
|
| 175 | +(* -------------------------------------------------------------------- *) |
| 176 | +lemma col_ofcols (i r c : int) (vs : vector list) : |
| 177 | + 0 <= r => 0 <= c => 0 <= i < c |
| 178 | + => size vs = c |
| 179 | + => all (fun v : vector => size v = r) vs |
| 180 | + => col (ofcols r c vs) i = vs.[i]. |
| 181 | +proof. |
| 182 | +move=> ge0_r ge0_c rgi sz_vs /allP => sz_in_vs. |
| 183 | +have sz_rows: rows (ofcols r c vs) = r. |
| 184 | +- by rewrite rows_offunm lez_maxr // sz_in_vs. |
| 185 | +apply/eq_vectorP; split=> /=. |
| 186 | +- by rewrite sz_rows sz_in_vs // &(mem_nth) sz_vs. |
| 187 | +by move=> j; rewrite sz_rows => rgj; rewrite get_offunm //#. |
| 188 | +qed. |
| 189 | +
|
| 190 | +(* -------------------------------------------------------------------- *) |
| 191 | +lemma dmatrixM (m n p : int) (d1 d2 : R distr) : |
| 192 | + 0 <= m => 0 <= n => 0 <= p => |
| 193 | +
|
| 194 | + let d = |
| 195 | + dlet (dmatrix d1 m n) (fun (m1 : matrix) => |
| 196 | + dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in |
| 197 | +
|
| 198 | + forall i j, 0 <= i < m => 0 <= j < p => |
| 199 | + dmap d (fun m => m.[i, j]) = |
| 200 | + ((weight d1) ^ (n * (m-1)) * (weight d2) ^ (n * (p-1))) \cdot dmul n d1 d2. |
| 201 | +proof. |
| 202 | +move=> ge0_m ge0_n ge0_p d i j rg_i rg_j. |
| 203 | +have [gt0_m gt0_p]: (0 <= m-1) /\ (0 <= p-1) by smt(). |
| 204 | +rewrite /d (dmatrix_rows_i i) //= (dmatrix_cols_i j) //=. |
| 205 | +pose D1 := dvector d1 n `*` _; pose D2 := dvector d2 n `*` _. |
| 206 | +pose F1 := fun (r_rs : _ * _) => trmx (ofcols n m (insert r_rs.`1 r_rs.`2 i)). |
| 207 | +pose F2 := fun (c_cs : _ * _) => ofcols n p (insert c_cs.`1 c_cs.`2 j). |
| 208 | +pose F r rs c cs := (trmx (ofcols n m (insert r rs i)) * ofcols n p (insert c cs j)).[i, j]. |
| 209 | +pose D := dlet D1 (fun c : _ * _ => dmap D2 (fun r : _ * _ => F c.`1 c.`2 r.`1 r.`2)). |
| 210 | +apply: (eq_trans _ D) => @/D => {D}. |
| 211 | +- rewrite dmap_dlet dlet_dmap /= &(eq_dlet) // => ? /=. |
| 212 | + by rewrite 2!dmap_comp &(eq_dmap). |
| 213 | +pose G (x_xs : (_ * _) * (_ * _)) := F x_xs.`1.`1 x_xs.`2.`1 x_xs.`1.`2 x_xs.`2.`2. |
| 214 | +rewrite dprod_cross /= => {D1 D2}; pose D1 := _ `*` _; pose D2 := _ `*` _. |
| 215 | +have @/G /= <- := dmap_dprodE D1 D2 G => {G}. |
| 216 | +pose G (vs : vector * vector) := dotp vs.`1 vs.`2. |
| 217 | +apply: (eq_trans _ (dmap (D1 `*` D2) (fun x : _ * _ => G x.`1))). |
| 218 | +- apply: eq_dmap_in=> -[[c r] [cs rs]] @/G @/F /=. |
| 219 | + case/supp_dprod=> /= /supp_dprod[/=]. |
| 220 | + case/(supp_dvector _ _ _ ge0_n) => sz_c _. |
| 221 | + case/(supp_dvector _ _ _ ge0_n) => sz_r _. |
| 222 | + move/supp_dprod=> [/=]. |
| 223 | + case/(supp_dlist _ _ _ gt0_m) => [sz_cs /allP sz_in_cs]. |
| 224 | + case/(supp_dlist _ _ _ gt0_p) => [sz_rs /allP sz_in_rs]. |
| 225 | + rewrite get_mulmx row_trmx /= !col_ofcols //. |
| 226 | + - by rewrite size_insert ?sz_cs //#. |
| 227 | + - apply/allP=> v /mem_insert [->>|] //=. |
| 228 | + by move/sz_in_cs => /(supp_dvector _ _ _ ge0_n). |
| 229 | + - by rewrite size_insert ?sz_rs //#. |
| 230 | + - apply/allP=> v /mem_insert [->>|] //=. |
| 231 | + by move/sz_in_rs => /(supp_dvector _ _ _ ge0_n). |
| 232 | + by rewrite !nth_insert // (sz_cs, sz_rs) //#. |
| 233 | +rewrite dprod_marginalL /D2 weight_dprod !weight_dlist // !weight_dmap. |
| 234 | +rewrite !weight_dlist ?lez_maxr // -!exprM. |
| 235 | +congr=> @/D1 @/G => {D1 D2 G} @/dmul. |
| 236 | +rewrite !dmap_dprodE /= dlet_dmap lez_maxr //. |
| 237 | +apply/in_eq_dlet => //= xs /(supp_dlist _ _ _ ge0_n)[sz_xs _]. |
| 238 | +rewrite dmap_comp lez_maxr //; apply/eq_dmap_in => /= ys. |
| 239 | +case/(supp_dlist _ _ _ ge0_n)=> sz_ys _ @/dotp. |
| 240 | +rewrite !size_oflist sz_xs sz_ys lez_maxr //. |
| 241 | +apply/eq_sym; rewrite (Big.BAdd.big_nth witness) predT_comp. |
| 242 | +rewrite size_zip sz_xs sz_ys lez_minr //. |
| 243 | +rewrite !Big.BAdd.big_seq /= &(Big.BAdd.eq_bigr) /=. |
| 244 | +move=> k /mem_range rg_k; rewrite !(get_oflist witness) ~-1://#. |
| 245 | +have := nth_zip witness witness xs ys k _; first by smt(). |
| 246 | +by rewrite (nth_change_dfl witness) => [|->//]; rewrite size_zip /#. |
| 247 | +qed. |
| 248 | +
|
| 249 | +(* -------------------------------------------------------------------- *) |
| 250 | +lemma dmatrixM_ll (m n p : int) (d1 d2 : R distr) : |
| 251 | + 0 <= m => 0 <= n => 0 <= p => |
| 252 | +
|
| 253 | + is_lossless d1 => is_lossless d2 => |
| 254 | +
|
| 255 | + let d = |
| 256 | + dlet (dmatrix d1 m n) (fun (m1 : matrix) => |
| 257 | + dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in |
| 258 | +
|
| 259 | + forall i j, 0 <= i < m => 0 <= j < p => |
| 260 | + dmap d (fun m => m.[i, j]) = dmul n d1 d2. |
| 261 | +proof. |
| 262 | +move=> *; rewrite dmatrixM //; pose c := (_ * _)%Real. |
| 263 | +rewrite (_ : c = 1%r) -1:dscalar1 // /c. |
| 264 | +by do 2! rewrite (_ : weight _ = 1%r) // expr1z. |
| 265 | +qed. |
0 commit comments