@@ -694,6 +694,47 @@ interface Convert (tm : List Name -> Type) where
694694 = do q <- newRef QVar 0
695695 convGen q defs env tm tm'
696696
697+ tryUpdate : {vars, vars' : _ } ->
698+ List (Var vars, Var vars') ->
699+ Term vars -> Maybe (Term vars')
700+ tryUpdate ms (Local fc l idx p)
701+ = do MkVar p' <- findIdx ms idx
702+ pure $ Local fc l _ p'
703+ where
704+ findIdx : List (Var vars, Var vars') -> Nat -> Maybe (Var vars')
705+ findIdx [] _ = Nothing
706+ findIdx ((MkVar {i} _ , v) :: ps) n
707+ = if i == n then Just v else findIdx ps n
708+ tryUpdate ms (Ref fc nt n) = pure $ Ref fc nt n
709+ tryUpdate ms (Meta fc n i args) = pure $ Meta fc n i ! (traverse (tryUpdate ms) args)
710+ tryUpdate ms (Bind fc x b sc)
711+ = do b' <- tryUpdateB b
712+ pure $ Bind fc x b' ! (tryUpdate (map weakenP ms) sc)
713+ where
714+ tryUpdatePi : PiInfo (Term vars) -> Maybe (PiInfo (Term vars'))
715+ tryUpdatePi Explicit = pure Explicit
716+ tryUpdatePi Implicit = pure Implicit
717+ tryUpdatePi AutoImplicit = pure AutoImplicit
718+ tryUpdatePi (DefImplicit t) = pure $ DefImplicit ! (tryUpdate ms t)
719+
720+ tryUpdateB : Binder (Term vars) -> Maybe (Binder (Term vars'))
721+ tryUpdateB (Lam r p t) = pure $ Lam r ! (tryUpdatePi p) ! (tryUpdate ms t)
722+ tryUpdateB (Let r v t) = pure $ Let r ! (tryUpdate ms v) ! (tryUpdate ms t)
723+ tryUpdateB (Pi r p t) = pure $ Pi r ! (tryUpdatePi p) ! (tryUpdate ms t)
724+ tryUpdateB _ = Nothing
725+
726+ weakenP : {n : _} -> (Var vars, Var vars') ->
727+ (Var (n :: vars), Var (n :: vars'))
728+ weakenP (v, vs) = (weaken v, weaken vs)
729+ tryUpdate ms (App fc f a) = pure $ App fc ! (tryUpdate ms f) ! (tryUpdate ms a)
730+ tryUpdate ms (As fc s a p) = pure $ As fc s ! (tryUpdate ms a) ! (tryUpdate ms p)
731+ tryUpdate ms (TDelayed fc r tm) = pure $ TDelayed fc r ! (tryUpdate ms tm)
732+ tryUpdate ms (TDelay fc r ty tm) = pure $ TDelay fc r ! (tryUpdate ms ty) ! (tryUpdate ms tm)
733+ tryUpdate ms (TForce fc r tm) = pure $ TForce fc r ! (tryUpdate ms tm)
734+ tryUpdate ms (PrimVal fc c) = pure $ PrimVal fc c
735+ tryUpdate ms (Erased fc i) = pure $ Erased fc i
736+ tryUpdate ms (TType fc) = pure $ TType fc
737+
697738mutual
698739 allConv : {vars : _} ->
699740 Ref QVar Int -> Defs -> Env Term vars ->
@@ -703,6 +744,87 @@ mutual
703744 = pure $ ! (convGen q defs env x y) && ! (allConv q defs env xs ys)
704745 allConv q defs env _ _ = pure False
705746
747+ -- If the case trees match in structure, get the list of variables which
748+ -- have to match in the call
749+ getMatchingVarAlt : {args, args' : _ } ->
750+ Defs ->
751+ List (Var args, Var args') ->
752+ CaseAlt args -> CaseAlt args' ->
753+ Core (Maybe (List (Var args, Var args')))
754+ getMatchingVarAlt defs ms (ConCase n tag cargs t) (ConCase n' tag' cargs' t')
755+ = if n == n'
756+ then do let Just ms' = extend cargs cargs' ms
757+ | Nothing => pure Nothing
758+ Just ms <- getMatchingVars defs ms' t t'
759+ | Nothing => pure Nothing
760+ -- drop the prefix from cargs/cargs' since they won't
761+ -- be in the caller
762+ pure (Just (mapMaybe (dropP cargs cargs') ms))
763+ else pure Nothing
764+ where
765+ weakenP : {c, c', args, args' : _ } ->
766+ (Var args, Var args') ->
767+ (Var (c :: args), Var (c' :: args'))
768+ weakenP (v, vs) = (weaken v, weaken vs)
769+
770+ extend : (cs : List Name) -> (cs' : List Name) ->
771+ (List (Var args, Var args')) ->
772+ Maybe (List (Var (cs ++ args), Var (cs' ++ args')))
773+ extend [] [] ms = pure ms
774+ extend (c :: cs) (c' :: cs') ms
775+ = do rest <- extend cs cs' ms
776+ pure ((MkVar First , MkVar First ) :: map weakenP rest)
777+ extend _ _ _ = Nothing
778+
779+ dropV : forall args .
780+ (cs : List Name) -> Var (cs ++ args) -> Maybe (Var args)
781+ dropV [] v = Just v
782+ dropV (c :: cs) (MkVar First ) = Nothing
783+ dropV (c :: cs) (MkVar (Later x))
784+ = dropV cs (MkVar x)
785+
786+ dropP : (cs : List Name) -> (cs' : List Name) ->
787+ (Var (cs ++ args), Var (cs' ++ args')) ->
788+ Maybe (Var args, Var args')
789+ dropP cs cs' (x, y) = pure (! (dropV cs x), ! (dropV cs' y))
790+
791+ getMatchingVarAlt defs ms (ConstCase c t) (ConstCase c' t')
792+ = if c == c'
793+ then getMatchingVars defs ms t t'
794+ else pure Nothing
795+ getMatchingVarAlt defs ms (DefaultCase t) (DefaultCase t')
796+ = getMatchingVars defs ms t t'
797+ getMatchingVarAlt defs _ _ _ = pure Nothing
798+
799+ getMatchingVarAlts : {args, args' : _ } ->
800+ Defs ->
801+ List (Var args, Var args') ->
802+ List (CaseAlt args) -> List (CaseAlt args') ->
803+ Core (Maybe (List (Var args, Var args')))
804+ getMatchingVarAlts defs ms [] [] = pure (Just ms)
805+ getMatchingVarAlts defs ms (a :: as) (a' :: as')
806+ = do Just ms <- getMatchingVarAlt defs ms a a'
807+ | Nothing => pure Nothing
808+ getMatchingVarAlts defs ms as as'
809+ getMatchingVarAlts defs _ _ _ = pure Nothing
810+
811+ getMatchingVars : {args, args' : _ } ->
812+ Defs ->
813+ List (Var args, Var args') ->
814+ CaseTree args -> CaseTree args' ->
815+ Core (Maybe (List (Var args, Var args')))
816+ getMatchingVars defs ms (Case _ p _ alts) (Case _ p' _ alts')
817+ = getMatchingVarAlts defs ((MkVar p, MkVar p') :: ms) alts alts'
818+ getMatchingVars defs ms (STerm i tm) (STerm i' tm')
819+ = do let Just tm'' = tryUpdate ms tm
820+ | Nothing => pure Nothing
821+ if !(convert defs (mkEnv (getLoc tm) args' ) tm'' tm' )
822+ then pure (Just ms)
823+ else pure Nothing
824+ getMatchingVars defs ms (Unmatched _ ) (Unmatched _ ) = pure (Just ms)
825+ getMatchingVars defs ms Impossible Impossible = pure (Just ms)
826+ getMatchingVars _ _ _ _ = pure Nothing
827+
706828 chkSameDefs : {vars : _} ->
707829 Ref QVar Int -> Defs -> Env Term vars ->
708830 Name -> Name ->
@@ -712,9 +834,32 @@ mutual
712834 | _ => pure False
713835 Just (PMDef _ args' ct' rt' _ ) <- lookupDefExact n' (gamma defs)
714836 | _ => pure False
715- if (length args == length args' && eqTree rt rt')
716- then allConv q defs env nargs nargs'
717- else pure False
837+
838+ -- If the two case blocks match in structure, get which variables
839+ -- correspond. If corresponding variables convert, the two case
840+ -- blocks convert.
841+ Just ms <- getMatchingVars defs [] ct ct'
842+ | Nothing => pure False
843+ convertMatches ms
844+ where
845+ -- We've only got the index into the argument list, and the indices
846+ -- don't match up, which is annoying. But it'll always be there!
847+ getArgPos : Nat -> List (Closure vars) -> Maybe (Closure vars)
848+ getArgPos _ [] = Nothing
849+ getArgPos Z (c :: cs) = pure c
850+ getArgPos (S k) (c :: cs) = getArgPos k cs
851+
852+ convertMatches : {vs, vs' : _ } ->
853+ List (Var vs, Var vs') ->
854+ Core Bool
855+ convertMatches [] = pure True
856+ convertMatches ((MkVar {i} p, MkVar {i= i'} p') :: vs)
857+ = do let Just varg = getArgPos i nargs
858+ | Nothing => pure False
859+ let Just varg' = getArgPos i' nargs'
860+ | Nothing => pure False
861+ pure $ ! (convGen q defs env varg varg') &&
862+ ! (convertMatches vs)
718863
719864 -- If two names are standing for case blocks, check the blocks originate
720865 -- from the same place, and have the same scrutinee
0 commit comments