Skip to content

Commit 1b12249

Browse files
committed
Filter example working
1 parent 33a6f93 commit 1b12249

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

src/ecCircuits.ml

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ module type CircuitInterface = sig
206206

207207
(* Mapreduce/Dependecy analysis related functions *)
208208
val is_decomposable : int -> int -> cbitstring cfun -> bool
209-
val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun) list
209+
val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun) list * (int * int)
210210
val permute : int -> (int -> int) -> cbitstring cfun -> cbitstring cfun
211211

212212
(* Wraps the backend call to deal with args/inputs *)
@@ -320,6 +320,10 @@ module type CBackend = sig
320320
val is_splittable : int -> int -> deps -> bool
321321

322322
val are_independent : block_deps -> bool
323+
324+
val single_dep : deps -> bool
325+
(* Assumes single_dep *)
326+
val dep_range : deps -> int * int
323327
end
324328
end
325329

@@ -425,11 +429,14 @@ module TestBack : CBackend = struct
425429
let get (r: reg) (idx: int) = r.(idx)
426430

427431
let permute (w: int) (perm: int -> int) (r: reg) : reg =
432+
Format.eprintf "Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w;
428433
Array.init (size_of_reg r) (fun i ->
429-
let block_idx, bit_idx = (i / w), (i mod w) in
430-
let idx = (perm block_idx)*w + bit_idx in
431-
r.(idx)
432-
)
434+
let block_idx, bit_idx = perm (i / w), (i mod w) in
435+
if block_idx < 0 then None
436+
else
437+
let idx = block_idx*w + bit_idx in
438+
Some r.(idx)
439+
) |> Array.filter_map (fun x -> x)
433440

434441

435442
(* Node operations *)
@@ -536,17 +543,17 @@ module TestBack : CBackend = struct
536543
| 0 -> true
537544
| 1 ->
538545
let blocks = block_deps_of_deps w_out d in
539-
(* Format.eprintf "Checking block width...@."; *)
546+
Format.eprintf "Checking block width...@.";
540547
Array.for_all (fun (_, d) ->
541548
if Map.is_empty d then true
542549
else
543550
let _, bits = Map.any d in
544551
Set.is_empty bits ||
545552
let base = Set.at_rank_exn 0 bits in
546-
(* Format.eprintf "Base for current block: %d@." base; *)
553+
Format.eprintf "Base for current block: %d@." base;
547554
Set.for_all (fun bit ->
548555
let dist = bit - base in
549-
(* Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; *)
556+
Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in;
550557
0 <= dist && dist < w_in
551558
) bits
552559
) blocks
@@ -576,6 +583,28 @@ module TestBack : CBackend = struct
576583
true
577584
with BreakOut ->
578585
false
586+
587+
588+
let single_dep (d: deps) : bool =
589+
match Set.cardinal
590+
(Array.fold_left (Set.union) Set.empty
591+
(Array.map (fun dep -> Map.keys dep |> Set.of_enum) d))
592+
with
593+
| 0 | 1 -> true
594+
| _ -> false
595+
596+
(* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *)
597+
let dep_range (d: deps) : int * int =
598+
assert (single_dep d);
599+
let idxs =
600+
Array.fold_left (fun acc d ->
601+
Set.union (Map.fold Set.union d Set.empty) acc) Set.empty d
602+
in
603+
Format.eprintf "%a@." pp_deps d;
604+
Format.eprintf "Dep range for dependencies:@.";
605+
Set.iter (fun i -> Format.eprintf "%d " i) idxs;
606+
Format.eprintf "@.Min: %d | Max: %d@." (Set.min_elt idxs) (Set.max_elt idxs);
607+
(Set.min_elt idxs, Set.max_elt idxs + 1)
579608
end
580609

