-- Type Inference for the Simply-Typed Lambda Calculus import Data.Maybe (fromJust, isJust) -- variable names type VName = (String, Int) showVar (s, n) = s ++ if n == 0 then "" else show n infixr 7 :--> -- types data Type = TV VName | Type :--> Type deriving Eq instance Show Type where show (TV x) = showVar x show (TV x :--> b) = showVar x ++ " :--> " ++ show b show (a :--> b) = "(" ++ show a ++ ") :--> " ++ show b -- terms infixl 7 :$: data Term = V VName | Abs VName Term | Term :$: Term instance Show Term where show (V x) = showVar x show (Abs x t) = "\\" ++ showVar x ++ ". " ++ show t show (t1 :$: V x) = show t1 ++ " " ++ showVar x show (t1 :$: t2) = show t1 ++ " (" ++ show t2 ++ ")" -- substitutions type Subst = [(VName,Type)] -- type environments type Env = [(VName, Type)] -- get the binding of x get :: VName -> Env -> Maybe Type get x [] = Nothing get x ((y,ty) : env) = if x == y then Just ty else get x env -- is variable x in the domain of substitution s? inDom :: VName -> Subst -> Bool inDom x s = isJust (get x s) -- applies substition to Type appSubst :: Subst -> Type -> Type appSubst s (TV x) = if inDom x s then app s x else TV x where app :: Subst -> VName -> Type app ((y,t) : s) x = if x == y then t else app s x appSubst s (t1 :--> t2) = appSubst s t1 :--> appSubst s t2 -- does variable x occur in a Type? occurs :: VName -> Type -> Bool occurs x (TV y) = x == y occurs x (t1 :--> t2) = occurs x t1 || occurs x t2 ------------------------------------------ -- unification ------------------------------------------ type Unifier = Maybe Subst -- given list of disagreement pairs, current substitution -- produces a unifier solve :: [(Type,Type)] -> Subst -> Unifier solve [] s = Just s solve ((TV x, TV y) : das) s | x == y = solve das s | otherwise = elim x (TV y) das s solve ((TV x, t) : das) s = elim x t das s solve ((t, TV x) : das) s = elim x t das s solve ((t1 :--> t2, u1 :--> u2) : das) s = solve ((t1, u1) : (t2, u2) : das) s -- given (x, t, list of disagreement pairs, current substitution) -- propagates the constraint x = t and solves the remaining constraints elim :: VName -> Type -> [(Type,Type)] -> Subst -> Unifier elim x t das s | occurs x t = Nothing | otherwise = solve das' s' where xt = appSubst [(x, t)] das' = map (\(t1, t2) -> (xt t1, xt t2)) das s' = (x, t) : map (\(y, u) -> (y, xt u)) s -- embedding of solve unify :: (Type, Type) -> Unifier unify (t1, t2) = solve [(t1, t2)] [] ------------------------------------------ -- type inference ------------------------------------------ -- accumulate type constraints constraints :: Term -> Type -> Env -> (Int, [(Type, Type)]) -> Maybe (Int, [(Type, Type)]) -- inDom x env -- ---------------------- -- env |- V x : get x env constraints (V x) ty env (n, cs) | inDom x env = Just (n, (ty, fromJust (get x env)) : cs) | otherwise = Nothing -- env, x : tau1 |- t : tau2 -- ------------------------------ -- env |- Abs x t : tau1 :--> tau2 constraints (Abs x t) ty env (n, cs) = constraints t tau2 env' (n + 2, cs') where tau1 = TV ("t", n) tau2 = TV ("t", n + 1) env' = (x, tau1) : env cs' = (ty, tau1 :--> tau2) : cs -- env |- t1 : tau :--> ty env |- t2 : tau -- ---------------------------------------- -- env |- t1 t2 : ty constraints (t1 :$: t2) ty env (n, cs) = case constraints t1 (tau :--> ty) env (n + 1, cs) of Nothing -> Nothing Just (n', cs') -> constraints t2 tau env (n', cs') where tau = TV ("t", n) -- type inference procedure infer :: Term -> Maybe Type infer t = case constraints t (TV tau) [] (1, []) of Nothing -> Nothing Just (_, cs) -> case solve cs [] of Just s -> get tau s Nothing -> Nothing where tau = ("t", 0) -- -- terms & types for testing: -- building blocks tx = TV("x",0) ty = TV("y",0) tz = TV("z",0) f a b = a :--> b g a b = a :--> a :--> b abstr xs t = foldr (\ x -> Abs (x, 0)) t xs apply [t] = t apply (t : ts) = t :$: apply ts x = V ("x", 0) y = V ("y", 0) z = V ("z", 0) tId = Abs ("x", 0) x omega = Abs ("x", 0) (x :$: x) tFst = abstr ["x", "y"] x tSnd = abstr ["x", "y"] y comp = abstr ["x", "y", "z"] (x :$: (y :$: z)) tEx3b = abstr ["x", "y", "z"] (x :$: y :$: (y :$: z))