diff --git a/.gitignore b/.gitignore index bc89608..81c72b9 100644 --- a/.gitignore +++ b/.gitignore @@ -21,5 +21,6 @@ plugin/Makefile.coq plugin/Makefile.conf plugin/Makefile plugin/.merlin +*.merlin *.out _opam diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e69de29 diff --git a/LICENSE b/LICENSE index cbee369..3fa32cf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2019 Talia Ringer, Nate Yazdani +Copyright (c) 2021 PUMPKIN PATCH Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index ebb8db3..36fa2d6 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ See [PUMPKIN PATCH](https://github.com/uwplse/PUMPKIN-PATCH) and [DEVOID](https: ## Guide +This code guide is out of date and needs to be updated due to the switch to use Dune. + * [LICENSE](/LICENSE): License * [README.md](/README.md): You are here! * [build.sh](/build.sh): Build script for example plugin @@ -32,7 +34,7 @@ See [PUMPKIN PATCH](https://github.com/uwplse/PUMPKIN-PATCH) and [DEVOID](https: ## Contributors -This library was developed by Talia Ringer, Nate Yazdani, and RanDair Porter. +This library was developed by Talia Ringer, Nate Yazdani, RanDair Porter, and Emily First. Probably none of it would build without Emilio's help. ## Licensing diff --git a/_CoqProject b/_CoqProject deleted file mode 100644 index 339b268..0000000 --- a/_CoqProject +++ /dev/null @@ -1,97 +0,0 @@ --I src/utilities --I src/coq --I src/coq/termutils --I src/coq/constants --I src/coq/logicutils --I src/coq/logicutils/contexts --I src/coq/logicutils/typesandequality --I src/coq/logicutils/hofs --I src/coq/logicutils/inductive --I src/coq/logicutils/transformation --I src/coq/devutils --I src/coq/representationutils --I src/coq/decompiler --I src --R src Plibrary --Q theories Plibrary - -src/utilities/utilities.mli -src/utilities/utilities.ml - -src/coq/termutils/apputils.mli -src/coq/termutils/apputils.ml -src/coq/termutils/constutils.mli -src/coq/termutils/constutils.ml -src/coq/termutils/funutils.mli -src/coq/termutils/funutils.ml - -src/coq/representationutils/defutils.mli -src/coq/representationutils/defutils.ml -src/coq/representationutils/nameutils.mli -src/coq/representationutils/nameutils.ml - -src/coq/logicutils/typesandequality/inference.mli -src/coq/logicutils/typesandequality/inference.ml -src/coq/logicutils/typesandequality/convertibility.mli -src/coq/logicutils/typesandequality/convertibility.ml -src/coq/logicutils/typesandequality/checking.mli -src/coq/logicutils/typesandequality/checking.ml - -src/coq/constants/equtils.mli -src/coq/constants/equtils.ml -src/coq/constants/sigmautils.mli -src/coq/constants/sigmautils.ml -src/coq/constants/produtils.mli -src/coq/constants/produtils.ml -src/coq/constants/idutils.mli -src/coq/constants/idutils.ml -src/coq/constants/proputils.ml -src/coq/constants/proputils.mli - -src/coq/logicutils/contexts/stateutils.mli -src/coq/logicutils/contexts/stateutils.ml -src/coq/logicutils/contexts/envutils.mli -src/coq/logicutils/contexts/envutils.ml -src/coq/logicutils/contexts/contextutils.mli -src/coq/logicutils/contexts/contextutils.ml - -src/coq/logicutils/hofs/hofs.mli -src/coq/logicutils/hofs/hofs.ml -src/coq/logicutils/hofs/hofimpls.mli -src/coq/logicutils/hofs/hofimpls.ml -src/coq/logicutils/hofs/debruijn.mli -src/coq/logicutils/hofs/debruijn.ml -src/coq/logicutils/hofs/substitution.mli -src/coq/logicutils/hofs/substitution.ml -src/coq/logicutils/hofs/reducers.mli -src/coq/logicutils/hofs/reducers.ml -src/coq/logicutils/hofs/typehofs.mli -src/coq/logicutils/hofs/typehofs.ml -src/coq/logicutils/hofs/zooming.mli -src/coq/logicutils/hofs/zooming.ml -src/coq/logicutils/hofs/hypotheses.mli -src/coq/logicutils/hofs/hypotheses.ml -src/coq/logicutils/hofs/filters.mli -src/coq/logicutils/hofs/filters.ml - -src/coq/logicutils/inductive/indexing.mli -src/coq/logicutils/inductive/indexing.ml -src/coq/logicutils/inductive/indutils.mli -src/coq/logicutils/inductive/indutils.ml - -src/coq/logicutils/contexts/modutils.mli -src/coq/logicutils/contexts/modutils.ml - -src/coq/logicutils/transformation/transform.mli -src/coq/logicutils/transformation/transform.ml - -src/coq/devutils/printing.mli -src/coq/devutils/printing.ml - -src/coq/decompiler/decompiler.mli -src/coq/decompiler/decompiler.ml - -src/plibrary.ml4 -src/plib.mlpack - -theories/Plib.v diff --git a/build.sh b/build.sh index 99853cb..d280d59 100755 --- a/build.sh +++ b/build.sh @@ -1,2 +1,7 @@ -coq_makefile -f _CoqProject -o Makefile -make clean && make && make install +opam pin dune 2.7.1 +opam pin coq-serapi 8.9.0+0.6.1 +opam install lymp +dune clean +dune build @all +dune build @all +dune install diff --git a/coq-plugin-lib.opam b/coq-plugin-lib.opam new file mode 100644 index 0000000..7577de2 --- /dev/null +++ b/coq-plugin-lib.opam @@ -0,0 +1,19 @@ +synopsis: "Coq Plugin Library" +description: "Coq Plugin Library" +name: "coq-plugin-lib" +opam-version: "2.0" +maintainer: "talia@dependenttyp.es" +authors: "Talia Ringer" +homepage: "https://github.com/uwplse/coq-plugin-lib" +bug-reports: "https://github.com/uwplse/coq-plugin-lib" +dev-repo: "git+https://github.com/uwplse/coq-plugin-lib" +license: "MIT" +doc: "https://github.com/uwplse/coq-plugin-lib" + +depends: [ + "ocaml" { = "4.07.1+flambda" } + "coq" { = "8.9.1" } + "dune" { build & >= "1.9.0" & <= "2.7.1" } +] + +build: [ "dune" "build" "-p" name "-j" jobs ] diff --git a/dune b/dune new file mode 100644 index 0000000..d8d3aa2 --- /dev/null +++ b/dune @@ -0,0 +1,3 @@ +(env + (dev (flags (:standard -rectypes -w -8-33-3-27-28-32))) + (release (flags (:standard -rectypes -w -8-33-3-27-28-32)))) diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..709e0e9 --- /dev/null +++ b/dune-project @@ -0,0 +1,3 @@ +(lang dune 2.7.1) +(using coq 0.2) +(name coq-plugin-lib) diff --git a/src/coq/constants/dune b/src/coq/constants/dune new file mode 100644 index 0000000..6d2f28a --- /dev/null +++ b/src/coq/constants/dune @@ -0,0 +1,8 @@ +(library + (name constants) + (public_name coq-plugin-lib.constants) + (libraries + coq-plugin-lib.inference + coq-plugin-lib.termutils + coq.kernel) + (wrapped false)) diff --git a/src/coq/decompiler/decompiler.ml b/src/coq/decompiler/decompiler.ml index a7c24bd..1994b0a 100644 --- a/src/coq/decompiler/decompiler.ml +++ b/src/coq/decompiler/decompiler.ml @@ -15,6 +15,14 @@ open Zooming open Nameutils open Ltac_plugin open Stateutils +open List +open Tactok + +(* + * This is a minimal, sound version of the decompiler with our own heuristics + * and improvements disabled. We can add those back later, but I think it's + * easier to start with this to experiment with the decompiler. + *) (* Monadic bind on option types. *) let (>>=) = Option.bind @@ -35,16 +43,23 @@ let parse_tac_str (s : string) : unit Proofview.tactic = (* Run a coq tactic against a given goal, returning generated subgoals *) let run_tac env sigma (tac : unit Proofview.tactic) (goal : constr) - : Goal.goal list * Evd.evar_map = + : Goal.goal list state = let p = Proof.start sigma [(env, EConstr.of_constr goal)] in let (p', _) = Proof.run_tactic env tac p in let (subgoals, _, _, _, sigma) = Proof.proof p' in - subgoals, sigma + sigma, subgoals + +(* Convert a coq-generated subgoal into its context environment and goal type. *) +let get_context_goal env sigma (g : Goal.goal) : env * types = + let context_size = List.length (named_context (Goal.V82.env sigma g)) in + let abstr = EConstr.to_constr sigma (Goal.V82.abstract_type sigma g) in + Zooming.zoom_n_prod (reset_context env) context_size abstr + (* Returns true if the given tactic solves the goal. *) let solves env sigma (tac : unit Proofview.tactic) (goal : constr) : bool state = try - let subgoals, sigma = run_tac env sigma tac goal in + let sigma, subgoals = run_tac env sigma tac goal in sigma, subgoals = [] with _ -> sigma, false @@ -88,7 +103,7 @@ type tactical = stored in reverse order to push in constant time. *) | Compose of tact list * (tactical list) -(* Return the string representation of a single tactic. *) +(* Return the Pp.t representation of a single tactic. *) let show_tactic sigma tac : Pp.t = let prnt e = Printer.pr_constr_env e sigma in match tac with @@ -132,18 +147,15 @@ let show_tactic sigma tac : Pp.t = | Auto -> str "auto" | Expr s -> str s +(* Return the string representation of a single tactic. *) +let show_tactic_string sigma t = + let s = show_tactic sigma t in + Format.asprintf "%a" Pp.pp_with s + (* Convert IR tactic to coq tactic by printing and parsing. *) let coq_tac sigma t prefix = - let s = show_tactic sigma t in - let s' = Format.asprintf "%a" Pp.pp_with s in - parse_tac_str (prefix ^ s') + parse_tac_str (prefix ^ show_tactic_string sigma t) -(* True if both tactics are "equal" (syntactically). *) -let compare_tact sigma (t1 : tact) (t2 : tact) : bool = - let s1 = show_tactic sigma t1 in - let s2 = show_tactic sigma t2 in - Pp.string_of_ppcmds s1 = Pp.string_of_ppcmds s2 - (* Option monad over function application. *) let try_app (trm : constr) : (constr * constr array) option = match kind trm with @@ -172,92 +184,11 @@ let guard (b : bool) : unit option = if b then Some () else None (* Single dotted tactic. *) -let dot tac next = Some (Compose ([ tac ], [ next ])) +let dot sigma tac next : tactical state option = Some (sigma, Compose ([ tac ], [ next ])) (* Single tactic to finish proof. *) -let qed tac = Some (Compose ([ tac ], [])) +let qed sigma tac = Some (sigma, Compose ([ tac ], [])) -(* Inserts "simpl." before every rewrite. *) -let rec simpl sigma (t : tactical) : tactical = - match t with - (*| Compose ( [ Rewrite (env, b, c, Some goal) ], goal_prfs) -> - let r = Rewrite (env, b, c, Some goal) in - let rest = Compose ([ r ], List.map (simpl sigma) goal_prfs) in - (try - Printing.debug_term env b "REWRITE: "; - Printing.debug_term env goal "GOAL: "; - let goals1, sigma = run_tac env sigma (coq_tac sigma r "") goal in - let goals2, sigma = run_tac env sigma (coq_tac sigma r "simpl;") goal in - let goals1 = List.map (Goal.V82.abstract_type sigma) goals1 in - let goals2 = List.map (Goal.V82.abstract_type sigma) goals2 in - if list_eq (EConstr.eq_constr sigma) goals1 goals2 - then rest else Compose ([ Simpl ], [ rest ]) - with _ -> rest) *) - | Compose ( [ Rewrite (a, b, c, d) ], goals) -> - Compose ([ Simpl ], [ Compose ([ Rewrite (a, b, c, d) ], - List.map (simpl sigma) goals)]) - | Compose (tacs, goals) -> - Compose (tacs, List.map (simpl sigma) goals) - - -(* Combine adjacent intros and revert tactics if possible. *) -let rec intros_revert (t : tactical) : tactical = - match t with - | Compose ( [ Intros xs ], [ Compose ([ Revert ys ], goals) ]) -> - let n = count_shared_prefix Id.equal (List.rev xs) ys in - let xs' = take (List.length xs - n) xs in - let ys' = drop n ys in - let goals' = List.map intros_revert goals in - (* Don't include empty name lists! *) - let c1 = if ys' == [] then goals' else [ Compose ([ Revert ys' ], goals') ] in - if xs' == [] then List.hd c1 else Compose ([ Intros xs' ], c1) - | Compose (tacs, goals) -> - Compose (tacs, List.map intros_revert goals) - -(* Combine common subgoal tactics into semicolons. *) -let rec semicolons sigma (t : tactical) : tactical = - let first t = match t with - | Compose ( [ tac ], _) -> tac in - let subgoals t = match t with - | Compose ( _, goals) -> goals in - match t with - (* end of proof *) - | Compose (_, []) -> t - (* single subgoal, don't bother *) - | Compose ( tacs, [ goal ]) -> - Compose ( tacs, [ semicolons sigma goal ]) - (* compare first tactic of each subgoal *) - | Compose ( tacs, goals ) -> - let firsts = List.map first goals in - if all_eq (compare_tact sigma) firsts - then - let goals' = List.concat (List.map subgoals goals) in - semicolons sigma (Compose ( (List.hd firsts) :: tacs, goals')) - else - Compose (tacs, List.map (semicolons sigma) goals) - -(* Try implicit arguments to rewrite functions. *) -let rec rewrite_implicit sigma (t : tactical) : tactical = - try - match t with - | Compose ( [ Rewrite (env, fx, dir, Some goal) ], [ goal_prf ]) -> - let rest = [ rewrite_implicit sigma goal_prf ] in - let r1 = Rewrite (env, fx, dir, Some goal) in - (match kind fx with - | App (f, args) -> - let r2 = Rewrite (env, f, dir, Some goal) in - let goals1, sigma = run_tac env sigma (coq_tac sigma r1 "") goal in - let goals2, sigma = run_tac env sigma (coq_tac sigma r2 "") goal in - let goals1 = List.map (Goal.V82.abstract_type sigma) goals1 in - let goals2 = List.map (Goal.V82.abstract_type sigma) goals2 in - let choice = if list_eq (EConstr.eq_constr sigma) goals1 goals2 - then r2 else r1 in - Compose ( [ choice ], rest ) - | _ -> Compose ( [ r1 ], rest )) - | Compose ( tacs, goals ) -> - Compose ( tacs, List.map (rewrite_implicit sigma) goals ) - with _ -> t - (* Given the list of tactics and their corresponding string expressions, try to solve the goal (type of trm), return None otherwise. *) @@ -275,18 +206,32 @@ let try_solve env sigma opts trm = in aux sigma opts with _ -> None +exception RunTacExc of string * env * Evd.evar_map * types + +(* Generate the new subgoals after applying a tactic to a goal. *) +let next_context_goals env sigma (t : tact) (goal : types) : (env * types) list state = + try + let sigma, subgoals = run_tac env sigma (coq_tac sigma t "") goal in + sigma, List.map (get_context_goal env sigma) subgoals + with e -> + let s = show_tactic sigma t in + let s' = Format.asprintf "%a" Pp.pp_with s in + raise (RunTacExc (s', env, sigma, goal)) + (* Generates an apply tactic with implicit arguments if possible. *) -let apply_implicit env sigma trm = +let apply_implicit env sigma trm : tactical state option = try try_app trm >>= fun (f, args) -> try_name env f >>= fun name -> let s = String.concat " " [ "apply" ; name ] in let opt = parse_tac_str s in try_solve env sigma [ (opt, s) ] trm >>= fun tac -> - qed tac + qed sigma tac with _ -> None - +let get_hints env sigma prev goal = + sigma, (beam_search env (List.rev prev) goal) + (* Performs the bulk of decompilation on a proof term. Opts are the optional goal solving tactics that can be inserted into the generated script. If one of these tactics solves the focused goal or @@ -295,53 +240,83 @@ let apply_implicit env sigma trm = let rec first_pass (env : env) (sigma : Evd.evar_map) - (get_hints : env -> Evd.evar_map -> constr -> (unit Proofview.tactic * string) list state) - (trm : constr) = + (get_hints : env -> Evd.evar_map -> string list -> constr -> + (unit Proofview.tactic * string) list state) + (prev : string list) + (goal : types) + (trm : constr) : tactical state = (* Apply single reduction to terms that *might* be in eta expanded form. *) let trm = Reduction.whd_betaiota env trm in - let sigma, hints = get_hints env sigma trm in - let custom = try_custom_tacs env sigma get_hints hints trm in + let custom = try_custom_tacs env sigma get_hints prev goal trm in if Option.has_some custom then Option.get custom else - let def = Option.default (Compose ([ Apply (env, trm) ], [])) + let def = Option.default (sigma, Compose ([ Apply (env, trm) ], [])) (apply_implicit env sigma trm) in - let choose f x = - Option.default def (f x (env, sigma, get_hints)) in - match kind trm with - (* "fun x => ..." -> "intro x." *) - | Lambda (n, t, b) -> - let (env', trm', names) = zoom_lambda_names env 0 trm in - Compose ([ Intros names ], [ first_pass env' sigma get_hints trm' ]) - (* Match on well-known functions used in the proof. *) - | App (f, args) -> - choose (rewrite <|> induction <|> left <|> right <|> split - <|> reflexivity <|> symmetry <|> exists) (f, args) - (* Hypothesis transformations or generation tactics. *) - | LetIn (n, valu, typ, body) -> - choose (rewrite_in <|> apply_in <|> pose) (n, valu, typ, body) - (* Remainder of body, simply apply it. *) - | _ -> def + try + let choose f x = + Option.default def (f x (env, sigma, get_hints, prev, goal)) in + match kind trm with + (* "fun x => ..." -> "intro x." *) + | Lambda (n, t, b) -> + let (env', trm', names) = zoom_lambda_names env 0 trm in + let t = Intros names in + let sigma, next = next_context_goals env sigma t goal in + let _, goal' = List.hd next in + let sigma, rest = first_pass env' sigma get_hints (show_tactic_string sigma t :: prev) goal' trm' in + sigma, Compose ([ t ], [ rest ]) + (* Match on well-known functions used in the proof. *) + | App (f, args) -> + choose (rewrite <|> induction <|> left <|> right <|> split + <|> reflexivity <|> symmetry <|> exists) (f, args) + (* Hypothesis transformations or generation tactics. *) + | LetIn (n, valu, typ, body) -> + choose (rewrite_in <|> apply_in <|> pose) (n, valu, typ, body) + (* Remainder of body, simply apply it. *) + | _ -> def + with + | (RunTacExc (s, env, sigma, goal)) -> + Feedback.msg_warning (str "Failed to execute: " ++ str s); + Feedback.msg_warning (str "on the goal: " ++ Printer.pr_constr_env env sigma goal); + def + | e -> + Feedback.msg_warning (str "Error occured while decompilin: "); + Feedback.msg_warning (str (Printexc.to_string e)); + def + + +(* Pass the updated goal to the next stage of decompilation. *) +and one_subgoal env sigma opts prev goal t trm = + let sigma, next = next_context_goals env sigma t goal in + let env', goal' = List.hd next in + let sigma, rest = first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal' trm in + sigma, Compose ([ t ], [ rest ]) +(* Pass the updated goal to the next stages of decompilation. *) +and many_subgoals env sigma opts prev goal t trms = + let sigma, next = next_context_goals env sigma t goal in + let sigma, rests = + map2_state (fun (_, g) trm sigma -> + first_pass env sigma opts (show_tactic_string sigma t :: prev) g trm) next trms sigma in + sigma, Compose ([ t ], rests) + (* If successful, uses a custom tactic and decompiles subterms solving any generated subgoals. *) -and try_custom_tacs env sigma get_hints all_opts trm = +and try_custom_tacs env sigma get_hints prev goal trm : tactical state option = guard (not (isLambda trm)) >>= fun _ -> - try - let goal = (Typeops.infer env trm).uj_type in - let goal_env env sigma g = - let typ = EConstr.to_constr sigma (Goal.V82.abstract_type sigma g) in - Zooming.zoom_product_type (Environ.reset_context env) typ in - let rec aux opts = + try + let goal = (Typeops.infer env trm).uj_type in + let sigma, hints = get_hints env sigma prev goal in + let rec aux opts : tactical state option = match opts with | [] -> None | (tac, expr) :: opts' -> try - let subgoals, sigma = run_tac env sigma tac goal in - let subgoals = List.map (goal_env env sigma) subgoals in + let sigma, subgoals = run_tac env sigma tac goal in + let subgoals = List.map (get_context_goal env sigma) subgoals in if subgoals = [] then (* Goal solving *) - Some (Compose ([ Expr expr ], [])) + Some (sigma, Compose ([ Expr expr ], [])) else let new_env = fst (List.hd subgoals) in let sigma, same_env = Envutils.compare_envs env new_env sigma in @@ -349,31 +324,35 @@ and try_custom_tacs env sigma get_hints all_opts trm = then (* Both goal and context are unchanged *) aux opts' else (* Intermediate goal generating or context modifying tactic *) + (* NOTE: These produce a distinct sigma for each subgoal. So we + return the latest sigma (from compare_envs) here in the end. *) let subterms = List.map (fun (env', goal) -> - (Typehofs.subterms_with_type env sigma goal trm, env')) + (Typehofs.subterms_with_type env sigma goal trm, env', goal)) subgoals in (* could not find subterms to satisfy all subgoals? *) - if List.exists (fun x -> fst x = []) subterms + if List.exists (fun (x, _, _) -> x = []) subterms then aux opts' else - (* doesn't matter which subterm we found, it's a proof of the subgoal *) - let subterms = List.map (fun (g, e) -> (list_snd g, e)) subterms in - let proofs = List.map (fun ((sigma, (_, trm)), env') -> - first_pass env' sigma get_hints trm) subterms in - Some (Compose ([ Expr expr ], proofs)) + (* Pick the second subterm we found, since the first could be the entire term. *) + let t = Expr expr in + let subterms = List.map (fun (t, e, g) -> (list_snd t, e, g)) subterms in + let proofs = List.map (fun ((sigma, (_, trm)), env', goal') -> + snd (first_pass env' sigma get_hints (show_tactic_string sigma t :: prev) + goal' trm)) subterms in + Some (sigma, Compose ([ t ], proofs)) with _ -> aux opts' - in aux all_opts + in aux hints with e -> (* raise e *) None - + (* Application of a equality eliminator. *) -and rewrite (f, args) (env, sigma, opts) : tactical option = +and rewrite (f, args) (env, sigma, opts, prev, goal) : tactical state option = let fx = mkApp (f, args) in dest_rewrite fx >>= fun rewr -> - let sigma, goal = type_of env fx sigma in - dot (Rewrite (env, rewr.eq, rewr.left, goal)) (first_pass env sigma opts rewr.px) + let t = Rewrite (env, rewr.eq, rewr.left, Some goal) in + Some (one_subgoal env sigma opts prev goal t rewr.px) (* Applying an eliminator for induction on a hypothesis in context. *) -and induction (f, args) (env, sigma, opts) : tactical option = +and induction (f, args) (env, sigma, opts, prev, goal) : tactical state option = guard (is_elim env f) >>= fun _ -> guard (not (is_rewrite f)) >>= fun _ -> let app = mkApp (f, args) in @@ -398,58 +377,69 @@ and induction (f, args) (env, sigma, opts) : tactical option = (* Compute bindings and goals for each case. *) let zooms = List.map (zoom_lambda_names env zoom_but) ind.cs in let names = List.map (fun (_, _, names) -> names) zooms in - let goals = List.map (fun (env, trm, _) -> first_pass env sigma opts trm) zooms in - let ind = Compose ([ Induction (env, ind_var, names) ], goals) in - if reverts == [] then Some ind else dot (Revert reverts) ind - + let finish goal reverts = + let t = Induction (env, ind_var, names) in + let sigma, next = next_context_goals env sigma t goal in + let sigma, rests = map2_state (fun (_, trm, _) (env, goal') sigma -> + first_pass env sigma opts (reverts @ show_tactic_string sigma t :: prev) + goal' trm) zooms next sigma in + Compose ([ t ], rests) in + if reverts == [] + then + Some (sigma, finish goal []) + else + let t1 = Revert reverts in + let sigma, next = next_context_goals env sigma t1 goal in + let goal = snd (List.hd next) in + Some (sigma, Compose ([ t1 ], [ finish goal [show_tactic_string sigma t1] ])) + (* Choose left proof to construct or. *) -and left (f, args) (env, sigma, opts) : tactical option = +and left (f, args) (env, sigma, opts, prev, goal) : tactical state option = dest_or_introl (mkApp (f, args)) >>= fun args -> - dot (Left) (first_pass env sigma opts args.ltrm) + Some (one_subgoal env sigma opts prev goal Left args.ltrm) (* Choose right proof to construct or. *) -and right (f, args) (env, sigma, opts) : tactical option = +and right (f, args) (env, sigma, opts, prev, goal) : tactical state option = dest_or_intror (mkApp (f, args)) >>= fun args -> - dot (Right) (first_pass env sigma opts args.rtrm) + Some (one_subgoal env sigma opts prev goal Right args.rtrm) (* Branch two goals as arguments to conj. *) -and split (f, args) (env, sigma, opts) : tactical option = +and split (f, args) (env, sigma, opts, prev, goal) : tactical state option = dest_conj (mkApp (f, args)) >>= fun args -> - let lhs = first_pass env sigma opts args.ltrm in - let rhs = first_pass env sigma opts args.rtrm in - Some (Compose ([ Split ], [ lhs ; rhs ])) + Some (many_subgoals env sigma opts prev goal Split [ args.ltrm ; args.rtrm ]) (* Converts "apply eq_refl." into "reflexivity." *) -and reflexivity (f, args) _ : tactical option = +and reflexivity (f, args) (_, sigma, _, _, _) : tactical state option = dest_eq_refl_opt (mkApp (f, args)) >>= fun _ -> - qed Reflexivity + qed sigma Reflexivity (* Transform x = y to y = x. *) -and symmetry (f, args) (env, sigma, opts) : tactical option = +and symmetry (f, args) (env, sigma, opts, prev, goal) : tactical state option = guard (equal f eq_sym) >>= fun _ -> let sym = dest_eq_sym (mkApp (f, args)) in - dot (Symmetry) (first_pass env sigma opts sym.eq_proof) + Some (one_subgoal env sigma opts prev goal Symmetry sym.eq_proof) (* Provide evidence for dependent pair. *) -and exists (f, args) (env, sigma, opts) : tactical option = +and exists (f, args) (env, sigma, opts, prev, goal) : tactical state option = guard (equal f Sigmautils.existT) >>= fun _ -> let exT = Sigmautils.dest_existT (mkApp (f, args)) in - dot (Exists (env, exT.index)) (first_pass env sigma opts exT.unpacked) + Some (one_subgoal env sigma opts prev goal (Exists (env, exT.index)) exT.unpacked) (* Value must be a rewrite on a hypothesis in context. *) -and rewrite_in (_, valu, _, body) (env, sigma, opts) : tactical option = +and rewrite_in (_, valu, _, body) (env, sigma, opts, prev, goal) : tactical state option = let valu = Reduction.whd_betaiota env valu in try_app valu >>= fun (f, args) -> dest_rewrite (mkApp (f, args)) >>= fun rewr -> try_rel rewr.px >>= fun idx -> guard (noccurn (idx + 1) body) >>= fun _ -> - let n, t = rel_name_type (lookup_rel idx env) in - let env' = push_local (n, t) env in - dot (RewriteIn (env, rewr.eq, rewr.px, rewr.left)) - (first_pass env' sigma opts body) + let t = RewriteIn (env, rewr.eq, rewr.px, rewr.left) in + let n, typ = rel_name_type (lookup_rel idx env) in + let env' = push_local (n, typ) env in + let sigma, rest = first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal body in + dot sigma t rest (* Value must be an application with last argument in context. *) -and apply_in (n, valu, typ, body) (env, sigma, opts) : tactical option = +and apply_in (n, valu, _, body) (env, sigma, opts, prev, goal) : tactical state option = let valu = Reduction.whd_betaiota env valu in try_app valu >>= fun (f, args) -> let len = Array.length args in @@ -457,36 +447,46 @@ and apply_in (n, valu, typ, body) (env, sigma, opts) : tactical option = try_rel hyp >>= fun idx -> (* let H' := F H *) guard (noccurn (idx + 1) body) >>= fun _ -> (* H does not occur in body *) guard (not (noccurn 1 body)) >>= fun _ -> (* new binding DOES occur *) - let n, t = rel_name_type (lookup_rel idx env) in (* "H" *) - let env' = push_local (n, t) env in (* change type of "H" *) + let n, typ = rel_name_type (lookup_rel idx env) in (* "H" *) + let env' = push_local (n, typ) env in (* change type of "H" *) let prf = mkApp (f, Array.sub args 0 (len - 1)) in - (* let H2 := f H1 := H2 ... *) + let t = ApplyIn (env, prf, hyp) in + (* let A := f B C D ... in A *) let apply_binding app_in (_, sigma) = try_app body >>= fun (f, args) -> try_rel f >>= fun i -> guard (i == 1) >>= fun _ -> - let args' = List.map (first_pass env' sigma opts) (Array.to_list args) in - Some (Compose ([ ApplyIn (env, prf, hyp) ], first_pass env' sigma opts f :: args')) + let sigma, f' = first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal f in + let sigma, args' = map_state (fun trm sigma -> + first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal trm) + (Array.to_list args) sigma in + Some (sigma, Compose ([ t ], f' :: args')) in (* all other cases *) - let default app_in (_, sigma) = dot (ApplyIn (env, prf, hyp)) - (first_pass env' sigma opts body) + let default app_in (_, sigma) = + let sigma, rest = first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal body in + dot sigma t rest in (apply_binding <|> default) () (env', sigma) -(* Last resort decompile let-in as a pose. *) -and pose (n, valu, t, body) (env, sigma, opts) : tactical option = +(* Last resort decompile let-in as a pose. *) +and pose (n, valu, typ, body) (env, sigma, opts, prev, goal) : tactical state option = let n' = fresh_name env n in - let env' = push_let_in (Name n', valu, t) env in - let decomp_body = first_pass env' sigma opts body in + let env' = push_let_in (Name n', valu, typ) env in (* If the binding is NEVER used, just skip this. *) - if noccurn 1 body then Some decomp_body - else dot (Pose (env, valu, n')) (decomp_body) - + if noccurn 1 body + then + let sigma, decomp_body = first_pass env' sigma opts prev goal body in + Some (sigma, decomp_body) + else + let t = Pose (env, valu, n') in + let sigma, decomp_body = first_pass env' sigma opts (show_tactic_string sigma t :: prev) goal body in + dot sigma t (decomp_body) + (* Decompile a term into its equivalent tactic list. *) -let tac_from_term env sigma get_hints trm : tactical = - (* Perform second pass to revise greedy tactic list. *) - semicolons sigma (simpl sigma (rewrite_implicit sigma (intros_revert (first_pass env sigma get_hints trm)))) +let tac_from_term env sigma get_hints trm : tactical state = + let sigma, goal = Inference.infer_type env sigma trm in + first_pass env sigma get_hints [] goal trm (* Generate indentation space before bullet. *) let indent level = @@ -513,7 +513,7 @@ let pp_concat sep xs = | x :: [] -> [ x ] | x :: xs' -> x :: sep :: aux xs' in seq (aux xs) - + (* Show tactical, composed of many tactics. *) let rec show_tactical sigma (level : int) (bulletted : bool) (t : tactical) : Pp.t = let full_indent = if bulletted @@ -527,7 +527,6 @@ let rec show_tactical sigma (level : int) (bulletted : bool) (t : tactical) : Pp tac_s ++ match goals with | [ goal ] -> show_tactical sigma level false goal | goals -> seq (List.mapi f goals) - + (* Represent tactics as a string. *) -let tac_to_string sigma = show_tactical sigma 0 false - +let tac_to_string sigma = show_tactical sigma 0 false \ No newline at end of file diff --git a/src/coq/decompiler/decompiler.mli b/src/coq/decompiler/decompiler.mli index 5ebce7e..8b721cc 100644 --- a/src/coq/decompiler/decompiler.mli +++ b/src/coq/decompiler/decompiler.mli @@ -42,9 +42,10 @@ val parse_tac_str : string -> unit Proofview.tactic Each proofview tactic in the list must be paired with their string representation. *) val tac_from_term : env -> evar_map -> - (env -> evar_map -> constr -> (unit Proofview.tactic * string) list state) -> + (env -> evar_map -> string list -> constr -> + (unit Proofview.tactic * string) list state) -> constr -> - tactical + tactical state (* Given a decompiled Ltac script, return its string representation. *) val tac_to_string : evar_map -> tactical -> Pp.t diff --git a/src/coq/decompiler/dune b/src/coq/decompiler/dune new file mode 100644 index 0000000..5ca0924 --- /dev/null +++ b/src/coq/decompiler/dune @@ -0,0 +1,13 @@ +(library + (name decompiler) + (public_name coq-plugin-lib.decompiler) + (libraries + coq-plugin-lib.hofimpls + coq-plugin-lib.devutils + coq-plugin-lib.constants + coq-plugin-lib.inductive + coq-plugin-lib.contexts + coq.engine + coq.kernel + coq.plugins.ltac) + (wrapped false)) diff --git a/src/coq/devutils/dune b/src/coq/devutils/dune new file mode 100644 index 0000000..173564a --- /dev/null +++ b/src/coq/devutils/dune @@ -0,0 +1,10 @@ +(library + (name devutils) + (public_name coq-plugin-lib.devutils) + (libraries + coq-plugin-lib.contexts + coq-plugin-lib.utilities + coq.printing + coq.engine + coq.kernel) + (wrapped false)) diff --git a/src/coq/devutils/printing.mli b/src/coq/devutils/printing.mli index d4dc67d..fa62803 100644 --- a/src/coq/devutils/printing.mli +++ b/src/coq/devutils/printing.mli @@ -8,7 +8,7 @@ open Evd (* --- Coq terms --- *) (* Pretty-print a `global_reference` with fancy `constr` coloring. *) -val pr_global_as_constr : global_reference -> Pp.t +val pr_global_as_constr : GlobRef.t -> Pp.t (* Gets a name as a string *) val name_as_string : Name.t -> string diff --git a/src/coq/dune b/src/coq/dune new file mode 100644 index 0000000..567bfbb --- /dev/null +++ b/src/coq/dune @@ -0,0 +1,21 @@ +(library + (name coq) + (public_name coq-plugin-lib.coq) + (libraries + coq-plugin-lib.constants + coq-plugin-lib.decompiler + coq-plugin-lib.devutils + coq-plugin-lib.envs + coq-plugin-lib.state + coq-plugin-lib.contexts + coq-plugin-lib.inference + coq-plugin-lib.typesandequality + coq-plugin-lib.transformation + coq-plugin-lib.hofs + coq-plugin-lib.hofimpls + coq-plugin-lib.inductive + coq-plugin-lib.representationutils + coq-plugin-lib.termutils + coq-plugin-lib.tactok + coq.kernel) + (wrapped false)) diff --git a/src/coq/logicutils/contexts/dune b/src/coq/logicutils/contexts/dune new file mode 100644 index 0000000..a6153d4 --- /dev/null +++ b/src/coq/logicutils/contexts/dune @@ -0,0 +1,14 @@ +(library + (name contexts) + (public_name coq-plugin-lib.contexts) + (libraries + coq-plugin-lib.inductive + coq-plugin-lib.state + coq-plugin-lib.inference + coq-plugin-lib.representationutils + coq-plugin-lib.utilities + coq.kernel + coq.engine + coq.interp + coq.plugins.ltac) + (wrapped false)) diff --git a/src/coq/logicutils/contexts/contextutils.ml b/src/coq/logicutils/contexts/envs/contextutils.ml similarity index 98% rename from src/coq/logicutils/contexts/contextutils.ml rename to src/coq/logicutils/contexts/envs/contextutils.ml index 4a63df0..d44c7cc 100644 --- a/src/coq/logicutils/contexts/contextutils.ml +++ b/src/coq/logicutils/contexts/envs/contextutils.ml @@ -194,7 +194,7 @@ let deanonymize_context env sigma ctxt = (* * Inductive types *) -let bindings_for_inductive env mutind_body ind_bodies : CRD.t list = +let bindings_for_inductive env mutind_body ind_bodies : rel_declaration list = Array.to_list (Array.mapi (fun i ind_body -> @@ -206,7 +206,7 @@ let bindings_for_inductive env mutind_body ind_bodies : CRD.t list = (* * Fixpoints *) -let bindings_for_fix (names : name array) (typs : types array) : CRD.t list = +let bindings_for_fix (names : name array) (typs : types array) : rel_declaration list = Array.to_list (CArray.map2_i (fun i name typ -> CRD.LocalAssum (name, Vars.lift i typ)) diff --git a/src/coq/logicutils/contexts/contextutils.mli b/src/coq/logicutils/contexts/envs/contextutils.mli similarity index 84% rename from src/coq/logicutils/contexts/contextutils.mli rename to src/coq/logicutils/contexts/envs/contextutils.mli index 45c5ac3..8e6b292 100644 --- a/src/coq/logicutils/contexts/contextutils.mli +++ b/src/coq/logicutils/contexts/envs/contextutils.mli @@ -78,13 +78,13 @@ val named_type : ('constr, 'types) CND.pt -> 'types * Map over a rel context with environment kept in synch *) val map_rel_context : - env -> (env -> CRD.t -> 'a) -> Context.Rel.t -> 'a list + env -> (env -> rel_declaration -> 'a) -> rel_context -> 'a list (* * Map over a named context with environment kept in synch *) val map_named_context : - env -> (env -> CND.t -> 'a) -> Context.Named.t -> 'a list + env -> (env -> named_declaration -> 'a) -> named_context -> 'a list (* --- Binding in contexts --- *) @@ -92,21 +92,21 @@ val map_named_context : * Bind all local declarations in the relative context onto the body term as * products, substituting away (i.e., zeta-reducing) any local definitions. *) -val smash_prod_assum : Context.Rel.t -> types -> types -val smash_lam_assum : Context.Rel.t -> constr -> constr +val smash_prod_assum : rel_context -> types -> types +val smash_lam_assum : rel_context -> constr -> constr (* * Decompose the first n product bindings, zeta-reducing let bindings to reveal * further product/lambda bindings when necessary. *) -val decompose_prod_n_zeta : int -> types -> Context.Rel.t * types -val decompose_lam_n_zeta : int -> constr -> Context.Rel.t * constr +val decompose_prod_n_zeta : int -> types -> rel_context * types +val decompose_lam_n_zeta : int -> constr -> rel_context * constr (* * Reconstruct local bindings around a term *) -val recompose_prod_assum : Context.Rel.t -> types -> types -val recompose_lam_assum : Context.Rel.t -> types -> types +val recompose_prod_assum : rel_context -> types -> types +val recompose_lam_assum : rel_context -> types -> types (* --- Names in contexts --- *) @@ -130,10 +130,10 @@ val deanonymize_context : *) val bindings_for_inductive : - env -> mutual_inductive_body -> one_inductive_body array -> CRD.t list + env -> mutual_inductive_body -> one_inductive_body array -> rel_declaration list val bindings_for_fix : - name array -> types array -> CRD.t list + Name.t array -> types array -> rel_declaration list (* --- Combining contexts --- *) @@ -144,4 +144,4 @@ val bindings_for_fix : * external indices inside the now-inner context must be shifted to pass over * the now-outer context. *) -val context_app : Context.Rel.t -> Context.Rel.t -> Context.Rel.t +val context_app : rel_context -> rel_context -> rel_context diff --git a/src/coq/logicutils/contexts/envs/dune b/src/coq/logicutils/contexts/envs/dune new file mode 100644 index 0000000..9ccce56 --- /dev/null +++ b/src/coq/logicutils/contexts/envs/dune @@ -0,0 +1,13 @@ +(library + (name envs) + (public_name coq-plugin-lib.envs) + (libraries + coq-plugin-lib.state + coq-plugin-lib.inference + coq-plugin-lib.representationutils + coq-plugin-lib.utilities + coq.kernel + coq.engine + coq.interp + coq.plugins.ltac) + (wrapped false)) diff --git a/src/coq/logicutils/contexts/envutils.ml b/src/coq/logicutils/contexts/envs/envutils.ml similarity index 96% rename from src/coq/logicutils/contexts/envutils.ml rename to src/coq/logicutils/contexts/envs/envutils.ml index 32c3a22..3b4ebb5 100644 --- a/src/coq/logicutils/contexts/envutils.ml +++ b/src/coq/logicutils/contexts/envs/envutils.ml @@ -15,7 +15,7 @@ open Nameutils open Stateutils (* Look up all indexes from is in env *) -let lookup_rels (is : int list) (env : env) : CRD.t list = +let lookup_rels (is : int list) (env : env) : rel_declaration list = List.map (fun i -> lookup_rel i env) is (* Return a list of all indexes in env, starting with 1 *) @@ -27,7 +27,7 @@ let mk_n_rels n = List.map mkRel (List.rev (from_one_to n)) (* Return a list of all bindings in env, starting with the closest *) -let lookup_all_rels (env : env) : CRD.t list = +let lookup_all_rels (env : env) : rel_declaration list = lookup_rels (all_rel_indexes env) env (* Return a name-type pair from the given rel_declaration. *) diff --git a/src/coq/logicutils/contexts/envutils.mli b/src/coq/logicutils/contexts/envs/envutils.mli similarity index 81% rename from src/coq/logicutils/contexts/envutils.mli rename to src/coq/logicutils/contexts/envs/envutils.mli index c25c8c5..f9817b2 100644 --- a/src/coq/logicutils/contexts/envutils.mli +++ b/src/coq/logicutils/contexts/envs/envutils.mli @@ -6,15 +6,14 @@ open Environ open Constr open Names -open Contextutils open Evd open Stateutils (* Look up all indexes from a list in an environment *) -val lookup_rels : int list -> env -> CRD.t list +val lookup_rels : int list -> env -> rel_declaration list (* Return a list of all bindings in an environment, starting with the closest *) -val lookup_all_rels : env -> CRD.t list +val lookup_all_rels : env -> rel_declaration list (* Return a list of all indexes in an environment, starting with 1 *) val all_rel_indexes : env -> int list @@ -23,18 +22,18 @@ val all_rel_indexes : env -> int list val mk_n_rels : int -> types list (* Return a name-type pair from the given rel_declaration. *) -val rel_name_type : CRD.t -> Name.t * types +val rel_name_type : rel_declaration -> Name.t * types (* * Push to an environment *) -val push_local : (name * types) -> env -> env -val push_let_in : (name * types * types) -> env -> env +val push_local : (Name.t * types) -> env -> env +val push_let_in : (Name.t * types * types) -> env -> env (* * Lookup from an environment *) -val lookup_pop : int -> env -> (env * CRD.t list) +val lookup_pop : int -> env -> (env * rel_declaration list) val lookup_definition : env -> types -> types val unwrap_definition : env -> types -> types diff --git a/src/coq/logicutils/contexts/modutils.mli b/src/coq/logicutils/contexts/modutils.mli index 96ec6bd..b1becf4 100644 --- a/src/coq/logicutils/contexts/modutils.mli +++ b/src/coq/logicutils/contexts/modutils.mli @@ -24,7 +24,7 @@ val declare_module_structure : ?params:(Constrexpr.module_ast Declaremods.module * * Elimination schemes (e.g., `Ind_rect`) are filtered out from the definitions. *) -val fold_module_structure_by_decl : 'a -> ('a -> constant -> constant_body -> 'a) -> ('a -> inductive -> Inductive.mind_specif -> 'a) -> module_body -> 'a +val fold_module_structure_by_decl : 'a -> ('a -> Constant.t -> constant_body -> 'a) -> ('a -> inductive -> Inductive.mind_specif -> 'a) -> module_body -> 'a (* * Same as `fold_module_structure_by_decl` except a single step function diff --git a/src/coq/logicutils/contexts/state/dune b/src/coq/logicutils/contexts/state/dune new file mode 100644 index 0000000..7740d54 --- /dev/null +++ b/src/coq/logicutils/contexts/state/dune @@ -0,0 +1,9 @@ +(library + (name state) + (public_name coq-plugin-lib.state) + (libraries + coq-plugin-lib.utilities + coq.kernel + coq.engine + coq.interp) + (wrapped false)) diff --git a/src/coq/logicutils/contexts/stateutils.ml b/src/coq/logicutils/contexts/state/stateutils.ml similarity index 100% rename from src/coq/logicutils/contexts/stateutils.ml rename to src/coq/logicutils/contexts/state/stateutils.ml diff --git a/src/coq/logicutils/contexts/stateutils.mli b/src/coq/logicutils/contexts/state/stateutils.mli similarity index 100% rename from src/coq/logicutils/contexts/stateutils.mli rename to src/coq/logicutils/contexts/state/stateutils.mli diff --git a/src/coq/logicutils/hofs/debruijn.ml b/src/coq/logicutils/debruijn/debruijn.ml similarity index 100% rename from src/coq/logicutils/hofs/debruijn.ml rename to src/coq/logicutils/debruijn/debruijn.ml diff --git a/src/coq/logicutils/hofs/debruijn.mli b/src/coq/logicutils/debruijn/debruijn.mli similarity index 100% rename from src/coq/logicutils/hofs/debruijn.mli rename to src/coq/logicutils/debruijn/debruijn.mli diff --git a/src/coq/logicutils/debruijn/dune b/src/coq/logicutils/debruijn/dune new file mode 100644 index 0000000..3e65d0d --- /dev/null +++ b/src/coq/logicutils/debruijn/dune @@ -0,0 +1,11 @@ +(library + (name debruijn) + (public_name coq-plugin-lib.debruijn) + (libraries + coq-plugin-lib.envs + coq-plugin-lib.hofs + coq-plugin-lib.typesandequality + coq-plugin-lib.utilities + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/hofs/dune b/src/coq/logicutils/hofs/dune new file mode 100644 index 0000000..3b4141b --- /dev/null +++ b/src/coq/logicutils/hofs/dune @@ -0,0 +1,11 @@ +(library + (name hofs) + (public_name coq-plugin-lib.hofs) + (libraries + coq-plugin-lib.state + coq-plugin-lib.envs + coq-plugin-lib.typesandequality + coq-plugin-lib.utilities + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/hofs/impls/dune b/src/coq/logicutils/hofs/impls/dune new file mode 100644 index 0000000..d762669 --- /dev/null +++ b/src/coq/logicutils/hofs/impls/dune @@ -0,0 +1,13 @@ +(library + (name hofimpls) + (public_name coq-plugin-lib.hofimpls) + (libraries + coq-plugin-lib.termutils + coq-plugin-lib.constants + coq-plugin-lib.debruijn + coq-plugin-lib.hofs + coq-plugin-lib.typesandequality + coq-plugin-lib.utilities + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/hofs/filters.ml b/src/coq/logicutils/hofs/impls/filters.ml similarity index 100% rename from src/coq/logicutils/hofs/filters.ml rename to src/coq/logicutils/hofs/impls/filters.ml diff --git a/src/coq/logicutils/hofs/filters.mli b/src/coq/logicutils/hofs/impls/filters.mli similarity index 100% rename from src/coq/logicutils/hofs/filters.mli rename to src/coq/logicutils/hofs/impls/filters.mli diff --git a/src/coq/logicutils/hofs/hofimpls.ml b/src/coq/logicutils/hofs/impls/hofimpls.ml similarity index 100% rename from src/coq/logicutils/hofs/hofimpls.ml rename to src/coq/logicutils/hofs/impls/hofimpls.ml diff --git a/src/coq/logicutils/hofs/hofimpls.mli b/src/coq/logicutils/hofs/impls/hofimpls.mli similarity index 100% rename from src/coq/logicutils/hofs/hofimpls.mli rename to src/coq/logicutils/hofs/impls/hofimpls.mli diff --git a/src/coq/logicutils/hofs/hypotheses.ml b/src/coq/logicutils/hofs/impls/hypotheses.ml similarity index 100% rename from src/coq/logicutils/hofs/hypotheses.ml rename to src/coq/logicutils/hofs/impls/hypotheses.ml diff --git a/src/coq/logicutils/hofs/hypotheses.mli b/src/coq/logicutils/hofs/impls/hypotheses.mli similarity index 100% rename from src/coq/logicutils/hofs/hypotheses.mli rename to src/coq/logicutils/hofs/impls/hypotheses.mli diff --git a/src/coq/logicutils/hofs/reducers.ml b/src/coq/logicutils/hofs/impls/reducers.ml similarity index 100% rename from src/coq/logicutils/hofs/reducers.ml rename to src/coq/logicutils/hofs/impls/reducers.ml diff --git a/src/coq/logicutils/hofs/reducers.mli b/src/coq/logicutils/hofs/impls/reducers.mli similarity index 100% rename from src/coq/logicutils/hofs/reducers.mli rename to src/coq/logicutils/hofs/impls/reducers.mli diff --git a/src/coq/logicutils/hofs/substitution.ml b/src/coq/logicutils/hofs/impls/substitution.ml similarity index 100% rename from src/coq/logicutils/hofs/substitution.ml rename to src/coq/logicutils/hofs/impls/substitution.ml diff --git a/src/coq/logicutils/hofs/substitution.mli b/src/coq/logicutils/hofs/impls/substitution.mli similarity index 100% rename from src/coq/logicutils/hofs/substitution.mli rename to src/coq/logicutils/hofs/impls/substitution.mli diff --git a/src/coq/logicutils/hofs/typehofs.ml b/src/coq/logicutils/hofs/impls/typehofs.ml similarity index 100% rename from src/coq/logicutils/hofs/typehofs.ml rename to src/coq/logicutils/hofs/impls/typehofs.ml diff --git a/src/coq/logicutils/hofs/typehofs.mli b/src/coq/logicutils/hofs/impls/typehofs.mli similarity index 100% rename from src/coq/logicutils/hofs/typehofs.mli rename to src/coq/logicutils/hofs/impls/typehofs.mli diff --git a/src/coq/logicutils/hofs/zooming.ml b/src/coq/logicutils/hofs/impls/zooming.ml similarity index 100% rename from src/coq/logicutils/hofs/zooming.ml rename to src/coq/logicutils/hofs/impls/zooming.ml diff --git a/src/coq/logicutils/hofs/zooming.mli b/src/coq/logicutils/hofs/impls/zooming.mli similarity index 100% rename from src/coq/logicutils/hofs/zooming.mli rename to src/coq/logicutils/hofs/impls/zooming.mli diff --git a/src/coq/logicutils/inductive/dune b/src/coq/logicutils/inductive/dune new file mode 100644 index 0000000..750663b --- /dev/null +++ b/src/coq/logicutils/inductive/dune @@ -0,0 +1,11 @@ +(library + (name inductive) + (public_name coq-plugin-lib.inductive) + (libraries + coq-plugin-lib.hofimpls + coq-plugin-lib.debruijn + coq-plugin-lib.termutils + coq-plugin-lib.utilities + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/inductive/indutils.ml b/src/coq/logicutils/inductive/indutils.ml index 751b2ef..c8c1b5b 100644 --- a/src/coq/logicutils/inductive/indutils.ml +++ b/src/coq/logicutils/inductive/indutils.ml @@ -28,7 +28,7 @@ let check_inductive_supported mutind_body : unit = * Check if a constant is an inductive elminator * If so, return the inductive type *) -let inductive_of_elim (env : env) (pc : pconstant) : mutual_inductive option = +let inductive_of_elim (env : env) (pc : pconstant) : MutInd.t option = let (c, u) = pc in let kn = Constant.canonical c in let (modpath, dirpath, label) = KerName.repr kn in @@ -43,7 +43,7 @@ let inductive_of_elim (env : env) (pc : pconstant) : mutual_inductive option = let ind_label_string = String.sub label_string 0 split_index in let ind_label = Label.of_id (Id.of_string_soft ind_label_string) in let ind_name = MutInd.make1 (KerName.make modpath dirpath ind_label) in - lookup_mind ind_name env; + let _ = lookup_mind ind_name env in Some ind_name else if not is_rev then diff --git a/src/coq/logicutils/inductive/indutils.mli b/src/coq/logicutils/inductive/indutils.mli index fbffa54..502e99b 100644 --- a/src/coq/logicutils/inductive/indutils.mli +++ b/src/coq/logicutils/inductive/indutils.mli @@ -29,7 +29,7 @@ val is_elim : env -> types -> bool (* * Get an inductive type from an eliminator, if possible *) -val inductive_of_elim : env -> pconstant -> mutual_inductive option +val inductive_of_elim : env -> pconstant -> MutInd.t option (* * Lookup the eliminator over the type sort diff --git a/src/coq/logicutils/inference/dune b/src/coq/logicutils/inference/dune new file mode 100644 index 0000000..65bc89a --- /dev/null +++ b/src/coq/logicutils/inference/dune @@ -0,0 +1,8 @@ +(library + (name inference) + (public_name coq-plugin-lib.inference) + (libraries + coq.pretyping + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/typesandequality/inference.ml b/src/coq/logicutils/inference/inference.ml similarity index 100% rename from src/coq/logicutils/typesandequality/inference.ml rename to src/coq/logicutils/inference/inference.ml diff --git a/src/coq/logicutils/typesandequality/inference.mli b/src/coq/logicutils/inference/inference.mli similarity index 100% rename from src/coq/logicutils/typesandequality/inference.mli rename to src/coq/logicutils/inference/inference.mli diff --git a/src/coq/logicutils/transformation/dune b/src/coq/logicutils/transformation/dune new file mode 100644 index 0000000..ebd6a03 --- /dev/null +++ b/src/coq/logicutils/transformation/dune @@ -0,0 +1,10 @@ +(library + (name transformation) + (public_name coq-plugin-lib.transformation) + (libraries + coq-plugin-lib.hofimpls + coq-plugin-lib.inductive + coq-plugin-lib.contexts + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/logicutils/transformation/transform.ml b/src/coq/logicutils/transformation/transform.ml index c902c67..1a786ad 100644 --- a/src/coq/logicutils/transformation/transform.ml +++ b/src/coq/logicutils/transformation/transform.ml @@ -134,6 +134,7 @@ let transform_module_structure ?(init=const Globnames.Refmap.empty) ?(opaques=Gl in assert (List.is_empty mod_arity); (* Functors are not yet supported *) let transform_module_element mod_path' subst (label, body) = + let open GlobRef in Feedback.msg_info (Pp.(str "Transforming " ++ Label.print label)); let ident = Label.to_id label in let tr_constr env sigma = subst_globals subst %> tr_constr env sigma in diff --git a/src/coq/logicutils/typesandequality/convertibility.ml b/src/coq/logicutils/typesandequality/convertibility.ml index c1f6b57..8b1a854 100644 --- a/src/coq/logicutils/typesandequality/convertibility.ml +++ b/src/coq/logicutils/typesandequality/convertibility.ml @@ -15,7 +15,9 @@ open Inference let convertible env sigma trm1 trm2 : evar_map * bool = let etrm1 = EConstr.of_constr trm1 in let etrm2 = EConstr.of_constr trm2 in - Reductionops.infer_conv env sigma etrm1 etrm2 + match Reductionops.infer_conv env sigma etrm1 etrm2 with + | Some sigma -> sigma, true + | None -> sigma, false (* * Checks whether the conclusions of two dependent types are convertible, diff --git a/src/coq/logicutils/typesandequality/dune b/src/coq/logicutils/typesandequality/dune new file mode 100644 index 0000000..22e8fce --- /dev/null +++ b/src/coq/logicutils/typesandequality/dune @@ -0,0 +1,9 @@ +(library + (name typesandequality) + (public_name coq-plugin-lib.typesandequality) + (libraries + coq-plugin-lib.inference + coq-plugin-lib.envs + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/representationutils/defutils.ml b/src/coq/representationutils/defutils.ml index 66423fb..a18b343 100644 --- a/src/coq/representationutils/defutils.ml +++ b/src/coq/representationutils/defutils.ml @@ -49,7 +49,7 @@ let edeclare ident (_, poly, _ as k) ~opaque sigma udecl body tyopt imps hook re let body = to_constr sigma body in let tyopt = Option.map (to_constr sigma) tyopt in let uvars_fold uvars c = - Univ.LSet.union uvars (Univops.universes_of_constr env c) in + Univ.LSet.union uvars (Univops.universes_of_constr c) in let uvars = List.fold_left uvars_fold Univ.LSet.empty (Option.List.cons tyopt [body]) in let sigma = Evd.restrict_universe_context sigma uvars in @@ -61,7 +61,7 @@ let edeclare ident (_, poly, _ as k) ~opaque sigma udecl body tyopt imps hook re (* Define a new Coq term *) let define_term ?typ (n : Id.t) (evm : evar_map) (trm : types) (refresh : bool) = let k = (Global, Flags.is_universe_polymorphism(), Definition) in - let udecl = Univdecls.default_univ_decl in + let udecl = UState.default_univ_decl in let nohook = Lemmas.mk_hook (fun _ x -> x) in let etrm = EConstr.of_constr trm in let etyp = Option.map EConstr.of_constr typ in @@ -70,7 +70,7 @@ let define_term ?typ (n : Id.t) (evm : evar_map) (trm : types) (refresh : bool) (* Define a Canonical Structure *) let define_canonical ?typ (n : Id.t) (evm : evar_map) (trm : types) (refresh : bool) = let k = (Global, Flags.is_universe_polymorphism (), CanonicalStructure) in - let udecl = Univdecls.default_univ_decl in + let udecl = UState.default_univ_decl in let hook = Lemmas.mk_hook (fun _ x -> declare_canonical_structure x; x) in let etrm = EConstr.of_constr trm in let etyp = Option.map EConstr.of_constr typ in @@ -99,7 +99,7 @@ let expr_of_global (g : global_reference) : constr_expr = (* Convert a term into a global reference with universes (or raise Not_found) *) let pglobal_of_constr term = match Constr.kind term with - | Const (const, univs) -> ConstRef const, univs + | Const (const, univs) -> Globnames.ConstRef const, univs | Ind (ind, univs) -> IndRef ind, univs | Construct (cons, univs) -> ConstructRef cons, univs | Var id -> VarRef id, Univ.Instance.empty @@ -108,7 +108,7 @@ let pglobal_of_constr term = (* Convert a global reference with universes into a term *) let constr_of_pglobal (glob, univs) = match glob with - | ConstRef const -> mkConstU (const, univs) + | Globnames.ConstRef const -> mkConstU (const, univs) | IndRef ind -> mkIndU (ind, univs) | ConstructRef cons -> mkConstructU (cons, univs) | VarRef id -> mkVar id diff --git a/src/coq/representationutils/dune b/src/coq/representationutils/dune new file mode 100644 index 0000000..aae2323 --- /dev/null +++ b/src/coq/representationutils/dune @@ -0,0 +1,9 @@ +(library + (name representationutils) + (public_name coq-plugin-lib.representationutils) + (libraries + coq.vernac + coq.interp + coq.kernel + coq.engine) + (wrapped false)) diff --git a/src/coq/representationutils/nameutils.ml b/src/coq/representationutils/nameutils.ml index 58d49e8..9432148 100644 --- a/src/coq/representationutils/nameutils.ml +++ b/src/coq/representationutils/nameutils.ml @@ -26,15 +26,4 @@ let expect_name = function | Name n -> n | Anonymous -> failwith "Unexpected Anonymous Name.t." - -(* Turn an identifier into an external (i.e., surface-level) reference *) -let reference_of_ident id = - Libnames.Ident id |> CAst.make -(* Turn a name into an optional external (i.e., surface-level) reference *) -let reference_of_name = - ident_of_name %> Option.map reference_of_ident - -(* Convert an external reference into a qualid *) -let qualid_of_reference = - Libnames.qualid_of_reference %> CAst.with_val identity diff --git a/src/coq/representationutils/nameutils.mli b/src/coq/representationutils/nameutils.mli index 3e958bd..4dfc713 100644 --- a/src/coq/representationutils/nameutils.mli +++ b/src/coq/representationutils/nameutils.mli @@ -17,11 +17,3 @@ val ident_of_name : Name.t -> Id.t option (* Unwrap a Name.t expecting an Id.t. Fails if anonymous. *) val expect_name : Name.t -> Id.t -(* Turn an identifier into an external (i.e., surface-level) reference *) -val reference_of_ident : Id.t -> Libnames.reference - -(* Turn a name into an optional external (i.e., surface-level) reference *) -val reference_of_name : Name.t -> Libnames.reference option - -(* Convert an external reference into a qualid *) -val qualid_of_reference : Libnames.reference -> Libnames.qualid diff --git a/src/coq/tactok/agent_utils.py b/src/coq/tactok/agent_utils.py new file mode 100644 index 0000000..8086cb7 --- /dev/null +++ b/src/coq/tactok/agent_utils.py @@ -0,0 +1,68 @@ +import torch +import argparse +from gallina import GallinaTermParser +from models.prover import Prover +import string + +term_parser = GallinaTermParser(caching=True) +sexp_cache = SexpCache('../../../../sexp_cache', readonly=True) # include this outside top directory + +def get_opts(): + # TODO: need to update opts as we go + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default='model.pth') + parser.add_argument('--beam_width', type=int, default=10) + parser.add_argument('--tac_grammar', type=str, default='tactics.ebnf') + parser.add_argument('--term_embedding_dim', type=int, default=256) + parser.add_argument('--embedding_dim', type=int, default=256, help='dimension of the grammar embeddings') + parser.add_argument('--symbol_dim', type=int, default=256, help='dimension of the terminal/nonterminal symbol embeddings') + parser.add_argument('--hidden_dim', type=int, default=256, help='dimension of the LSTM controller') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--num_tactics', type=int, default=15025) + parser.add_argument('--tac_vocab_file', type=str, default='token_vocab.pickle') + parser.add_argument('--cutoff_len', type=int, default=30) + opts = parser.parse_args() + opts.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return opts + +def import_model(): + opts = get_opts() + model = Prover(opts) + if opts.device.type == 'cpu': + checkpoint = torch.load(opts.path, map_location='cpu') + else: + checkpoint = torch.load(opts.path) + model.load_state_dict(checkpoint['state_dict']) + model.to(opts.device) + return model + +def filter_env(env): + filtered_env = [] + for const in [const for const in env['constants'] if const['qualid'].startswith('SerTop')][-10:]: + ast = sexp_cache[const['sexp']] + filtered_env.append({'qualid': const['qualid'], 'ast': term_parser.parse(ast)}) + return filtered_env + +def parse_goal(g): + goal = {'id': g['id'], 'text': g['type'], 'ast': term_parser.parse(g['sexp'])} + local_context = [] + for i, h in enumerate(g['hypotheses']): + for ident in h['idents']: + local_context.append({'ident': ident, 'text': h['type'], 'ast': term_parser.parse(h['sexp'])}) + return local_context, goal['ast'] + +rem_punc = string.punctuation.replace('\'','').replace('_', '') +table = str.maketrans('', '', rem_punc) + +def tokenize_text(raw_text): + without_punc = raw_text.translate(table) + words = without_punc.split() + return words + +def parse_script(script): + prev_seq = [] + for tac in script: + tac_words = tokenize_text(tac) + prev_seq += tac_words + + return prev_seq \ No newline at end of file diff --git a/src/coq/tactok/coq_gym.yml b/src/coq/tactok/coq_gym.yml new file mode 100644 index 0000000..c5cc59d --- /dev/null +++ b/src/coq/tactok/coq_gym.yml @@ -0,0 +1,16 @@ +name: coq_gym +channels: + - defaults +dependencies: + - numpy=1.16.2 + - numpy-base=1.16.2 + - python=3.7.1 + - pip=19.0.3 + - pip: + - lark-parser==0.6.5 + - lmdb==0.94 + - pandas==0.24.2 + - pexpect==4.6.0 + - progressbar2==3.39.3 + - sexpdata==0.0.3 + - torch \ No newline at end of file diff --git a/src/coq/tactok/dune b/src/coq/tactok/dune new file mode 100644 index 0000000..90c6d67 --- /dev/null +++ b/src/coq/tactok/dune @@ -0,0 +1,16 @@ +(library + (name tactok) + (public_name coq-plugin-lib.tactok) + (libraries + coq-serapi.serlib + lymp + coq-plugin-lib.hofimpls + coq-plugin-lib.devutils + coq-plugin-lib.constants + coq-plugin-lib.inductive + coq-plugin-lib.contexts + coq-plugin-lib.decompiler + coq.engine + coq.kernel + coq.plugins.ltac) + (wrapped false)) \ No newline at end of file diff --git a/src/coq/tactok/gallina.py b/src/coq/tactok/gallina.py new file mode 100644 index 0000000..9acc5e6 --- /dev/null +++ b/src/coq/tactok/gallina.py @@ -0,0 +1,104 @@ +# Utilities for reconstructing Gallina terms from their serialized S-expressions in CoqGym +from io import StringIO +from vernac_types import Constr__constr +from lark import Lark, Transformer, Visitor, Discard +from lark.lexer import Token +from lark.tree import Tree +from lark.tree import pydot__tree_to_png +import logging +logging.basicConfig(level=logging.DEBUG) +from collections import defaultdict +import re +import pdb + + +def traverse_postorder(node, callback): + for c in node.children: + if isinstance(c, Tree): + traverse_postorder(c, callback) + callback(node) + + +class GallinaTermParser: + + def __init__(self, caching=True): + self.caching = caching + t = Constr__constr() + self.grammar = t.to_ebnf(recursive=True) + ''' + %import common.STRING_INNER + %import common.ESCAPED_STRING + %import common.SIGNED_INT + %import common.WS + %ignore WS + ''' + self.parser = Lark(StringIO(self.grammar), start='constr__constr', parser='lalr') + if caching: + self.cache = {} + + + def parse_no_cache(self, term_str): + ast = self.parser.parse(term_str) + + ast.quantified_idents = set() + + def get_quantified_idents(node): + if node.data == 'constructor_prod' and node.children != [] and node.children[0].data == 'constructor_name': + ident = node.children[0].children[0].value + if ident.startswith('"') and ident.endswith('"'): + ident = ident[1:-1] + ast.quantified_idents.add(ident) + + traverse_postorder(ast, get_quantified_idents) + ast.quantified_idents = list(ast.quantified_idents) + + def compute_height_remove_toekn(node): + children = [] + node.height = 0 + for c in node.children: + if isinstance(c, Tree): + node.height = max(node.height, c.height + 1) + children.append(c) + node.children = children + + traverse_postorder(ast, compute_height_remove_toekn) + return ast + + + def parse(self, term_str): + if self.caching: + if term_str not in self.cache: + self.cache[term_str] = self.parse_no_cache(term_str) + return self.cache[term_str] + else: + return self.parse_no_cache(term_str) + + + def print_grammar(self): + print(self.grammar) + + +class Counter(Visitor): + + def __init__(self): + super().__init__() + self.counts_nonterminal = defaultdict(int) + self.counts_terminal = defaultdict(int) + + def __default__(self, tree): + self.counts_nonterminal[tree.data] += 1 + for c in tree.children: + if isinstance(c, Token): + self.counts_terminal[c.value] += 1 + + +class TreeHeight(Transformer): + + def __default__(self, symbol, children, meta): + return 1 + max([0 if isinstance(c, Token) else c for c in children] + [-1]) + + +class TreeNumTokens(Transformer): + + def __default__(self, symbol, children, meta): + return sum([1 if isinstance(c, Token) else c for c in children]) + diff --git a/src/coq/tactok/models/embedding_map.py b/src/coq/tactok/models/embedding_map.py new file mode 100644 index 0000000..ac3db72 --- /dev/null +++ b/src/coq/tactok/models/embedding_map.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +class EmbeddingMap(nn.Module): + + def __init__(self, dim, opts): + self.dim = dim + self.opts = opts + self.mapping = {} + + def __getitem__(self, key): + if key not in self.mapping: + embedding = nn.Parameters(torch.Tensor(self.dim).to(self.opts.device)) + nn.init.normal_(embedding, std=0.1) + self.mapping[key] = embedding + self.register_parameter(key, embedding) + return self.mapping[key] diff --git a/src/coq/tactok/models/prover.py b/src/coq/tactok/models/prover.py new file mode 100644 index 0000000..e53e0af --- /dev/null +++ b/src/coq/tactok/models/prover.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from tac_grammar import CFG +from .tactic_decoder import TacticDecoder +from .term_encoder import TermEncoder +import pdb +import os +from itertools import chain +import sys +sys.path.append(os.path.abspath('.')) +from time import time +from torch.nn import Embedding, LSTM +import pickle + +class Prover(nn.Module): + + def __init__(self, opts): + super().__init__() + self.opts = opts + self.tactic_decoder = TacticDecoder(CFG(opts.tac_grammar, 'tactic_expr'), opts) + self.term_encoder = TermEncoder(opts) + self.tactic_embedding = Embedding(opts.num_tactics, 256, padding_idx=0) + self.tactic_LSTM = LSTM(256, 256, 1, batch_first=True, bidirectional=True) + self.tac_vocab = pickle.load(open(opts.tac_vocab_file, 'rb')) + self.cutoff_len = opts.cutoff_len + + def create_tactic_batch(self, tok_seq): + mod_tok_seq = [] + if '' in self.tac_vocab: + for item in tok_seq: + mod_item = [self.tac_vocab[i] if i in self.tac_vocab else self.tac_vocab[''] for i in item] + mod_tok_seq.append(mod_item) + else: + for item in tok_seq: + mod_item = [self.tac_vocab[i] for i in item if i in self.tac_vocab] + mod_tok_seq.append(mod_item) + + max_len = min(max([len(item) for item in mod_tok_seq]), self.cutoff_len) + batch = [] + lens = [] + for item in mod_tok_seq: + idx = self.cutoff_len - 1 # ex: 29, for len 30 + lens.append(len(item[-idx:]) + 1) + new_item = [self.tac_vocab['']] + item[-idx:] + [self.tac_vocab['']]*(max_len-len(item[-idx:])-1) + batch.append(new_item) + + return torch.tensor(batch, device=self.opts.device), lens + + def embed_terms(self, environment, local_context, goal, tok_seq=None): + all_asts = list(chain([env['ast'] for env in chain(*environment)], [context['ast'] for context in chain(*local_context)], goal)) + all_embeddings = self.term_encoder(all_asts) + + batchsize = len(environment) + environment_embeddings = [] + j = 0 + for n in range(batchsize): + size = len(environment[n]) + environment_embeddings.append(torch.cat([torch.zeros(size, 3, device=self.opts.device), + all_embeddings[j : j + size]], dim=1)) + environment_embeddings[-1][:, 0] = 1.0 + j += size + + context_embeddings = [] + for n in range(batchsize): + size = len(local_context[n]) + context_embeddings.append(torch.cat([torch.zeros(size, 3, device=self.opts.device), + all_embeddings[j : j + size]], dim=1)) + context_embeddings[-1][:, 1] = 1.0 + j += size + + goal_embeddings = [] + for n in range(batchsize): + goal_embeddings.append(torch.cat([torch.zeros(3, device=self.opts.device), all_embeddings[j]], dim=0)) + goal_embeddings[-1][2] = 1.0 + j += 1 + goal_embeddings = torch.stack(goal_embeddings) + + if tok_seq: + tactic_batch, lens = self.create_tactic_batch(tok_seq) + tactic_embeddings = self.tactic_embedding(tactic_batch) + X = torch.nn.utils.rnn.pack_padded_sequence(tactic_embeddings, lens, batch_first=True, enforce_sorted=False) + tactic_seq_embeddings, _ = self.tactic_LSTM(X) + tactic_seq_embeddings, _ = torch.nn.utils.rnn.pad_packed_sequence(tactic_seq_embeddings, batch_first=True) + tactic_seq_embeddings = tactic_seq_embeddings[:, -1, :] + return environment_embeddings, context_embeddings, goal_embeddings, tactic_seq_embeddings + + + return environment_embeddings, context_embeddings, goal_embeddings + + + def forward(self, environment, local_context, goal, actions, teacher_forcing, tok_seq=None): + environment_embeddings, context_embeddings, goal_embeddings, seq_embeddings = \ + self.embed_terms(environment, local_context, goal, tok_seq) + environment = [{'idents': [v['qualid'] for v in env], + 'embeddings': environment_embeddings[i], + 'quantified_idents': [v['ast'].quantified_idents for v in env]} + for i, env in enumerate(environment)] + local_context = [{'idents': [v['ident'] for v in context], + 'embeddings': context_embeddings[i], + 'quantified_idents': [v['ast'].quantified_idents for v in context]} + for i, context in enumerate(local_context)] + goal = {'embeddings': goal_embeddings, 'quantified_idents': [g.quantified_idents for g in goal]} + asts, loss = self.tactic_decoder(environment, local_context, goal, actions, teacher_forcing, seq_embeddings) + return asts, loss + + + def beam_search(self, environment, local_context, goal, tok_seq=None): + environment_embeddings, context_embeddings, goal_embeddings, seq_embeddings = \ + self.embed_terms([environment], [local_context], [goal], [tok_seq]) + environment = {'idents': [v['qualid'] for v in environment], + 'embeddings': environment_embeddings[0], + 'quantified_idents': [v['ast'].quantified_idents for v in environment]} + local_context = {'idents': [v['ident'] for v in local_context], + 'embeddings': context_embeddings[0], + 'quantified_idents': [v['ast'].quantified_idents for v in local_context]} + goal = {'embeddings': goal_embeddings, 'quantified_idents': goal.quantified_idents} + asts = self.tactic_decoder.beam_search(environment, local_context, goal, seq_embeddings) + return asts diff --git a/src/coq/tactok/models/tactic_decoder.py b/src/coq/tactok/models/tactic_decoder.py new file mode 100644 index 0000000..3253dcd --- /dev/null +++ b/src/coq/tactok/models/tactic_decoder.py @@ -0,0 +1,438 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import random +import pdb +from copy import deepcopy +from tac_grammar import TerminalNode, NonterminalNode +from lark.lexer import Token + +class AvgLoss: + 'Maintaining the average of a set of losses' + + def __init__(self, device): + self.sum = torch.tensor(0., device=device) + self.num = 0 + + + def add(self, v): + self.sum += v + self.num += 1 + + def value(self): + return self.sum / self.num + + +class ContextReader(nn.Module): + + def __init__(self, opts): + super().__init__() + self.opts = opts + self.linear1 = nn.Linear(opts.hidden_dim + opts.term_embedding_dim + 3, opts.hidden_dim) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(opts.hidden_dim, 1) + self.default_context = torch.zeros(self.opts.term_embedding_dim + 3, device=self.opts.device) + + + def forward(self, states, embeddings): + assert states.size(0) == len(embeddings) + context = [] + for state, embedding in zip(states, embeddings): + if embedding.size(0) == 0: # no premise + context.append(self.default_context) + else: + input = torch.cat([state.unsqueeze(0).expand(embedding.size(0), -1), embedding], dim=1) + weights = self.linear2(self.relu1(self.linear1(input))) + weights = F.softmax(weights, dim=0) + context.append(torch.matmul(embedding.t(), weights).squeeze()) + context = torch.stack(context) + return context + + +class ContextRetriever(nn.Module): + + def __init__(self, opts): + super().__init__() + self.opts = opts + self.linear1 = nn.Linear(opts.hidden_dim + opts.term_embedding_dim + 3, opts.hidden_dim) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(opts.hidden_dim, 1) + + + def forward(self, state, embeddings): + input = torch.cat([state.unsqueeze(0).expand(embeddings.size(0), -1), embeddings], dim=1) + logits = self.linear2(self.relu1(self.linear1(input))) + return logits.view(logits.size(0)) + + +def clear_state(node): + del node.state + + +class TacticDecoder(nn.Module): + + def __init__(self, grammar, opts): + super().__init__() + self.opts = opts + self.grammar = grammar + self.symbol_embeddings = nn.Embedding(len(self.grammar.symbols), opts.symbol_dim) + self.production_rule_embeddings = nn.Embedding(len(self.grammar.production_rules), opts.embedding_dim) + self.lex_rule_embeddings = nn.Embedding(len(self.grammar.terminal_symbols), opts.embedding_dim) + self.default_action_embedding = torch.zeros(self.opts.embedding_dim, device=self.opts.device) + self.default_state = torch.zeros(self.opts.hidden_dim, device=self.opts.device) + self.controller = nn.GRUCell(2 * opts.embedding_dim + 2 * opts.term_embedding_dim + 6 + opts.hidden_dim + opts.symbol_dim + 512, opts.hidden_dim) + self.state_decoder = nn.Sequential(nn.Linear(opts.hidden_dim, opts.embedding_dim), nn.Tanh()) + self.context_reader = ContextReader(opts) + self.context_retriever = ContextRetriever(opts) + self.INT_classifier = nn.Sequential(nn.Linear(opts.hidden_dim, opts.hidden_dim // 2), + nn.ReLU(inplace=True), + nn.Linear(opts.hidden_dim // 2, 4)) + + self.hint_dbs = ['arith', 'zarith', 'algebra', 'real', 'sets', 'core', 'bool', 'datatypes', 'coc', 'set', 'zfc'] + self.HINT_DB_classifier = nn.Sequential(nn.Linear(opts.hidden_dim, opts.hidden_dim // 2), + nn.ReLU(inplace=True), + nn.Linear(opts.hidden_dim // 2, len(self.hint_dbs))) + + def action2embedding(self, action): + if isinstance(action, tuple): # a production rule + idx = self.grammar.production_rules.index(action) + return self.production_rule_embeddings(torch.LongTensor([idx]).to(self.opts.device)).squeeze() + else: # a token + idx = self.grammar.terminal_symbols.index(action) + return self.lex_rule_embeddings(torch.LongTensor([idx]).to(self.opts.device)).squeeze() + + + def gather_frontier_info(self, frontiers): + indice = [] # indice for incomplete ASTs + s_tm1 = [] + a_tm1 = [] + p_t = [] + symbols = [] + + for i, stack in enumerate(frontiers): + if stack == []: + continue + indice.append(i) + node = stack[-1] # the next node to expand + if node.pred is None: # root + assert node.parent is None + s_tm1.append(self.default_state) + a_tm1.append(self.default_action_embedding) + p_t.append(torch.cat([self.default_state, self.default_action_embedding])) + else: + s_tm1.append(node.pred.state) + a_tm1.append(self.action2embedding(node.pred.action)) + p_t.append(torch.cat([node.parent.state, self.action2embedding(node.parent.action)])) + symbols.append(node.symbol) + + if indice == []: # all trees are complete + return [], None, None, None, None + + symbol_indice = torch.LongTensor([self.grammar.symbols.index(s) for s in symbols]) + n_t = self.symbol_embeddings(symbol_indice.to(self.opts.device)) + s_tm1 = torch.stack(s_tm1) + a_tm1 = torch.stack(a_tm1) + p_t = torch.stack(p_t) + return indice, s_tm1, a_tm1, p_t, n_t + + + def initialize_trees(self, batchsize): + asts = [NonterminalNode(self.grammar.start_symbol, parent=None) for i in range(batchsize)] # partial results + frontiers = [[asts[i]] for i in range(batchsize)] # the stacks for DFS, whose top are the next nodes + return asts, frontiers + + + def expand_node_set_pred(self, node, rule, stack): + node.expand(rule) + + # updat the links to the predecessor + for c in node.children[::-1]: + if isinstance(c, Token): + continue + if stack != []: + stack[-1].pred = c + stack.append(c) + + if stack != []: + stack[-1].pred = node + + + def expand_nonterminal(self, node, expansion_step, nonterminal_expansion_step, actions_gt, teacher_forcing, stack): + # selcet a production rule and compute the loss + applicable_rules = self.grammar.get_applicable_rules(node.symbol) + + if teacher_forcing: + logits = torch.matmul(self.production_rule_embeddings.weight[applicable_rules], self.state_decoder(node.state)) + action_idx = actions_gt[expansion_step] + rule = self.grammar.production_rules[action_idx] # expand the tree using the ground truth action + action_gt_onehot = torch.LongTensor([applicable_rules.index(action_idx)]).to(self.opts.device) + loss = F.cross_entropy(logits.unsqueeze(0), action_gt_onehot) + + else: + logits = torch.matmul(self.production_rule_embeddings.weight, self.state_decoder(node.state)) + rule_idx = applicable_rules[logits[applicable_rules].argmax().item()] + rule = self.grammar.production_rules[rule_idx] + if nonterminal_expansion_step < len(actions_gt): + action_idx = actions_gt[nonterminal_expansion_step] + action_gt_onehot = torch.LongTensor([action_idx]).to(self.opts.device) + loss = F.cross_entropy(logits.unsqueeze(0), action_gt_onehot) + else: + loss = 0. + + if expansion_step > self.opts.size_limit: # end the generation process asap + rule_idx = applicable_rules[0] + rule = self.grammar.production_rules[rule_idx] + + self.expand_node_set_pred(node, rule, stack) + + return loss + + + def expand_terminal(self, node, expansion_step, environment, local_context, goal, actions_gt, teacher_forcing): + loss = 0. + if teacher_forcing: + token_gt = actions_gt[expansion_step] + + if node.symbol in ['QUALID', 'LOCAL_IDENT']: + if node.symbol == 'QUALID': + candidates = environment['idents'] + local_context['idents'] + else: + candidates = local_context['idents'] + if candidates == []: + token = random.choice(['H'] + goal['quantified_idents']) + else: + if node.symbol == 'QUALID': + candidate_embeddings = torch.cat([environment['embeddings'], local_context['embeddings']]) + else: + candidate_embeddings = local_context['embeddings'] + context_scores = self.context_retriever(node.state, candidate_embeddings) + if teacher_forcing: + target = torch.zeros_like(context_scores) + if token_gt in candidates: + target[candidates.index(token_gt)] = 1.0 + loss = F.binary_cross_entropy_with_logits(context_scores, target) + else: + token = candidates[context_scores.argmax()] + + elif node.symbol in 'INT': + cls = self.INT_classifier(node.state) + if teacher_forcing: + cls_gt = torch.LongTensor([int(token_gt) - 1]).to(self.opts.device) + loss = F.cross_entropy(cls.unsqueeze(0), cls_gt) + else: + token = str(cls.argmax().item() + 1) + + elif node.symbol == 'HINT_DB': + cls = self.HINT_DB_classifier(node.state) + if teacher_forcing: + cls_gt = torch.LongTensor([self.hint_dbs.index(token_gt)]).to(self.opts.device) + loss = F.cross_entropy(cls.unsqueeze(0), cls_gt) + else: + token = self.hint_dbs[cls.argmax().item()] + + elif node.symbol == 'QUANTIFIED_IDENT': + if goal['quantified_idents'] == []: + candidates = ['x'] + else: + candidates = goal['quantified_idents'] + token = random.choice(candidates) + + # generadddte a token with the lex rule + node.expand(token_gt if teacher_forcing else token) + + return loss + + + def expand_partial_tree(self, node, expansion_step, nonterminal_expansion_step, environment, local_context, goal, actions_gt, + teacher_forcing, stack): + assert node.state is not None + if isinstance(node, NonterminalNode): + return self.expand_nonterminal(node, expansion_step, nonterminal_expansion_step, actions_gt, teacher_forcing, stack) + else: + return self.expand_terminal(node, expansion_step, environment, local_context, goal, actions_gt, teacher_forcing) + + + def forward(self, environment, local_context, goal, actions, teacher_forcing, seq_embeddings=None): + if not teacher_forcing: + # when train without teacher forcing, only consider the expansion of non-terminal nodes + actions = [[a for a in act if isinstance(a, int)] for act in actions] + + loss = AvgLoss(self.opts.device) + + # initialize the trees + batchsize = goal['embeddings'].size(0) + asts, frontiers = self.initialize_trees(batchsize) + + # expand the trees in a depth-first order + expansion_step = 0 + nonterminal_expansion_step = [0 for i in range(batchsize)] + while True: + # in each iteration, compute the state of the frontier nodes and expand them + # collect inputs from all partial trees: s_{t-1}, a_{t-1}, p_t, n_t + indice, s_tm1, a_tm1, p_t, n_t = self.gather_frontier_info(frontiers) + if indice == []: # all trees are complete + break + + r = [torch.cat([environment[i]['embeddings'], local_context[i]['embeddings']], dim=0) for i in indice] + u_t = self.context_reader(s_tm1, r) + + states = self.controller(torch.cat([a_tm1, goal['embeddings'][indice], u_t, p_t, n_t, seq_embeddings[indice]], dim=1), s_tm1) + + # store states and expand nodes + for j, idx in enumerate(indice): + stack = frontiers[idx] + node = stack.pop() + node.state = states[j] + g = {k: v[idx] for k, v in goal.items()} + loss.add(self.expand_partial_tree(node, expansion_step, nonterminal_expansion_step[idx], + environment[idx], local_context[idx], g, actions[idx], teacher_forcing, stack)) + if isinstance(node, NonterminalNode): + nonterminal_expansion_step[idx] += 1 + + expansion_step += 1 + + for ast in asts: + ast.traverse_pre(clear_state) + + return asts, loss.value() + + + def duplicate(self, ast, stack): + old2new = {} + def recursive_duplicate(node, parent=None): + if isinstance(node, Token): + new_node = deepcopy(node) + old2new[node] = new_node + return new_node + elif isinstance(node, TerminalNode): + new_node = TerminalNode(node.symbol, parent) + new_node.token = node.token + else: + assert isinstance(node, NonterminalNode) + new_node = NonterminalNode(node.symbol, parent) + + old2new[node] = new_node + new_node.action = node.action + if node.pred is None: + new_node.pred = None + else: + new_node.pred = old2new[node.pred] + new_node.state = node.state + if isinstance(node, NonterminalNode): + for c in node.children: + new_node.children.append(recursive_duplicate(c, new_node)) + return new_node + + new_ast = recursive_duplicate(ast) + new_stack = [old2new[node] for node in stack] + return new_ast, new_stack + + + def beam_search(self, environment, local_context, goal, seq_embeddings=None): + # initialize the trees in the beam + assert goal['embeddings'].size(0) == 1 # only support batchsize == 1 + beam, frontiers = self.initialize_trees(1) + log_likelihood = [0.] # the (unnormalized) objective function maximized by the beam search + complete_trees = [] # the complete ASTs generated during the beam search + + expansion_step = 0 + while True: + # collect inputs from all partial trees + indice, s_tm1, a_tm1, p_t, n_t = self.gather_frontier_info(frontiers) + # check if there are complete trees + for i in range(len(beam)): + if i not in indice: + normalized_log_likelihood = log_likelihood[i] / (expansion_step ** self.opts.lens_norm) # length normalization + beam[i].traverse_pre(clear_state) + complete_trees.append((beam[i], normalized_log_likelihood)) + if indice == []: # all trees are complete, terminate the beam search + break + + r = [torch.cat([environment['embeddings'], local_context['embeddings']], dim=0) for i in indice] + u_t = self.context_reader(s_tm1, r) + + states = self.controller(torch.cat([a_tm1, goal['embeddings'].expand(len(indice), -1), u_t, p_t, n_t, seq_embeddings.expand(len(indice), -1)], dim=1), s_tm1) + + # compute the log likelihood and pick the top candidates + beam_candidates = [] + for j, idx in enumerate(indice): + stack = frontiers[idx] + node = stack[-1] + node.state = states[j] + + if isinstance(node, NonterminalNode): + applicable_rules = self.grammar.get_applicable_rules(node.symbol) + if expansion_step > self.opts.size_limit: # end the generation process asap + beam_candidates.append((idx, log_likelihood[i], applicable_rules[0])) + else: + logits = torch.matmul(self.production_rule_embeddings.weight[applicable_rules], self.state_decoder(node.state)) + log_cond_prob = logits - logits.logsumexp(dim=0) + for n, cand in enumerate(applicable_rules): + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob[n].item(), cand)) + + elif node.symbol in ['QUALID', 'LOCAL_IDENT']: + if node.symbol == 'QUALID': + candidates = environment['idents'] + local_context['idents'] + else: + candidates = local_context['idents'] + if candidates == []: + candidates = ['H'] + goal['quantified_idents'] + log_cond_prob = - math.log(len(candidates)) + for cand in candidates: + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob, cand)) + else: + if node.symbol == 'QUALID': + candidate_embeddings = torch.cat([environment['embeddings'], local_context['embeddings']]) + else: + candidate_embeddings = local_context['embeddings'] + context_scores = self.context_retriever(node.state, candidate_embeddings) + log_cond_prob = context_scores - context_scores.logsumexp(dim=0) + for n, cand in enumerate(candidates): + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob[n].item(), cand)) + + elif node.symbol == 'INT': + cls = self.INT_classifier(node.state) + log_cond_prob = cls - cls.logsumexp(dim=0) + for n in range(cls.size(0)): + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob[n].item(), str(n + 1))) + + elif node.symbol == 'HINT_DB': + cls = self.HINT_DB_classifier(node.state) + log_cond_prob = cls - cls.logsumexp(dim=0) + for n in range(cls.size(0)): + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob[n].item(), self.hint_dbs[n])) + + elif node.symbol == 'QUANTIFIED_IDENT': + if len(goal['quantified_idents']) > 0: + candidates = list(goal['quantified_idents']) + else: + candidates = ['x'] + log_cond_prob = - math.log(len(candidates)) + for cand in candidates: + beam_candidates.append((idx, log_likelihood[idx] + log_cond_prob, cand)) + + # expand the nodes and update the beam + beam_candidates = sorted(beam_candidates, key=lambda x: x[1], reverse=True)[:self.opts.beam_width] + new_beam = [] + new_frontiers = [] + new_log_likelihood = [] + for idx, log_cond_prob, action in beam_candidates: + ast, stack = self.duplicate(beam[idx], frontiers[idx]) + node = stack.pop() + if isinstance(action, int): # expand a nonterimial node + rule = self.grammar.production_rules[action] + self.expand_node_set_pred(node, rule, stack) + else: # expand a terminal node + node.expand(action) + new_beam.append(ast) + new_frontiers.append(stack) + new_log_likelihood.append(log_likelihood[idx] + log_cond_prob) + beam = new_beam + frontiers = new_frontiers + log_likelihood = new_log_likelihood + expansion_step += 1 + + complete_trees = sorted(complete_trees, key=lambda x: x[1], reverse=True) # pick the top ASTs + return [t[0] for t in complete_trees[:self.opts.num_tactic_candidates]] + diff --git a/src/coq/tactok/models/term_encoder.py b/src/coq/tactok/models/term_encoder.py new file mode 100644 index 0000000..cfc52ba --- /dev/null +++ b/src/coq/tactok/models/term_encoder.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from collections import defaultdict +from time import time +from itertools import chain +from lark.tree import Tree +import os +from gallina import traverse_postorder +import pdb + + +nonterminals = [ + 'constr__constr', + 'constructor_rel', + 'constructor_var', + 'constructor_meta', + 'constructor_evar', + 'constructor_sort', + 'constructor_cast', + 'constructor_prod', + 'constructor_lambda', + 'constructor_letin', + 'constructor_app', + 'constructor_const', + 'constructor_ind', + 'constructor_construct', + 'constructor_case', + 'constructor_fix', + 'constructor_cofix', + 'constructor_proj', + 'constructor_ser_evar', + 'constructor_prop', + 'constructor_set', + 'constructor_type', + 'constructor_ulevel', + 'constructor_vmcast', + 'constructor_nativecast', + 'constructor_defaultcast', + 'constructor_revertcast', + 'constructor_anonymous', + 'constructor_name', + 'constructor_constant', + 'constructor_mpfile', + 'constructor_mpbound', + 'constructor_mpdot', + 'constructor_dirpath', + 'constructor_mbid', + 'constructor_instance', + 'constructor_mutind', + 'constructor_letstyle', + 'constructor_ifstyle', + 'constructor_letpatternstyle', + 'constructor_matchstyle', + 'constructor_regularstyle', + 'constructor_projection', + 'bool', + 'int', + 'names__label__t', + 'constr__case_printing', + 'univ__universe__t', + 'constr__pexistential___constr__constr', + 'names__inductive', + 'constr__case_info', + 'names__constructor', + 'constr__prec_declaration___constr__constr____constr__constr', + 'constr__pfixpoint___constr__constr____constr__constr', + 'constr__pcofixpoint___constr__constr____constr__constr', +] + + +class InputOutputUpdateGate(nn.Module): + + def __init__(self, hidden_dim, nonlinear): + super().__init__() + self.nonlinear = nonlinear + k = 1. / math.sqrt(hidden_dim) + self.W = nn.Parameter(torch.Tensor(hidden_dim, len(nonterminals) + hidden_dim)) + nn.init.uniform_(self.W, -k, k) + self.b = nn.Parameter(torch.Tensor(hidden_dim)) + nn.init.uniform_(self.b, -k, k) + + + def forward(self, xh): + return self.nonlinear(F.linear(xh, self.W, self.b)) + + +class ForgetGates(nn.Module): + + def __init__(self, hidden_dim, opts): + super().__init__() + self.hidden_dim = hidden_dim + self.opts = opts + k = 1. / math.sqrt(hidden_dim) + # the weight for the input + self.W_if = nn.Parameter(torch.Tensor(hidden_dim, len(nonterminals))) + nn.init.uniform_(self.W_if, -k, k) + # the weight for the hidden + self.W_hf = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) + nn.init.uniform_(self.W_hf, -k, k) + # the bias + self.b_f = nn.Parameter(torch.Tensor(hidden_dim)) + nn.init.uniform_(self.b_f, -k, k) + + + def forward(self, x, h_children, c_children): + c_remain = torch.zeros(x.size(0), self.hidden_dim).to(self.opts.device) + + Wx = F.linear(x, self.W_if) + all_h = list(chain(*h_children)) + if all_h == []: + return c_remain + Uh = F.linear(torch.stack(all_h), self.W_hf, self.b_f) + i = 0 + for j, h in enumerate(h_children): + if h == []: + continue + f_gates = torch.sigmoid(Wx[j] + Uh[i : i + len(h)]) + i += len(h) + c_remain[j] = (f_gates * torch.stack(c_children[j])).sum(dim=0) + + return c_remain + + +class TermEncoder(nn.Module): + + def __init__(self, opts): + super().__init__() + self.opts = opts + self.input_gate = InputOutputUpdateGate(opts.term_embedding_dim, nonlinear=torch.sigmoid) + self.forget_gates = ForgetGates(opts.term_embedding_dim, opts) + self.output_gate = InputOutputUpdateGate(opts.term_embedding_dim, nonlinear=torch.sigmoid) + self.update_cell = InputOutputUpdateGate(opts.term_embedding_dim, nonlinear=torch.tanh) + + + def forward(self, term_asts): + # the height of a node determines when it can be processed + height2nodes = defaultdict(set) + + def get_height(node): + height2nodes[node.height].add(node) + + for ast in term_asts: + traverse_postorder(ast, get_height) + + memory_cells = {} # node -> memory cell + hidden_states = {} # node -> hidden state + #return torch.zeros(len(term_asts), self.opts.term_embedding_dim).to(self.opts.device) + + # compute the embedding for each node + for height in sorted(height2nodes.keys()): + nodes_at_height = list(height2nodes[height]) + # sum up the hidden states of the children + h_sum = [] + c_remains = [] + x = torch.zeros(len(nodes_at_height), len(nonterminals), device=self.opts.device) \ + .scatter_(1, torch.tensor([nonterminals.index(node.data) for node in nodes_at_height], + device=self.opts.device).unsqueeze(1), 1.0) + + h_sum = torch.zeros(len(nodes_at_height), self.opts.term_embedding_dim).to(self.opts.device) + h_children = [] + c_children = [] + for j, node in enumerate(nodes_at_height): + h_children.append([]) + c_children.append([]) + for c in node.children: + h = hidden_states[c] + h_sum[j] += h + h_children[-1].append(h) + c_children[-1].append(memory_cells[c]) + c_remains = self.forget_gates(x, h_children, c_children) + + # gates + xh = torch.cat([x, h_sum], dim=1) + i_gate = self.input_gate(xh) + o_gate = self.output_gate(xh) + u = self.update_cell(xh) + cells = i_gate * u + c_remains + hiddens = o_gate * torch.tanh(cells) + + + for i, node in enumerate(nodes_at_height): + memory_cells[node] = cells[i] + hidden_states[node] = hiddens[i] + + return torch.stack([hidden_states[ast] for ast in term_asts]) diff --git a/src/coq/tactok/options.ml b/src/coq/tactok/options.ml new file mode 100644 index 0000000..d8386f5 --- /dev/null +++ b/src/coq/tactok/options.ml @@ -0,0 +1,15 @@ +(* --- Options for Decompiler with TacTok --- *) + +let default_beam_width = 10 +let opt_beam_width = ref default_beam_width +let _ = Goptions.declare_int_option { + Goptions.optdepr = false; + Goptions.optname = "Beam width"; + Goptions.optkey = ["beam"]; + Goptions.optread = (fun () -> get_beam_width ()); + Goptions.optwrite = (fun o -> + let wid = o in + set_beam_width wid) } + +let set_beam_width = (:=) opt_beam_width +let get_beam_width () = !opt_beam_width \ No newline at end of file diff --git a/src/coq/tactok/options.mli b/src/coq/tactok/options.mli new file mode 100644 index 0000000..dddc748 --- /dev/null +++ b/src/coq/tactok/options.mli @@ -0,0 +1,10 @@ +(* --- Options for Decompiler with TacTok --- *) + +(* + * Beam width is the number of tactics TacTok will predict in order + * of probability + *) + +val default_beam_width : int +val set_beam_width : int -> unit +val get_beam_width : unit -> int diff --git a/src/coq/tactok/tac_grammar.py b/src/coq/tactok/tac_grammar.py new file mode 100644 index 0000000..87e138a --- /dev/null +++ b/src/coq/tactok/tac_grammar.py @@ -0,0 +1,290 @@ +from glob import glob +import re +from io import StringIO +from lark import Lark, Transformer +from lark.lexer import Token +import logging +logging.basicConfig(level=logging.DEBUG) +import pdb +from progressbar import ProgressBar +import sys + + +class RuleBuilder(Transformer): + + def symbol(self, children): + assert len(children) == 1 + if children[0].type == 'TERMINAL': + assert children[0].value.isupper() + return children[0].value + elif children[0].type == 'NONTERMINAL': + assert children[0].value.islower() + return children[0].value + else: + assert children[0].type in ['ESCAPED_STRING', 'REGEXP'] + return children[0].value + + + def rhs(self, children): + return children + + + def nonterminal_rule(self, children): + assert children[0].type == 'NONTERMINAL' + return [(children[0].value, c) for c in children[1:]] + + + def terminal_rule(self, children): + assert children[0].type == 'TERMINAL' + return [(children[0].value, children[1].value)] + + + def rule(self, children): + assert len(children) == 1 + return children[0] + + +class CFG: + + def __init__(self, grammar_file, start_symbol): + self.terminal_symbols = [] + self.nonterminal_symbols = [] + self.production_rules = [] + self.start_symbol = start_symbol + + meta_grammar = ''' + rule : nonterminal_rule + | terminal_rule + nonterminal_rule : "!"? NONTERMINAL ":" rhs ("|" rhs)* + NONTERMINAL : /[a-z0-9_]+/ + ALIAS : /[a-z_]+/ + symbol : TERMINAL + | NONTERMINAL + | ESCAPED_STRING + | REGEXP + rhs : symbol* ["->" ALIAS] + terminal_rule : TERMINAL ":" ESCAPED_STRING + | TERMINAL ":" REGEXP + TERMINAL : /[A-Z_]+/ + REGEXP : "/" STRING_INNER+ "/" + %import common.STRING_INNER + %import common.ESCAPED_STRING + %import common.WS + %ignore WS + ''' + meta_parser = Lark(meta_grammar, start='rule', parser='earley') + t = RuleBuilder() + self.ebnf = open(grammar_file).read() + + for rule_ebnf in self.ebnf.split('\n\n'): + if rule_ebnf.startswith('%'): + continue + rules = t.transform(meta_parser.parse(rule_ebnf)) + if rules[0][0].islower(): + self.nonterminal_symbols.append(rules[0][0]) + self.production_rules.extend(rules) + else: + self.terminal_symbols.append(rules[0][0]) + self.symbols = self.nonterminal_symbols + self.terminal_symbols + + + self.parser = Lark(StringIO(self.ebnf), start=self.start_symbol, parser='earley', debug=True) + + + def get_applicable_rules(self, symbol): + return [i for i, rule in enumerate(self.production_rules) if rule[0] == symbol] + + + def __str__(self): + return self.ebnf + + +class Node: + + def __init__(self, symbol, parent): + self.symbol = symbol + self.parent = parent + self.pred = None # predecessor in depth-first search + self.state = None # the hidden state of GRU + self.action = None # the production rule used to expand this node + + +class NonterminalNode(Node): + + def __init__(self, symbol, parent): + super().__init__(symbol, parent) + self.children = [] + + + def __str__(self): + return 'NonterminalNode(%s, children=%s)' % (self.symbol, str(self.children)) + + + def __repr__(self): + return str(self) + + + def expand(self, rule): + assert rule[0] == self.symbol and self.action is None and self.children == [] + self.action = rule + for entry in rule[1]: + if entry.startswith('"') and entry.endswith('"'): # token + self.children.append(Token('literal', entry[1:-1])) + elif entry.islower(): # nonterminal symbol + self.children.append(NonterminalNode(entry, self)) + else: # terminal symbol + assert entry.isupper() + self.children.append(TerminalNode(entry, self)) + + + def to_tokens(self): + assert self.action is not None + fields = [] + for r, c in zip(self.action[1], self.children): + if isinstance(c, Token): + assert c.value == r[1:-1] + fields.append(c.value) + elif isinstance(c, NonterminalNode): + assert c.symbol == r + fields.append(c.to_tokens()) + else: + assert isinstance(c, TerminalNode) + fields.append(c.token) + return ' '.join(fields).strip() + + + def traverse_pre(self, callback): + callback(self) + for c in self.children: + if isinstance(c, Node): + c.traverse_pre(callback) + + + def height(self): + return 1 + max([-1] + [0 if isinstance(c, Token) else c.height() for c in self.children]) + + + def num_tokens(self): + n = 0 + for c in self.children: + if isinstance(c, Token) or isinstance(c, TerminalNode): + n += 1 + else: + assert isinstance(c, NonterminalNode) + n += c.num_tokens() + return n + + def has_argument(self): + result = False + for c in self.children: + if isinstance(c, TerminalNode): + return True + if isinstance(c, NonterminalNode): + result = result or c.has_argument() + return result + + +class TerminalNode(Node): + + def __init__(self, symbol, parent): + super().__init__(symbol, parent) + self.action = symbol + self.token = None + + + def expand(self, token): + self.token = token + + + def __str__(self): + return 'TerminalNode(%s, token=%s)' % (self.symbol, str(self.action)) + + + def __repr__(self): + return str(self) + + + def traverse_pre(self, callback): + callback(self) + + + def height(self): + return 0 + + +def find_rule(symbol, children, production_rules): + matches = [] + for rule in production_rules: + if rule[0] != symbol: + continue + if len(rule[1]) != len(children): + continue + for r, c in zip(rule[1], children): + if isinstance(c, Token): + if not isinstance(r, str): + break + if not (r.startswith('"') and r.endswith('"')): + break + if c.value != r[1:-1]: + break + elif isinstance(c, TerminalNode): + if not r.isupper() or r != c.symbol: + break + elif isinstance(c, NonterminalNode): + if not r.islower() or r != c.symbol: + break + else: + raise TypeError + else: + matches.append(rule) + assert len(matches) == 1 + return matches[0] + + +class TreeBuilder(Transformer): + + def __init__(self, grammar): + super().__init__() + self.grammar = grammar + + + def __default__(self, symbol, children, meta): + node = NonterminalNode(symbol, parent=None) + + for c in children: + if isinstance(c, Token): + if c.type in self.grammar.terminal_symbols: + t = TerminalNode(c.type, parent=None) + t.token = c.value + c = t + else: + assert isinstance(c, NonterminalNode) and c.parent is None + c.parent = node + node.children.append(c) + + node.action = find_rule(symbol, node.children, self.grammar.production_rules) + return node + + +if __name__ == '__main__': + grammar = CFG('tactics.ebnf', 'tactic_expr') + print(grammar) + + oup = open('fails.txt', 'wt') + num_failed = 0 + num_succeeded = 0 + + for tac_str in open('correct_tacs.txt'): + tac_str = tac_str.strip() + #assert tac_str.endswith('.') + #tac_str = tac_str[:-1] + try: + tree = grammar.parser.parse(tac_str) + pdb.set_trace() + num_succeeded += 1 + except Exception as ex: + oup.write(tac_str + '\n') + num_failed += 1 + + print(num_succeeded, num_failed) + diff --git a/src/coq/tactok/tactics.ebnf b/src/coq/tactok/tactics.ebnf new file mode 100644 index 0000000..3f2d056 --- /dev/null +++ b/src/coq/tactok/tactics.ebnf @@ -0,0 +1,135 @@ +!tactic_expr : intro + | "apply" term_commalist1 reduced_in_clause + | "auto" using_clause with_hint_dbs + | "rewrite" rewrite_term_list1 in_clause + | "simpl" in_clause + | "unfold" qualid_list1 in_clause + | destruct + | induction + | "elim" QUALID + | "split" + | "assumption" + | trivial + | "reflexivity" + | "case" QUALID + | clear + | "subst" local_ident_list + | "generalize" term_list1 + | "exists" LOCAL_IDENT + | "red" in_clause + | "omega" + | discriminate + | inversion + | simple_induction + | constructor + | "congruence" + | "left" + | "right" + | "ring" + | "symmetry" + | "f_equal" + | "tauto" + | "revert" local_ident_list1 + | "specialize" "(" LOCAL_IDENT QUALID ")" + | "idtac" + | "hnf" in_clause + | inversion_clear + | contradiction + | "injection" LOCAL_IDENT + | "exfalso" + | "cbv" + | "contradict" LOCAL_IDENT + | "lia" + | "field" + | "easy" + | "cbn" + | "exact" QUALID + | "intuition" + | "eauto" using_clause with_hint_dbs + +LOCAL_IDENT : /[A-Za-z_][A-Za-z0-9_']*/ + +QUANTIFIED_IDENT : /[A-Za-z_][A-Za-z0-9_']*/ + +INT : /1|2|3|4/ + +QUALID : /([A-Za-z_][A-Za-z0-9_']*\.)*[A-Za-z_][A-Za-z0-9_']*/ + +HINT_DB : /arith|zarith|algebra|real|sets|core|bool|datatypes|coc|set|zfc/ + +!local_ident_list : + | LOCAL_IDENT local_ident_list + +!local_ident_list1 : LOCAL_IDENT + | LOCAL_IDENT local_ident_list1 + +!qualid_list1 : QUALID + | QUALID "," qualid_list1 + +!term_list1 : QUALID + | QUALID term_list1 + +!term_commalist1 : QUALID + | QUALID "," term_commalist1 + +!hint_db_list1 : HINT_DB + | HINT_DB hint_db_list1 + +!reduced_in_clause : + | "in" LOCAL_IDENT + +!in_clause : + | "in" LOCAL_IDENT + | "in" "|- *" + | "in" "*" + +!at_clause : + | "at" INT + +!using_clause : + | "using" qualid_list1 + +!with_hint_dbs : + | "with" hint_db_list1 + | "with" "*" + +!intro : "intro" + | "intros" + +!rewrite_term : QUALID + | "->" QUALID + | "<-" QUALID + +!rewrite_term_list1 : rewrite_term + | rewrite_term "," rewrite_term_list1 + +!destruct : "destruct" term_commalist1 + +!induction : "induction" LOCAL_IDENT + | "induction" INT + +!trivial : "trivial" + +!clear : "clear" + | "clear" local_ident_list1 + +!discriminate : "discriminate" + | "discriminate" LOCAL_IDENT + +!inversion : "inversion" LOCAL_IDENT + | "inversion" INT + +!simple_induction : "simple induction" QUANTIFIED_IDENT + | "simple induction" INT + +!constructor : "constructor" + | "constructor" INT + +!inversion_clear : "inversion_clear" LOCAL_IDENT + | "inversion_clear" INT + +!contradiction : "contradiction" + | "contradiction" LOCAL_IDENT + +%import common.WS +%ignore WS diff --git a/src/coq/tactok/tactok.ml b/src/coq/tactok/tactok.ml new file mode 100644 index 0000000..f1cdf97 --- /dev/null +++ b/src/coq/tactok/tactok.ml @@ -0,0 +1,45 @@ +open Names +open Constr +open Environ +open Envutils +open Pp +open Equtils +open Proputils +open Indutils +open Funutils +open Inference +open Vars +open Utilities +open Zooming +open Nameutils +open Ltac_plugin +open Stateutils +open Lymp +open Ser_names +open Ser_environ +open Ser_goal + +let py = init "." +let agent_utils = Lymp.get_module py "agent_utils" +let prover = Lymp.get_module py "prover" + +(* Follow this format *) +(* let module = Lymp.get_module py in +let obj = Lymp.get_ref module in +let result = Lymp.get_ obj in *) + +let parse_script script = + get_list agent_utils "parse_script" [Pylist script] + +let import_model () = + get_ref agent_utils "import_model" [] + +let beam_search env prev goal = + let model = import_model () in + let script = parse_script prev in (* need to reverse prev *) + (* ser api calls to serialize *) + let ser_env = sexp_of_env env in + let ser_goal = sexp_of_goal goal in + let filter_env = Lymp.get_ref agent_utils "filter_env" [Pyref ser_env] in + let local_context, parsed_goal = Lymp.get_ref agent_utils "parse_goal" [Pyref ser_goal] in + get_list prover "beam_search" [Pyref model; Pyref filter_env; Pyref local_context; Pyref parsed_goal; Pylist script] diff --git a/src/coq/tactok/tactok.mli b/src/coq/tactok/tactok.mli new file mode 100644 index 0000000..cf23a4c --- /dev/null +++ b/src/coq/tactok/tactok.mli @@ -0,0 +1,22 @@ +open Names +open Constr +open Environ +open Envutils +open Pp +open Equtils +open Proputils +open Indutils +open Funutils +open Inference +open Vars +open Utilities +open Zooming +open Nameutils +open Ltac_plugin +open Stateutils +open Lymp +open Ser_names +open Ser_environ +open Ser_goal + +val beam_search : env -> string list -> types -> string list diff --git a/src/coq/tactok/token_vocab.pickle b/src/coq/tactok/token_vocab.pickle new file mode 100644 index 0000000..ca73c5b Binary files /dev/null and b/src/coq/tactok/token_vocab.pickle differ diff --git a/src/coq/termutils/dune b/src/coq/termutils/dune new file mode 100644 index 0000000..53d2b07 --- /dev/null +++ b/src/coq/termutils/dune @@ -0,0 +1,8 @@ +(library + (name termutils) + (public_name coq-plugin-lib.termutils) + (libraries + coq-plugin-lib.utilities + coq.engine + coq.kernel) + (wrapped false)) diff --git a/src/dune b/src/dune new file mode 100644 index 0000000..6eb1bfa --- /dev/null +++ b/src/dune @@ -0,0 +1,65 @@ +(library + (name plib) + (public_name coq-plugin-lib.plugin) + (synopsis "Coq Plugin Lib") + (flags :standard -w -27 -warn-error -A) ; CoqPP codes requires this + (modules ("plibrary")) + (modes native) + (library_flags -linkall) + (libraries + coq.vernac ; needed for vernac extend + coq-serapi.serlib + coq-plugin-lib.utilities + coq-plugin-lib.coq + coq-plugin-lib.debruijn + coq-plugin-lib.constants + coq-plugin-lib.tactok + coq-plugin-lib.decompiler + coq-plugin-lib.devutils + coq-plugin-lib.envs + coq-plugin-lib.state + coq-plugin-lib.contexts + coq-plugin-lib.inference + coq-plugin-lib.typesandequality + coq-plugin-lib.transformation + coq-plugin-lib.hofs + coq-plugin-lib.hofimpls + coq-plugin-lib.inductive + coq-plugin-lib.representationutils + coq-plugin-lib.termutils + lymp +)) + +(rule + (targets plibrary.ml) + (deps (:pp-file plibrary.ml4) ) + (action (bash "camlp5 pa_o.cmo pr_o.cmo pa_op.cmo pr_dump.cmo pa_extend.cmo q_MLast.cmo pa_macro.cmo pa_op.cmo pr_dump.cmo pa_extend.cmo q_MLast.cmo pa_macro.cmo %{lib:coq.grammar:grammar.cma} -loc loc -impl %{pp-file} -o %{targets}"))) + +(rule + (targets plib_full_plugin.cmxs) + (action (run %{ocamlopt} -shared -linkall -o %{targets} + %{lib:lymp:lymp.cmxa} + %{lib:coq-plugin-lib.utilities:utilities.cmxa} + %{lib:coq-plugin-lib.representationutils:representationutils.cmxa} + %{lib:coq-plugin-lib.state:state.cmxa} + %{lib:coq-plugin-lib.termutils:termutils.cmxa} + %{lib:coq-plugin-lib.inference:inference.cmxa} + %{lib:coq-plugin-lib.constants:constants.cmxa} + %{lib:coq-plugin-lib.envs:envs.cmxa} + %{lib:coq-plugin-lib.typesandequality:typesandequality.cmxa} + %{lib:coq-plugin-lib.hofs:hofs.cmxa} + %{lib:coq-plugin-lib.debruijn:debruijn.cmxa} + %{lib:coq-plugin-lib.hofimpls:hofimpls.cmxa} + %{lib:coq-plugin-lib.inductive:inductive.cmxa} + %{lib:coq-plugin-lib.contexts:contexts.cmxa} + %{lib:coq-plugin-lib.transformation:transformation.cmxa} + %{lib:coq-plugin-lib.devutils:devutils.cmxa} + %{lib:coq-plugin-lib.decompiler:decompiler.cmxa} + %{lib:coq-plugin-lib.coq:coq.cmxa} + %{cmxa:plib}))) + +(install + (section lib_root) + (package coq-plugin-lib) + (files + (plib_full_plugin.cmxs as coq/user-contrib/plib_full_plugin.cmxs))) diff --git a/src/plib.mlpack b/src/plib.mlpack index 2d376a0..e69de29 100644 --- a/src/plib.mlpack +++ b/src/plib.mlpack @@ -1,47 +0,0 @@ -Utilities - -Apputils -Constutils -Funutils - -Defutils -Nameutils - -Inference -Convertibility -Checking - -Equtils -Sigmautils -Produtils -Idutils -Proputils - -Stateutils -Envutils -Contextutils - -Hofs -Debruijn -Hofimpls -Substitution -Reducers -Typehofs -Filters -Zooming -Hypotheses -Filters - -Indexing -Indutils - -Modutils - -Transform - -Printing - -Decompiler - -Plibrary - diff --git a/src/plibrary.ml4 b/src/plibrary.ml4 index 197773c..311e696 100644 --- a/src/plibrary.ml4 +++ b/src/plibrary.ml4 @@ -1 +1,61 @@ DECLARE PLUGIN "plib" + +open Decompiler +open Constr +open Names +open Environ +open Assumptions +open Search +open Evd +open Printing +open Reducers +open Stdarg +open Utilities +open Zooming +open Defutils +open Envutils +open Stateutils +open Inference +open Tactics +open Pp +open Ltac_plugin +open Nameutils + +open Class_tactics +open Stdarg +open Tacarg + +open List + +open Ser_names +open Ser_environ +open Ser_goal +open Ser_constr +open Lymp + +(* --- Commands --- *) + +(* Decompiles a single term into a tactic list printed to console. *) +let decompile_command trm tacs = + let (sigma, env) = Pfedit.get_current_context () in + let sigma, trm = intern env sigma trm in + let trm = unwrap_definition env trm in + let opts = map (fun s -> (parse_tac_str s, s)) tacs in + let sigma, script = tac_from_term env sigma (fun _ sigma [] _ -> sigma, opts) trm in + (* let ser_env = sexp_of_env env in *) + (* let goal = (Typeops.infer env trm).uj_type in *) + (* let ser_goal = sexp_of_constr goal in *) + (* Feedback.msg_warning (ppx_conv_sexp env sigma goal) *) + (* Feedback.msg_warning (str "the goal: " ++ Printer.pr_constr_env env sigma goal) *) + (* Feedback.msg_debug (script) *) + Feedback.msg_debug (tac_to_string sigma script) + +(* --- Vernac syntax --- *) + +(* Decompile Command *) +VERNAC COMMAND EXTEND Decompile CLASSIFIED AS SIDEFF +| [ "Decompile" constr(trm) ] -> + [ decompile_command trm [] ] +| [ "Decompile" constr(trm) "with" string_list(l) ] -> + [ decompile_command trm l ] +END diff --git a/src/utilities/dune b/src/utilities/dune new file mode 100644 index 0000000..5bb1081 --- /dev/null +++ b/src/utilities/dune @@ -0,0 +1,6 @@ +(library + (name utilities) + (public_name coq-plugin-lib.utilities) + (libraries + coq.lib) + (wrapped false)) diff --git a/theories/Plib.v b/theories/Plib.v index 87e32df..4382fcb 100644 --- a/theories/Plib.v +++ b/theories/Plib.v @@ -1 +1,2 @@ -Declare ML Module "plib". +(* Uses the ad-hoc .cmxs file *) +Declare ML Module "plib_full_plugin". diff --git a/theories/dune b/theories/dune new file mode 100644 index 0000000..2f08575 --- /dev/null +++ b/theories/dune @@ -0,0 +1,5 @@ +(coq.theory + (name CoqPluginLib) + (package coq-plugin-lib) + (flags -q -I ./src) +)