Skip to content
107 changes: 106 additions & 1 deletion src/Solcore/Frontend/TypeInference/TcContract.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module Solcore.Frontend.TypeInference.TcContract where

import Algebra.Graph.AdjacencyMap
import Algebra.Graph.AdjacencyMap.Algorithm
import Algebra.Graph.NonEmpty.AdjacencyMap qualified as NAG
import Control.Monad
import Control.Monad.Except
import Control.Monad.State
Expand All @@ -8,6 +11,8 @@ import Data.List
import Data.List.NonEmpty qualified as N
import Data.Map qualified as Map
import Data.Maybe
import Data.Set (Set)
import Data.Set qualified as Set
import Solcore.Frontend.Pretty.ShortName
import Solcore.Frontend.Pretty.SolcorePretty
import Solcore.Frontend.Syntax
Expand Down Expand Up @@ -68,6 +73,8 @@ tcCompUnit (CompUnit imps cs) =
checkSynonymCycles syns
let st = buildSynTable syns
cs' <- everywhereM (mkM (expandTyM st)) cs
checkRecursiveTypes (topLevelDts cs')
mapM_ checkRecursiveTypes (perContractDts cs')
mapM_ checkTopDecl (filter isClass cs')
mapM_ checkTopDecl (filter (not . isClass) cs')
typedDecls <- mapM tcTopDecl' cs'
Expand All @@ -79,6 +86,8 @@ tcCompUnit (CompUnit imps cs) =
isClass (TClassDef _) = True
isClass _ = False
syns = [s | TSym s <- cs]
topLevelDts cs' = [d | TDataDef d <- cs']
perContractDts cs' = [[d | CDataDecl d <- cds] | TContr (Contract _ _ cds) <- cs']
tcTopDecl' d = timeItNamed (shortName d) $ do
clearSubst
tcTopDecl d
Expand Down Expand Up @@ -127,6 +136,102 @@ recursiveSynonymError cyclePath =
" " ++ intercalate " -> " (map pretty cyclePath)
]

-- check for recursive data types

allDataTys :: [TopDecl Name] -> [DataTy]
allDataTys = concatMap collect
where
collect (TDataDef d) = [d]
collect (TContr (Contract _ _ cds)) = [d | CDataDecl d <- cds]
collect _ = []

tyVarNames :: Ty -> [Name]
tyVarNames (TyVar tv) = [tyvarName tv]
tyVarNames (TyCon _ ts) = concatMap tyVarNames ts
tyVarNames _ = []

-- Collect type variable names that appear in non-phantom argument positions.
-- Phantom positions (indices in the map for the head type constructor) are skipped.
nonPhantomVarNames :: Map.Map Name (Set Int) -> Ty -> [Name]
nonPhantomVarNames m (TyCon n args) =
let phantomIdxs = Map.findWithDefault Set.empty n m
in concatMap
( \(i, arg) ->
if Set.member i phantomIdxs then [] else nonPhantomVarNames m arg
)
(zip [0 ..] args)
nonPhantomVarNames _ (TyVar v) = [tyvarName v]
nonPhantomVarNames _ _ = []

-- Build the phantom-parameter map using fixpoint iteration so that
-- transitively-phantom positions are discovered. A parameter at index i of
-- type T is phantom when it never appears in a non-phantom position across all
-- constructor field types (using the current map to decide what counts as
-- non-phantom). Starting from the empty map and iterating monotonically to a
-- fixpoint ensures that every position that can be proved phantom eventually is.
buildPhantomMap :: [DataTy] -> Map.Map Name (Set Int)
buildPhantomMap dts = fixpoint initial
where
initial = Map.fromList [(dataName dt, Set.empty) | dt <- dts]

fixpoint m =
let m' = Map.fromList (map (refineEntry m) dts)
in if m == m' then m else fixpoint m'

refineEntry m (DataTy n params ctors) =
let allFieldTys = concatMap constrTy ctors
isPhantomParam p =
let pName = tyvarName p
in all (\ty -> pName `notElem` nonPhantomVarNames m ty) allFieldTys
phantomIdxs = Set.fromList [i | (i, p) <- zip [0 ..] params, isPhantomParam p]
in (n, phantomIdxs)
Comment thread
rodrigogribeiro marked this conversation as resolved.

nonPhantomTyNames :: Map.Map Name (Set Int) -> Ty -> [Name]
nonPhantomTyNames phantomMap (TyCon n args) =
n : concatMap processArg (zip [0 ..] args)
where
phantomIdxs = Map.findWithDefault Set.empty n phantomMap
processArg (i, arg)
| Set.member i phantomIdxs = []
| otherwise = nonPhantomTyNames phantomMap arg
nonPhantomTyNames _ _ = []

buildTypeDepsGraph :: Set Name -> [DataTy] -> AdjacencyMap Name
buildTypeDepsGraph userTypes dts =
overlay isolated edged
where
phantomMap = buildPhantomMap dts
isolated = vertices (Set.toList userTypes)
edged = stars [(dataName dt, deps dt) | dt <- dts]
deps (DataTy _ _ ctors) =
nub
. filter (`Set.member` userTypes)
. concatMap (\(Constr _ tys) -> concatMap (nonPhantomTyNames phantomMap) tys)
$ ctors
Comment thread
rodrigogribeiro marked this conversation as resolved.

checkRecursiveTypes :: [DataTy] -> TcM ()
checkRecursiveTypes dts =
case cyclicSccs of
[] -> pure ()
(c : _) -> recursiveTypeError (NAG.vertexList1 c)
where
userTypes = Set.fromList (map dataName dts)
Comment thread
rodrigogribeiro marked this conversation as resolved.
graph = buildTypeDepsGraph userTypes dts
cyclicSccs = filter (isCyclic graph) (vertexList (scc graph))
isCyclic origGraph sccComp =
case N.toList (NAG.vertexList1 sccComp) of
[v] -> hasEdge v v origGraph -- singleton SCC: cyclic only if self-loop
_ -> True -- 2+ vertices: always a mutual cycle

recursiveTypeError :: N.NonEmpty Name -> TcM a
recursiveTypeError cycleVerts =
throwError $
unlines
[ "Recursive data type detected:",
" " ++ intercalate ", " (map pretty (N.toList cycleVerts)),
" (Data types must be non-recursive)"
]

-- setting up pragmas for type checking

setupPragmas :: [Pragma] -> TcM ()
Expand Down Expand Up @@ -196,7 +301,7 @@ checkTopDecl _ = pure ()

tcContract :: Contract Name -> TcM (Contract Id, [(Name, Scheme)])
tcContract c@(Contract n vs cdecls) =
withLocalEnv $ withContractName n $ do
withLocalContractEnv $ withContractName n $ do
ctx' <- gets ctx
initializeEnv c
decls' <- mapM tcDecl' cdecls
Expand Down
12 changes: 12 additions & 0 deletions src/Solcore/Frontend/TypeInference/TcMonad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ withLocalEnv ta =
putEnv savedCtx
pure a

-- Like withLocalEnv but also restores the typeTable, for contract scopes
-- where data type names must not leak between sibling contracts.
withLocalContractEnv :: TcM a -> TcM a
withLocalContractEnv ta =
do
savedCtx <- gets ctx
savedTypes <- gets typeTable
a <- ta
putEnv savedCtx
modify (\env -> env {typeTable = savedTypes})
pure a

envList :: TcM [(Name, Scheme)]
envList = gets (Map.toList . ctx)

Expand Down
20 changes: 13 additions & 7 deletions test/Cases.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ cases =
runTestForFile "const.solc" caseFolder,
runTestExpectingFailure "const-array.solc" caseFolder,
runTestForFile "constructor-weak-args.solc" caseFolder,
runTestForFile "constructors-contract.solc" caseFolder,
runTestExpectingFailure "complexproxy.solc" caseFolder,
runTestForFile "cyclical-defs.solc" caseFolder,
runTestForFile "cyclical-defs-inferred.solc" caseFolder,
Expand All @@ -157,7 +158,7 @@ cases =
runTestExpectingFailure "Enum.solc" caseFolder,
runTestExpectingFailure "Eq.solc" caseFolder,
runTestForFile "EqQual.solc" caseFolder,
runTestForFile "EvenOdd.solc" caseFolder,
runTestExpectingFailure "EvenOdd.solc" caseFolder,
runTestExpectingFailure "Filter.solc" caseFolder,
runTestForFile "foo-class.solc" caseFolder,
runTestForFile "Foo.solc" caseFolder,
Expand All @@ -177,8 +178,8 @@ cases =
runTestExpectingFailure "joinErr.solc" caseFolder,
runTestExpectingFailure "KindTest.solc" caseFolder,
runTestExpectingFailure "listeq.solc" caseFolder,
runTestForFile "ListModule.solc" caseFolder,
runTestForFile "listid.solc" caseFolder,
runTestExpectingFailure "ListModule.solc" caseFolder,
runTestExpectingFailure "listid.solc" caseFolder,
runTestForFile "Logic.solc" caseFolder,
runTestExpectingFailure "mainproxy.solc" caseFolder,
runTestForFile "MatchCall.solc" caseFolder,
Expand All @@ -192,6 +193,7 @@ cases =
runTestForFile "modifier.solc" caseFolder,
runTestForFile "morefun.solc" caseFolder,
runTestForFile "Mutuals.solc" caseFolder,
runTestForFile "rec-memory.solc" caseFolder,
runTestExpectingFailure "nano-desugared.solc" caseFolder,
runTestForFile "NegPair.solc" caseFolder,
runTestForFile "nid.solc" caseFolder,
Expand All @@ -206,8 +208,8 @@ cases =
runTestExpectingFailure "PairMatch2.solc" caseFolder,
-- failing due to missing assign constraint
runTestExpectingFailure "patterson-bug.solc" caseFolder,
runTestForFile "Peano.solc" caseFolder,
runTestForFile "PeanoMatch.solc" caseFolder,
runTestExpectingFailure "Peano.solc" caseFolder,
runTestExpectingFailure "PeanoMatch.solc" caseFolder,
runTestForFile "polymatch-error.solc" caseFolder,
runTestExpectingFailure "pragma_merge_fail_coverage.solc" caseFolder,
runTestExpectingFailure "pragma_merge_fail_patterson.solc" caseFolder,
Expand All @@ -218,6 +220,8 @@ cases =
runTestForFile "proxy.solc" caseFolder,
runTestExpectingFailure "proxy1.solc" caseFolder,
runTestForFile "rec.solc" caseFolder,
runTestExpectingFailure "recursive-type-direct.solc" caseFolder,
runTestExpectingFailure "recursive-type-mutual.solc" caseFolder,
runTestExpectingFailure "Ref.solc" caseFolder,
runTestForFile "RefDeref.solc" caseFolder,
runTestExpectingFailure "reference.solc" caseFolder,
Expand Down Expand Up @@ -251,7 +255,7 @@ cases =
runTestExpectingFailure "subject-index.solc" caseFolder,
runTestExpectingFailure "subject-reduction.solc" caseFolder,
runTestExpectingFailure "subsumption-test.solc" caseFolder,
runTestForFile "super-class.solc" caseFolder,
runTestExpectingFailure "super-class.solc" caseFolder,
runTestForFile "super-class-num.solc" caseFolder,
runTestForFile "tiamat.solc" caseFolder,
runTestForFile "tuple-trick.solc" caseFolder,
Expand Down Expand Up @@ -299,6 +303,7 @@ cases =
runTestForFile "redundant-match.solc" caseFolder,
runTestForFile "false-redundant-warning.solc" caseFolder,
runTestForFile "proxy-desugar.solc" caseFolder,
runTestForFile "box.solc" caseFolder,
runTestForFile "invokable-issue.solc" caseFolder,
runTestForFile "td.solc" caseFolder,
runTestForFile "bar.solc" caseFolder,
Expand All @@ -313,7 +318,8 @@ cases =
runTestExpectingFailure "overlap-synonym-missed-order.solc" caseFolder,
runTestExpectingFailure "overlap-synonym-missed-two-synonyms.solc" caseFolder,
runTestForFile "copytomem.solc" caseFolder,
runTestForFile "fresh-variable-shadowing.solc" caseFolder
runTestForFile "fresh-variable-shadowing.solc" caseFolder,
runTestExpectingFailure "synonym-example.solc" caseFolder
]
where
caseFolder = "./test/examples/cases"
Expand Down
4 changes: 3 additions & 1 deletion test/examples/cases/Ackermann.solc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
data Nat = Zero | Succ(Nat) ;
data memory(a) = memory(word);

