Skip to content

Atoms are now converted to tagged records in AtomFolding.hs #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: dev-integrity
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions compiler/src/AtomFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@ module AtomFolding ( visitProg )
where
import Basics
import Direct
import Data.Maybe
import Control.Monad
import Data.List (find, any)

visitProg :: Prog -> Prog
visitProg (Prog imports (Atoms atms) tm) =
Prog imports (Atoms atms) (visitTerm atms tm)
visitProg (Prog imports (DataTypes datatypes) tm) =
let tcs = concat $ map snd datatypes
in Prog imports (DataTypes datatypes) (visitTerm tcs tm)

visitTerm :: [AtomName] -> Term -> Term
visitTerm :: [TypeConstructor] -> Term -> Term
visitTerm atms (Lit lit) = Lit lit
visitTerm atms (Var nm) =
if (elem nm atms)
then Lit (LAtom nm)
else Var nm
let tag = "tag"
value = "value"
var = "v"
in case find (\x -> (fst x) == nm) atms of
Nothing -> Var nm
Just (t, []) -> Record [(tag, Just (Lit (LString nm)))] True -- Convert atom into a tagged record
Just (t, _) ->
Abs (Lambda [VarPattern var] (Record [(tag, Just (Lit (LString nm)))
, (value, Just (Var var))
] True))
visitTerm atms (Abs lam) =
Abs (visitLambda atms lam)
visitTerm atms (Hnd (Handler pat maybePat maybeTerm term)) =
Expand All @@ -38,7 +46,7 @@ visitTerm atms (If t1 t2 t3) =
If (visitTerm atms t1) (visitTerm atms t2) (visitTerm atms t3)
visitTerm atms (Tuple terms) =
Tuple (map (visitTerm atms) terms)
visitTerm atms (Record fields) = Record (visitFields atms fields)
visitTerm atms (Record fields tag) = Record (visitFields atms fields) tag
visitTerm atms (WithRecord e fields) =
WithRecord (visitTerm atms e) (visitFields atms fields)
visitTerm atms (ProjField t f) =
Expand All @@ -63,10 +71,10 @@ visitFields atms fs = map visitField fs
where visitField (f, Nothing) = (f, Nothing)
visitField (f, Just t) = (f, Just (visitTerm atms t))