581610
end
@@ -1272,7 +1301,7 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
12721301
let array_oflist (circs : circuit list) (dfl: circuit) (len: int) : circuit =
12731302
let circs, inps = List.split circs in
12741303
let dif = len - List.length circs in
1275-
Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif;
1304+
(* Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *)
12761305
let circs = circs @ (List.init dif (fun _ -> fst dfl)) in
12771306
let inps = if dif > 0 then inps @ [snd dfl] else inps in
12781307
let circs = List.map
@@ -1518,39 +1547,57 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15181547
(* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *)
15191548
let is_decomposable (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : bool =
15201549
match inps with
1521-
| {type_=`CIBitstring w} :: [] when w mod in_w = 0 && Backend.size_of_reg r mod out_w = 0 ->
1550+
| {type_=`CIBitstring w} :: [] when (Backend.size_of_reg r mod out_w = 0) ->
15221551
let deps = Backend.Deps.deps_of_reg r in
1523-
Backend.Deps.is_splittable in_w out_w deps
1552+
Backend.Deps.is_splittable in_w out_w deps &&
1553+
let base, top = Backend.Deps.dep_range deps in
1554+
let () = Format.eprintf "Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in
1555+
(top - base) mod in_w = 0
15241556
| _ ->
15251557
Format.eprintf "Failed decomposition type check@\n";
15261558
Format.eprintf "In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c;
15271559
false
15281560

1529-
let split_renamer (n: count) (in_w: width) (inp: cinp) : (cinp array) * (Backend.inp -> cbool_type option) =
1530-
match inp with
1531-
| {type_ = `CIBitstring w; id} when w mod in_w = 0 ->
1561+
let split_renamer ?(range: (int * int) option) (n: count) (in_w: width) (inp: cinp) : (cinp array) * (Backend.inp -> cbool_type option) =
1562+
match range, inp with
1563+
| Some (start_idx, end_idx), {type_ = `CIBitstring w; id} when (end_idx - start_idx) mod in_w = 0 ->
1564+
let ids = Array.init n (fun i -> create ("split_" ^ (string_of_int i)) |> tag) in
1565+
Array.map (fun id -> {type_ = `CIBitstring in_w; id}) ids,
1566+
(fun (id_, w) ->
1567+
let w = w - start_idx in (* FIXME: check if this doesn't cause problems on the upper end *)
1568+
if id <> id_ || w < 0 || w >= end_idx then None else
1569+
let id_idx, bit_idx = (w / in_w), (w mod in_w) in
1570+
Some (Backend.input_node ~id:ids.(id_idx) bit_idx))
1571+
| None, {type_ = `CIBitstring w; id} when w mod in_w = 0 ->
15321572
let ids = Array.init n (fun i -> create ("split_" ^ (string_of_int i)) |> tag) in
15331573
Array.map (fun id -> {type_ = `CIBitstring in_w; id}) ids,
15341574
(fun (id_, w) ->
15351575
if id <> id_ then None else
15361576
let id_idx, bit_idx = (w / in_w), (w mod in_w) in
15371577
Some (Backend.input_node ~id:ids.(id_idx) bit_idx))
1578+
| _, {type_ = `CIBitstring w; id} ->
1579+
Format.eprintf "Failed to build split renamer for n=%d in_w=%d w=%d" n in_w w;
1580+
Option.may (fun (bot, top) -> Format.eprintf "range=(%d, %d)" bot top) range;
1581+
Format.eprintf "@.";
1582+
assert false
15381583
| _ -> assert false
15391584

1540-
let decompose (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : cbitstring cfun list =
1585+
let decompose (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : cbitstring cfun list * (int * int) =
15411586
if not (is_decomposable in_w out_w c) then
15421587
let deps = Backend.Deps.block_deps_of_reg out_w r in
15431588
Format.eprintf "Failed to decompose. in_w=%d out_w=%d Deps:@.%a" in_w out_w (Backend.Deps.pp_block_deps) deps;
15441589
assert false
15451590
else
1591+
(* TODO: don't repeat dependecy computation ? *)
1592+
let dep_range = Backend.Deps.dep_range (Backend.Deps.deps_of_reg r) in
15461593
let n = (Backend.size_of_reg r) / out_w in
15471594
let blocks = Array.init n (fun i ->
15481595
Backend.slice r (i*out_w) out_w) in
1549-
let cinps, renamer = split_renamer n in_w (List.hd inps) in
1596+
let cinps, renamer = split_renamer ~range:dep_range n in_w (List.hd inps) in
15501597
Array.map2 (fun r inp ->
15511598
let r = Backend.applys renamer r in
15521599
(`CBitstring r, [inp])
1553-
) blocks cinps |> Array.to_list
1600+
) blocks cinps |> Array.to_list, dep_range
15541601

15551602
let permute (w: width) (perm: (int -> int)) ((`CBitstring r, inps): cbitstring cfun) : cbitstring cfun =
15561603
`CBitstring (Backend.permute w perm r), inps
@@ -2164,13 +2211,13 @@ let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
21642211
in
21652212
(permute bsz perm c :> circuit)
21662213

2167-
let circuit_mapreduce ?(perm : (int -> int) option) (c: circuit) (w_in: width) (w_out: width) : circuit list =
2214+
let circuit_mapreduce ?(perm : (int -> int) option) (c: circuit) (w_in: width) (w_out: width) : circuit list * (int * int) =
21682215
let c = match c, perm with
21692216
| (`CBitstring _, inps) as c, None -> c
21702217
| (`CBitstring _, inps) as c, Some perm -> permute w_out perm c
21712218
| _ -> assert false
21722219
in
2173-
(decompose w_in w_out c :> circuit list)
2220+
(decompose w_in w_out c :> circuit list * (int * int))
21742221

21752222
type circuit = ExampleInterface.circuit
21762223
type pstate = ExampleInterface.PState.pstate

src/ecCircuits.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ val circuit_aggregate : circuit list -> circuit
3939
val circuit_aggregate_inps : circuit -> circuit
4040
val circuit_flatten : circuit -> circuit
4141
val circuit_permute : int -> (int -> int) -> circuit -> circuit
42-
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list
42+
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list * (int * int)
4343

4444
(* Use circuits *)
4545
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint

0 commit comments

Comments
 (0)