data Nat = Zero | Succ(memory(Nat)) ;

function foo (x, y) {
match y, x {
Expand Down
22 changes: 18 additions & 4 deletions test/examples/cases/EitherModule.solc
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
contract EitherModule {
data memory(a) = memory(word);
data Either(a,b) = Left(a) | Right(b);
data List(a) = Nil | Cons(a,List(a));
data List(a) = Nil | Cons(a, memory(List(a)));

function lefts(xs) {
forall a b . function lefts(xs : List (Either(a,b))) -> List(a) {
match xs {
| Nil => return Nil ;
| Cons(y,ys) =>
match y {
| Left(z) => return Cons(z,lefts(ys)) ;
| Right(z) => return lefts(ys) ;
| Left(z) => return Cons(z,storeList(lefts(loadList(ys)))) ;
| Right(z) => return lefts(loadList(ys)) ;
}
}
}

forall a . function loadList(xs : memory(List(a))) -> List(a) {
match xs {
| memory(_) => return Nil ;
}
}

forall a . function storeList (xs : List(a)) -> memory(List(a)) {
return memory(0);
}

function main () -> word {
return 42;
}
}
3 changes: 3 additions & 0 deletions test/examples/cases/box.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data memory(a) = memory(word);
data Box(a) = Box(memory(a));
data Rec = Rec(Box(Rec));
9 changes: 9 additions & 0 deletions test/examples/cases/constructors-contract.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
contract A {
data T = MkT(U);
data U = MkU(word);
}

contract B {
data T = MkT2(word);
data U = MkU2(T);
}
3 changes: 3 additions & 0 deletions test/examples/cases/rec-memory.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data memory(a) = memory(word);

data List(a) = Nil | Cons(a, memory(List(a)));
9 changes: 9 additions & 0 deletions test/examples/cases/recursive-type-direct.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Direct recursion: Nat refers to itself in the Succ constructor.
-- Expected: type checker rejects with "Recursive data type detected".
data Nat = Zero | Succ(Nat);

contract RecursiveTypeDirect {
function main() -> Nat {
Zero
}
}
10 changes: 10 additions & 0 deletions test/examples/cases/recursive-type-mutual.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- Mutual recursion: Even refers to Odd and Odd refers to Even.
-- Expected: type checker rejects with "Recursive data type detected".
data Even = Zero | SuccE(Odd);
data Odd = SuccO(Even);

contract RecursiveTypeMutual {
function main() -> Even {
Zero
}
}
3 changes: 3 additions & 0 deletions test/examples/cases/synonym-example.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

type Ref = T;
data T = Mk(Ref);
4 changes: 3 additions & 1 deletion test/examples/pragmas/coverage.solc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pragma no-coverage-condition ;

data List(a) = Nil | Cons(a,List(a));
data memory(a) = memory(word);

data List(a) = Nil | Cons(a,memory(List(a)));
data Bool = True | False ;

forall a b c . class a : C(b,c) {}
Expand Down
Loading