visitPattern :: [AtomName] -> DeclPattern -> DeclPattern
visitPattern :: [TypeConstructor] -> DeclPattern -> DeclPattern
visitPattern atms pat@(VarPattern nm) =
if (elem nm atms)
then ValPattern (LAtom nm)
if any (\x -> x == (nm, [])) atms
then RecordPattern [("tag", Just (ValPattern (LString nm)))] ExactMatch -- Convert atom match into a record match
else pat
visitPattern _ pat@(ValPattern _) = pat
visitPattern atms (AtPattern p l) = AtPattern (visitPattern atms p) l
Expand All @@ -77,7 +85,12 @@ visitPattern atms (ListPattern pats) = ListPattern (map (visitPattern atms) pats
visitPattern atms (RecordPattern fields mode) = RecordPattern (map visitField fields) mode
where visitField pat@(_, Nothing) = pat
visitField (f, Just p) = (f, Just (visitPattern atms p))
visitPattern atms (DataTypePattern nm pat) =
RecordPattern [("tag", Just (ValPattern (LString nm)))
,("value", Just (visitPattern atms pat))] ExactMatch


visitLambda :: [AtomName] -> Lambda -> Lambda
visitLambda :: [TypeConstructor] -> Lambda -> Lambda
visitLambda atms (Lambda pats term) =
(Lambda (map (visitPattern atms) pats) (visitTerm atms term))

5 changes: 5 additions & 0 deletions compiler/src/Basics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import Data.Serialize (Serialize)

type VarName = String
type AtomName = String
type DataTypeName = String
type TypeConstructorName = String
type TypeConstructor = (TypeConstructorName, [VarName])
type DataTypeDef = (DataTypeName, [TypeConstructor])
type FieldName = String
type ADTTag = Bool

-- | Eq and Neq: deep equality check on the two parameters, including the types (any type inequality results in false being returned).
data BinOp = Plus | Minus | Mult | Div | Mod | Eq | Neq | Le | Lt | Ge | Gt | And | Or | RaisedTo | FlowsTo | Concat| IntDiv | BinAnd | BinOr | BinXor | BinShiftLeft | BinShiftRight | BinZeroShiftRight | HasField | LatticeJoin | LatticeMeet
Expand Down
18 changes: 9 additions & 9 deletions compiler/src/CPSOpt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ instance Substitutable SimpleTerm where
Bin op v1 v2 -> Bin op (fwd v1) (fwd v2)
Un op v -> Un op (fwd v)
Tuple vs -> Tuple (map fwd vs)
Record fields -> Record $ fwdFields fields
Record fields tag -> Record (fwdFields fields) tag
WithRecord x fields -> WithRecord (fwd x) $ fwdFields fields
ProjField x f -> ProjField (fwd x) f
ProjIdx x idx -> ProjIdx (fwd x) idx
Expand Down Expand Up @@ -146,7 +146,7 @@ instance CensusCollectible SimpleTerm where
Un _ v -> updateCensus v
ValSimpleTerm sv -> updateCensus sv
Tuple vs -> updateCensus vs
Record fs -> let (_,vs) = unzip fs in updateCensus vs
Record fs _ -> let (_,vs) = unzip fs in updateCensus vs
WithRecord v fs -> updateCensus v >> (let (_,vs) = unzip fs in updateCensus vs )
ProjField v _ -> updateCensus v
ProjIdx v _ -> updateCensus v
Expand Down Expand Up @@ -256,14 +256,14 @@ censusInfo x = do
fields x = do
w <- look x
case w of
St (Record xs) -> return xs
St (Record xs _) -> return xs
St (WithRecord y ys) -> do
xs <- fields y
return $ xs ++ ys
_ -> return []


isRecordTerm (St (Record _)) = True
isRecordTerm (St (Record _ _)) = True
isRecordTerm (St (WithRecord _ _ )) = True
isRecordTerm _ = False

Expand Down Expand Up @@ -327,14 +327,14 @@ simplifySimpleTerm t =
-- TODO should write out all cases
case (op,v) of
(Basics.IsTuple, St (Tuple _)) -> _ret __trueLit
(Basics.IsTuple, St (Record _)) -> _ret __falseLit
(Basics.IsTuple, St (Record _ _)) -> _ret __falseLit
(Basics.IsTuple, St (WithRecord _ _)) -> _ret __falseLit
(Basics.IsTuple, St (List _)) -> _ret __falseLit
(Basics.IsTuple, St (ListCons _ _)) -> _ret __falseLit
(Basics.IsTuple, St (ValSimpleTerm _)) -> _ret __falseLit


(Basics.IsRecord, St (Record _)) -> _ret __trueLit
(Basics.IsRecord, St (Record _ _)) -> _ret __trueLit
(Basics.IsRecord, St (WithRecord _ _)) -> _ret __trueLit
(Basics.IsRecord, St (Tuple _)) -> _ret __falseLit
(Basics.IsRecord, St (List _)) -> _ret __falseLit
Expand All @@ -344,7 +344,7 @@ simplifySimpleTerm t =

(Basics.IsList, St (List _)) -> _ret __trueLit
(Basics.IsList, St (ListCons _ _)) -> _ret __trueLit
(Basics.IsList, St (Record _)) -> _ret __falseLit
(Basics.IsList, St (Record _ _)) -> _ret __falseLit
(Basics.IsList, St (WithRecord _ _)) -> _ret __falseLit
(Basics.IsList, St (Tuple _)) -> _ret __falseLit
(Basics.IsList, St (ValSimpleTerm _)) -> _ret __falseLit
Expand Down Expand Up @@ -410,7 +410,7 @@ failFree st = case st of
Un _ _ -> False -- Unary operations can fail (e.g., head on empty list, arithmetic on non-numbers)
ValSimpleTerm _ -> True
Tuple _ -> True
Record _ -> True
Record _ _ -> True
WithRecord _ _ -> True
ProjField _ _ -> False -- Field projection can fail if field doesn't exist
ProjIdx _ _ -> False -- Index projection can fail if index out of bounds
Expand Down Expand Up @@ -546,4 +546,4 @@ iter kt =

rewrite :: Prog -> Prog
rewrite (Prog atoms kterm) =
Prog atoms (iter kterm)
Prog atoms (iter kterm)
12 changes: 6 additions & 6 deletions compiler/src/CaseElimination.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
tm'' <- transTerm tm'
return (T.Prog imports atms' tm'')

transAtoms :: S.Atoms -> Trans T.Atoms
transAtoms (S.Atoms atms) = return (T.Atoms atms)
transAtoms :: S.DataTypes -> Trans T.DataTypes
transAtoms (S.DataTypes atms) = return (T.DataTypes atms)

transLit :: S.Lit -> T.Lit
transLit (S.LInt n pi) = T.LInt n pi
Expand All @@ -41,7 +41,7 @@
transLit (S.LDCLabel dc) = T.LDCLabel dc
transLit (S.LUnit) = T.LUnit
transLit (S.LBool b) = T.LBool b
transLit (S.LAtom a) = T.LAtom a
transLit (S.LDataType a) = T.LDataType a


transLambda_aux :: S.Lambda -> ReaderT T.Term Trans Lambda
Expand Down Expand Up @@ -120,7 +120,7 @@
-- v: the term to be assigned to the pattern
-- The Reader monad stores the error term.
compilePattern :: T.Term -> (T.Term, S.DeclPattern) -> ReaderT T.Term Trans T.Term
compilePattern succ (v, (S.AtPattern p l)) = do

Check warning on line 123 in compiler/src/CaseElimination.hs

View workflow job for this annotation

GitHub Actions / build_and_test

Pattern match(es) are non-exhaustive
fail <- ask
succ' <- compilePattern succ (v, p)
return $ ifpat (Bin Eq (Un LevelOf v) (Lit (LLabel l))) succ' fail
Expand Down Expand Up @@ -260,9 +260,9 @@
transTerm (S.Tuple tms) = do
tms' <- mapM transTerm tms
return (T.Tuple tms')
transTerm (S.Record fields) = do
transTerm (S.Record fields tag) = do
fields' <- transFields fields
return (T.Record fields')
return (T.Record fields' tag)
transTerm (S.WithRecord e fields) = do
e' <- transTerm e
fields' <- transFields fields
Expand Down Expand Up @@ -302,4 +302,4 @@
(f, Nothing) -> return (f, T.Var f)
(f, Just t) -> do
t' <- transTerm t
return (f, t')
return (f, t')
4 changes: 2 additions & 2 deletions compiler/src/ClosureConv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ cpsToIR (CPS.LetSimple vname@(VN ident) st kt) = do
CPS.Tuple lst -> do
lst' <- transVars lst
_assign (Tuple lst')
CPS.Record fields -> do
CPS.Record fields tag -> do
fields' <- transFields fields
_assign (Record fields')
_assign (Record fields' tag)
CPS.WithRecord x fields -> do
x' <- transVar x
fields' <- transFields fields
Expand Down
24 changes: 15 additions & 9 deletions compiler/src/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import ShowIndent

import TroupePositionInfo
import DCLabels
import Data.List (find)

--------------------------------------------------
-- AST is the same as Direct, but lambda are unary (or nullary)
Expand All @@ -57,7 +58,7 @@ data Lit
| LDCLabel DCLabelExp
| LUnit
| LBool Bool
| LAtom AtomName
| LAtom TypeConstructorName
deriving (Show, Generic)
instance Serialize Lit
instance Eq Lit where
Expand Down Expand Up @@ -108,7 +109,7 @@ data Term
| If Term Term Term
| AssertElseError Term Term Term PosInf
| Tuple [Term]
| Record Fields
| Record Fields ADTTag
| WithRecord Term Fields
| ProjField Term FieldName
| ProjIdx Term Word
Expand Down Expand Up @@ -157,8 +158,8 @@ lowerProg (D.Prog imports atms term) = Prog imports (trans atms) (lower term)

-- the rest of the declarations in this part are not exported

trans :: D.Atoms -> Atoms
trans (D.Atoms atms) = Atoms atms
trans :: D.DataTypes -> Atoms
trans (D.DataTypes atms) = Atoms [] -- (concat $ map snd atms)

lowerLam (D.Lambda vs t) =
case vs of
Expand All @@ -172,7 +173,7 @@ lowerLit (D.LLabel s) = LLabel s
lowerLit (D.LDCLabel dc) = LDCLabel dc
lowerLit D.LUnit = LUnit
lowerLit (D.LBool b) = LBool b
lowerLit (D.LAtom n) = LAtom n
lowerLit (D.LDataType n) = LAtom n

lower :: D.Term -> Core.Term
lower (D.Lit l) = Lit (lowerLit l)
Expand All @@ -199,7 +200,7 @@ lower (D.Let decls e) =
lower (D.If e1 e2 e3) = If (lower e1) (lower e2) (lower e3)
lower (D.AssertElseError e1 e2 e3 p) = AssertElseError (lower e1 ) (lower e2) (lower e3) p
lower (D.Tuple terms) = Tuple (map lower terms)
lower (D.Record fields) = Record (map (\(f, t) -> (f, lower t)) fields)
lower (D.Record fields tag) = Record (map (\(f, t) -> (f, lower t)) fields) tag
lower (D.WithRecord e fields) = WithRecord (lower e) (map (\(f, t) -> (f, lower t)) fields)
lower (D.ProjField t f) = ProjField (lower t) f
lower (D.ProjIdx t idx) = ProjIdx (lower t) idx
Expand Down Expand Up @@ -333,8 +334,8 @@ rename (AssertElseError t1 t2 t3 p) m = do
rename (Tuple terms) m =
Tuple <$> mapM (flip rename m) terms

rename (Record fields) m =
Record <$> mapM renameField fields
rename (Record fields tag) m =
(\x -> Record x tag) <$> mapM renameField fields
where renameField (f, t) = do
t' <- rename t m
return (f, t')
Expand Down Expand Up @@ -448,7 +449,12 @@ ppTerm' (List ts) =
PP.hcat $
PP.punctuate (text ",") (map (ppTerm 0) ts)

ppTerm' (Record fs) = PP.braces $ qqFields fs
ppTerm' (Record fs False) = PP.braces $ qqFields fs
ppTerm' (Record fs True) = -- We should not be able to git the "MissingADT" cases - 2025-08-08: ASL
case find (\x -> fst x == "tag") fs of
Just (_, Lit (LString nm)) -> text nm
Just _ -> text "MissingADT"
Nothing -> text "MissingADT"

ppTerm' (WithRecord e fs) =
PP.braces $ PP.hsep [ ppTerm 0 e, text "with", qqFields fs]
Expand Down
Loading
Loading