I'm currently building a program to type inference in F# using the following AST:
// errors
//
exception SyntaxError of string * FSharp.Text.Lexing.LexBuffer<char>
exception TypeError of string
exception UnexpectedError of string
let throw_formatted exnf fmt = ksprintf (fun s -> raise (exnf s)) fmt
let unexpected_error fmt = throw_formatted UnexpectedError fmt
// AST type definitions
//
type tyvar = string
type ty =
| TyName of string //constants
| TyArrow of ty * ty
| TyVar of tyvar //'a,'b for example
| TyTuple of ty list
// pseudo data constructors for literal types
let TyFloat = TyName "float"
let TyInt = TyName "int"
let TyChar = TyName "char"
let TyString = TyName "string"
let TyBool = TyName "bool"
let TyUnit = TyName "unit"
// active pattern for literal types
let private (|TyLit|_|) name = function
| TyName s when s = name -> Some ()
| _ -> None
let (|TyFloat|_|) = (|TyLit|_|) "float"
let (|TyInt|_|) = (|TyLit|_|) "int"
let (|TyChar|_|) = (|TyLit|_|) "char"
let (|TyString|_|) = (|TyLit|_|) "string"
let (|TyBool|_|) = (|TyLit|_|) "bool"
let (|TyUnit|_|) = (|TyLit|_|) "unit"
type scheme = Forall of tyvar list * ty
type lit = LInt of int
| LFloat of float
| LString of string
| LChar of char
| LBool of bool
| LUnit
type binding = bool * string * ty option * expr // (is_recursive, id, optional_type_annotation, expression)
and expr =
| Lit of lit
| Lambda of string * ty option * expr
| App of expr * expr
| Var of string
| LetIn of binding * expr
| IfThenElse of expr * expr * expr option
| Tuple of expr list
| BinOp of expr * string * expr
| UnOp of string * expr
let fold_params parms e0 =
List.foldBack (fun (id, tyo) e -> Lambda (id, tyo, e)) parms e0
let (|Let|_|) = function
| LetIn ((false, x, tyo, e1), e2) -> Some (x, tyo, e1, e2)
| _ -> None
let (|LetRec|_|) = function
| LetIn ((true, x, tyo, e1), e2) -> Some (x, tyo, e1, e2)
| _ -> None
type 'a env = (string * 'a) list
type value =
| VLit of lit
| VTuple of value list
| Closure of value env * string * expr
| RecClosure of value env * string * string * expr
type interactive = IExpr of expr | IBinding of binding
// pretty printers
//
// utility function for printing lists by flattening strings with a separator
let rec flatten p sep es =
match es with
| [] -> ""
| [e] -> p e
| e :: es -> sprintf "%s%s %s" (p e) sep (flatten p sep es)
// print pairs within the given env using p as printer for the elements bound within
let pretty_env p env = sprintf "[%s]" (flatten (fun (x, o) -> sprintf "%s=%s" x (p o)) ";" env)
// print any tuple given a printer p for its elements
let pretty_tupled p l = flatten p ", " l
let rec pretty_ty t =
match t with
| TyName s -> s
| TyArrow (t1, t2) -> sprintf "%s -> %s" (pretty_ty t1) (pretty_ty t2)
| TyVar n -> sprintf "'%s" n
| TyTuple ts -> sprintf "(%s)" (pretty_tupled pretty_ty ts)
let pretty_lit lit =
match lit with
| LInt n -> sprintf "%d" n
| LFloat n -> sprintf "%g" n
| LString s -> sprintf "\"%s\"" s
| LChar c -> sprintf "%c" c
| LBool true -> "true"
| LBool false -> "false"
| LUnit -> "()"
let rec pretty_expr e =
match e with
| Lit lit -> pretty_lit lit
| Lambda (x, None, e) -> sprintf "fun %s -> %s" x (pretty_expr e)
| Lambda (x, Some t, e) -> sprintf "fun (%s : %s) -> %s" x (pretty_ty t) (pretty_expr e)
// TODO pattern-match sub-application cases
| App (e1, e2) -> sprintf "%s %s" (pretty_expr e1) (pretty_expr e2)
| Var x -> x
| Let (x, None, e1, e2) ->
sprintf "let %s = %s in %s" x (pretty_expr e1) (pretty_expr e2)
| Let (x, Some t, e1, e2) ->
sprintf "let %s : %s = %s in %s" x (pretty_ty t) (pretty_expr e1) (pretty_expr e2)
| LetRec (x, None, e1, e2) ->
sprintf "let rec %s = %s in %s" x (pretty_expr e1) (pretty_expr e2)
| LetRec (x, Some tx, e1, e2) ->
sprintf "let rec %s : %s = %s in %s" x (pretty_ty tx) (pretty_expr e1) (pretty_expr e2)
| IfThenElse (e1, e2, e3o) ->
let s = sprintf "if %s then %s" (pretty_expr e1) (pretty_expr e2)
match e3o with
| None -> s
| Some e3 -> sprintf "%s else %s" s (pretty_expr e3)
| Tuple es ->
sprintf "(%s)" (pretty_tupled pretty_expr es)
| BinOp (e1, op, e2) -> sprintf "%s %s %s" (pretty_expr e1) op (pretty_expr e2)
| UnOp (op, e) -> sprintf "%s %s" op (pretty_expr e)
| _ -> unexpected_error "pretty_expr: %s" (pretty_expr e)
let rec pretty_value v =
match v with
| VLit lit -> pretty_lit lit
| VTuple vs -> pretty_tupled pretty_value vs
| Closure (env, x, e) -> sprintf "<|%s;%s;%s|>" (pretty_env pretty_value env) x (pretty_expr e)
| RecClosure (env, f, x, e) -> sprintf "<|%s;%s;%s;%s|>" (pretty_env pretty_value env) f x (pretty_expr e)
and the type inference algorithm is the following:
let type_error fmt = throw_formatted TypeError fmt
type subst = (tyvar * ty) list
// type inference
//
// starting environment with operation
let gamma0 : scheme env = [
("+", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
("-", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
("*", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
("/", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
("%", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
("=", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
("<", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
("<=", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
(">", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
("=>", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
("<>", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
("and", Forall([],TyArrow (TyBool, TyArrow (TyBool, TyBool))))
("or", Forall([],TyArrow (TyBool, TyArrow (TyBool, TyBool))))
("not", Forall([],TyArrow (TyBool, TyBool)))
("-", Forall([],TyArrow (TyInt, TyInt)))
("+.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
("-.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
("*.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
("/.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
("%.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
("=.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
("<.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
("<=.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
(">.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
("=>.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
("<>.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
("-.", Forall([],TyArrow (TyFloat, TyFloat)))
]
let mutable counter = -1
let generate_fresh_variable () =
counter <- counter + 1
counter + int 'a'
|> char
|> string
let rec occurs (tv : tyvar) (t : ty) : bool =
match t with
| TyVar t1 -> tv = t1
| TyArrow (t1,t2) -> occurs tv t1 || occurs tv t2
| TyName t1 -> false
| TyTuple tt -> let rec occ_list (tv : tyvar) (t : ty list) : bool =
match t with
|[] -> false
|head::tail -> if occurs tv head
then true
else occ_list tv tail
occ_list tv tt
// TODO implement this
let compose_subst (s1 : subst) (s2 : subst) : subst = s1 @ s2
// TODO implement this
let rec unify (t1 : ty) (t2 : ty) : subst =
match t1, t2 with
| TyName n1, TyName n2 -> if n1 <> n2
then type_error "unify: unification between different variables name can't be execute"
else []
| TyVar tv, _ -> if occurs tv t2
then type_error "unify: unification fails"
else [(tv , t2)]
| _ , TyVar tv -> if occurs tv t1
then type_error "unify: unification fails"
else [(tv , t1)]
| TyArrow (tl1,tr1), TyArrow (tl2,tr2) -> let u1 = unify tl1 tl2
let u2 = unify tr1 tr2
compose_subst u1 u2
(*let subs1 = unify tl1 tl2
let te1 = apply_subst tr1 subs1
let te2 = apply_subst tr2 subs1
let subs2 = unify te1 te2
compose_subst subs1 subs2*)
| _ -> unexpected_error "unify: unsupported operation"
(* substitute term s for all occurrences of var x in term t *)
let rec subst (s : ty) (x : tyvar) (t : ty) : ty =
match t with
| TyVar y -> if x = y then s else t
| TyArrow (u, v) -> TyArrow (subst s x u, subst s x v)
| TyName n -> t
| TyTuple ts -> TyTuple(List.map (subst s x) ts)
// TODO implement this
let apply_subst (t : ty) (s : subst) : ty =
List.foldBack (fun (x, e) -> subst e x) s t
let apply_subst_helper s t = apply_subst t s
// Give all tyvar in a type -> FV
let rec freevars_ty (t : ty) : tyvar Set =
match t with
| TyName _ -> Set.empty
| TyArrow (t1, t2) -> Set.union (freevars_ty t1) (freevars_ty t2)
| TyVar tv -> Set.singleton tv
| TyTuple ts -> List.fold (fun set t -> Set.union set (freevars_ty t)) Set.empty ts
let freevars_scheme (Forall (tvs, t)) =
Set.difference (freevars_ty t) (Set.ofList tvs)
let rec freevars_env (en: scheme env) : tyvar Set =
match en with
| [] -> Set.empty
| e -> match e with
|(_,sc)::tail -> Set.union (freevars_env tail) (freevars_scheme sc)
let generalize (env : scheme env) (typ : ty) : scheme =
let vars = Set.difference (freevars_ty typ) (freevars_env env)
Forall (Set.toList vars, typ)
let instantiate (Forall (tvs, typ)) : ty =
let nvars = List.map (fun _ -> TyVar(generate_fresh_variable()) ) tvs
let s = Map.ofSeq (Seq.zip tvs nvars) |> Map.toList
apply_subst typ s
let rec tupleMap l: subst =
match l with
|[] -> []
|head::tail ->
match head with
|(_,su) -> compose_subst su (tupleMap tail)
let rec tupleMap2 l: ty list =
match l with
|[] -> []
|head::tail ->
match head with
|(typ,_) -> typ::(tupleMap2 tail)
// type inference
//
let rec typeinfer_expr (env : scheme env) (e : expr) : ty * subst =
match e with
| Var x ->
let _, t = List.find (fun (y, _) -> x = y) env
(instantiate t, [])
| Lit (LInt _) -> (TyInt, [])
| Lit (LFloat _) -> (TyFloat, [])
| Lit (LString _) -> (TyString, [])
| Lit (LChar _) -> (TyChar, [])
| Lit (LBool _) -> (TyBool, [])
| Lit LUnit -> (TyUnit, [])
| App (e1, e2) ->
let codTy = TyVar(generate_fresh_variable ())
let t1, s1 = typeinfer_expr env e1
let t2, s2 = typeinfer_expr env e2
let s3 = unify t1 (TyArrow (t2,codTy))
let s32 = compose_subst s3 s2
let s321 = compose_subst s32 s1
(apply_subst codTy s321, s321)
| Lambda (x, None, e) ->
let freshVar = TyVar(generate_fresh_variable())
let sc1 = Forall(list.Empty,freshVar) //46:00 lesson 30 november
let t,s = typeinfer_expr((x, sc1) :: env) e
let finalType = apply_subst (TyArrow(freshVar,t)) s
(finalType,s)
| Lambda (x, Some typ, e) ->
let sc1 = Forall(list.Empty,typ)
let t,s = typeinfer_expr((x, sc1) :: env) e
let finalType = apply_subst (TyArrow(typ,t)) s
(finalType,s)
//monomorphic version
(*| Let (x, None , e1, e2) ->
let t1, s1 = typeinfer_expr env e1
let t2, s2 = typeinfer_expr ((x,t1) :: env) e2
let s3 = compose_subst s2 s1
(t2, s3)*)
//polimorphic version
| Let (x, None , e1, e2) ->
let t1, s1 = typeinfer_expr env e1
//Generalize
let sc1 = generalize env t1
let t2, s2 = typeinfer_expr ((x,sc1) :: env) e2
let s3 = compose_subst s2 s1
(t2, s3)
| IfThenElse (e1, e2, e3o) ->
let t1, s1 = typeinfer_expr env e1
let t2, s2 = typeinfer_expr env e2
let s4 = unify t1 TyBool
match e3o with
| None -> let s5 = unify t2 TyUnit
let tot = compose_subst (compose_subst (compose_subst s5 s4) s2) s1
(apply_subst t2 s5, tot)
| Some ex -> let t3, s3 = typeinfer_expr env ex
let s5 = unify t2 t3
let tot = compose_subst (compose_subst (compose_subst (compose_subst s5 s4) s3) s2) s1
(apply_subst t2 s5, tot)
| Tuple es ->
let t = List.map (typeinfer_expr env) es
let comp = tupleMap t
let typL = tupleMap2 t
let typ = TyTuple(List.map (apply_subst_helper comp) typL)
(typ,comp)
| BinOp(e1, op, e2) ->
typeinfer_expr env (App (App (Var op, e1), e2))
| BinOp (_, op, _) -> unexpected_error "typecheck_expr: unsupported binary operator (%s)" op
| UnOp(op, e) ->
typeinfer_expr env (App (Var op, e))
| UnOp (op, _) -> unexpected_error "typeinfer_expr: unsupported unary operator (%s)" op
| _ -> unexpected_error "typeinfer_expr: unsupported expression: %s [AST: %A]" (pretty_expr e) e
I would like to have some explanation on why when I start the program and write inside the shell the following line:
let test = fun x -> x;;
the type returned by the type inference is 'b ->'b and not 'a ->' a .
it's probably because in the process of type inference the function generate_fresh_variable () is called twice for the lamba and for the var, but I would like to understand how I can make it come out as type 'a ->' a ?
Obviously I accept suggestions on how to improve the code and how to better implement the algorithm, I apologize for my lack of experience but it is the first time that I look at F # and the type inference.
My guess is that you inferred the type of another expression before this one, which caused the variable counter to increment. For example, consider the following test code:
let test () =
let lambda = Lambda ("x", None, Var "x") // fun x -> x
let ty, subs = typeinfer_expr [] lambda
pretty_ty ty |> printfn "%s"
test () // 'a -> 'a
test () // 'b -> 'b
test () // 'c -> 'c
Note that a fresh type variable is generated each time the test runs. Maybe you need a way to reset the counter between invocations? Or, even better, see if you can get rid of the side-effect in generate_fresh_variable
entirely.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.