简体   繁体   中英

[F#][Type inference] - How to improve my program?

I'm currently building a program to type inference in F# using the following AST:

// errors
//

exception SyntaxError of string * FSharp.Text.Lexing.LexBuffer<char>
exception TypeError of string
exception UnexpectedError of string

let throw_formatted exnf fmt = ksprintf (fun s -> raise (exnf s)) fmt

let unexpected_error fmt = throw_formatted UnexpectedError fmt


// AST type definitions
//

type tyvar = string

type ty =
    | TyName of string //constants
    | TyArrow of ty * ty
    | TyVar of tyvar //'a,'b for example
    | TyTuple of ty list

// pseudo data constructors for literal types
let TyFloat = TyName "float"
let TyInt = TyName "int"
let TyChar = TyName "char"
let TyString = TyName "string"
let TyBool = TyName "bool"
let TyUnit = TyName "unit"

// active pattern for literal types
let private (|TyLit|_|) name = function
    | TyName s when s = name -> Some ()
    | _ -> None

let (|TyFloat|_|) = (|TyLit|_|) "float"
let (|TyInt|_|) = (|TyLit|_|) "int"
let (|TyChar|_|) = (|TyLit|_|) "char"
let (|TyString|_|) = (|TyLit|_|) "string"
let (|TyBool|_|) = (|TyLit|_|) "bool"
let (|TyUnit|_|) = (|TyLit|_|) "unit"


type scheme = Forall of tyvar list * ty

type lit = LInt of int
         | LFloat of float
         | LString of string
         | LChar of char
         | LBool of bool
         | LUnit 

type binding = bool * string * ty option * expr    // (is_recursive, id, optional_type_annotation, expression)

and expr = 
    | Lit of lit
    | Lambda of string * ty option * expr
    | App of expr * expr
    | Var of string
    | LetIn of binding * expr
    | IfThenElse of expr * expr * expr option
    | Tuple of expr list
    | BinOp of expr * string * expr
    | UnOp of string * expr
   
let fold_params parms e0 = 
 List.foldBack (fun (id, tyo) e -> Lambda (id, tyo, e)) parms e0

let (|Let|_|) = function 
    | LetIn ((false, x, tyo, e1), e2) -> Some (x, tyo, e1, e2)
    | _ -> None
    
let (|LetRec|_|) = function 
    | LetIn ((true, x, tyo, e1), e2) -> Some (x, tyo, e1, e2)
    | _ -> None

type 'a env = (string * 'a) list  

type value =
    | VLit of lit
    | VTuple of value list
    | Closure of value env * string * expr
    | RecClosure of value env * string * string * expr

type interactive = IExpr of expr | IBinding of binding

// pretty printers
//

// utility function for printing lists by flattening strings with a separator 
let rec flatten p sep es =
    match es with
    | [] -> ""
    | [e] -> p e
    | e :: es -> sprintf "%s%s %s" (p e) sep (flatten p sep es)

// print pairs within the given env using p as printer for the elements bound within
let pretty_env p env = sprintf "[%s]" (flatten (fun (x, o) -> sprintf "%s=%s" x (p o)) ";" env)

// print any tuple given a printer p for its elements
let pretty_tupled p l = flatten p ", " l

let rec pretty_ty t =
    match t with
    | TyName s -> s
    | TyArrow (t1, t2) -> sprintf "%s -> %s" (pretty_ty t1) (pretty_ty t2)
    | TyVar n -> sprintf "'%s" n
    | TyTuple ts -> sprintf "(%s)" (pretty_tupled pretty_ty ts)

let pretty_lit lit =
    match lit with
    | LInt n -> sprintf "%d" n
    | LFloat n -> sprintf "%g" n
    | LString s -> sprintf "\"%s\"" s
    | LChar c -> sprintf "%c" c
    | LBool true -> "true"
    | LBool false -> "false"
    | LUnit -> "()"

let rec pretty_expr e =
    match e with
    | Lit lit -> pretty_lit lit

    | Lambda (x, None, e) -> sprintf "fun %s -> %s" x (pretty_expr e)
    | Lambda (x, Some t, e) -> sprintf "fun (%s : %s) -> %s" x (pretty_ty t) (pretty_expr e)
    
    // TODO pattern-match sub-application cases
    | App (e1, e2) -> sprintf "%s %s" (pretty_expr e1) (pretty_expr e2)

    | Var x -> x

    | Let (x, None, e1, e2) ->
        sprintf "let %s = %s in %s" x (pretty_expr e1) (pretty_expr e2)

    | Let (x, Some t, e1, e2) ->
        sprintf "let %s : %s = %s in %s" x (pretty_ty t) (pretty_expr e1) (pretty_expr e2)

    | LetRec (x, None, e1, e2) ->
        sprintf "let rec %s = %s in %s" x (pretty_expr e1) (pretty_expr e2)

    | LetRec (x, Some tx, e1, e2) ->
        sprintf "let rec %s : %s = %s in %s" x (pretty_ty tx) (pretty_expr e1) (pretty_expr e2)

    | IfThenElse (e1, e2, e3o) ->
        let s = sprintf "if %s then %s" (pretty_expr e1) (pretty_expr e2)
        match e3o with
        | None -> s
        | Some e3 -> sprintf "%s else %s" s (pretty_expr e3)
        
    | Tuple es ->        
        sprintf "(%s)" (pretty_tupled pretty_expr es)

    | BinOp (e1, op, e2) -> sprintf "%s %s %s" (pretty_expr e1) op (pretty_expr e2)
    
    | UnOp (op, e) -> sprintf "%s %s" op (pretty_expr e)
    
    | _ -> unexpected_error "pretty_expr: %s" (pretty_expr e)

let rec pretty_value v =
    match v with
    | VLit lit -> pretty_lit lit

    | VTuple vs -> pretty_tupled pretty_value vs

    | Closure (env, x, e) -> sprintf "<|%s;%s;%s|>" (pretty_env pretty_value env) x (pretty_expr e)
    
    | RecClosure (env, f, x, e) -> sprintf "<|%s;%s;%s;%s|>" (pretty_env pretty_value env) f x (pretty_expr e)

and the type inference algorithm is the following:

let type_error fmt = throw_formatted TypeError fmt

type subst = (tyvar * ty) list

// type inference
//

// starting environment with operation
let gamma0 : scheme env = [
    ("+", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
    ("-", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
    ("*", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
    ("/", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
    ("%", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyInt))))
    ("=", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    ("<", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    ("<=", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    (">", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    ("=>", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    ("<>", Forall([],TyArrow (TyInt, TyArrow (TyInt, TyBool))))
    ("and", Forall([],TyArrow (TyBool, TyArrow (TyBool, TyBool))))
    ("or", Forall([],TyArrow (TyBool, TyArrow (TyBool, TyBool))))
    ("not", Forall([],TyArrow (TyBool, TyBool)))
    ("-", Forall([],TyArrow (TyInt, TyInt)))

    
    ("+.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
    ("-.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
    ("*.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
    ("/.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
    ("%.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyFloat))))
    ("=.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    ("<.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    ("<=.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    (">.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    ("=>.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    ("<>.", Forall([],TyArrow (TyFloat, TyArrow (TyFloat, TyBool))))
    ("-.", Forall([],TyArrow (TyFloat, TyFloat)))
]

let mutable counter = -1

let generate_fresh_variable () =
    counter <- counter + 1
    counter + int 'a'
        |> char
        |> string

let rec occurs (tv : tyvar) (t : ty) : bool = 
    match t with
    | TyVar t1 -> tv = t1
    | TyArrow (t1,t2) -> occurs tv t1 || occurs tv t2
    | TyName t1 -> false
    | TyTuple tt -> let rec occ_list (tv : tyvar) (t : ty list) : bool =
                        match t with
                        |[] -> false
                        |head::tail -> if occurs tv head
                                       then true
                                       else occ_list tv tail
                    occ_list tv tt

// TODO implement this
let compose_subst (s1 : subst) (s2 : subst) : subst = s1 @ s2

// TODO implement this
let rec unify (t1 : ty) (t2 : ty) : subst = 
    match t1, t2 with
    | TyName n1, TyName n2 -> if n1 <> n2 
                              then type_error "unify: unification between different variables name can't be execute"
                              else []
    | TyVar tv, _ -> if occurs tv t2 
                     then type_error "unify: unification fails"
                     else [(tv , t2)]

    | _ , TyVar tv -> if occurs tv t1 
                      then type_error "unify: unification fails"
                      else [(tv , t1)]

    | TyArrow (tl1,tr1), TyArrow (tl2,tr2) ->   let u1 = unify tl1 tl2
                                                let u2 = unify tr1 tr2
                                                compose_subst u1 u2

                                                (*let subs1 = unify tl1 tl2
                                                let te1 = apply_subst tr1 subs1 
                                                let te2 = apply_subst tr2 subs1 
                                                let subs2 = unify te1 te2
                                                compose_subst subs1 subs2*)
    | _ -> unexpected_error "unify: unsupported operation"

(* substitute term s for all occurrences of var x in term t *)
let rec subst (s : ty) (x : tyvar) (t : ty) : ty =
  match t with
  | TyVar y -> if x = y then s else t
  | TyArrow (u, v) -> TyArrow (subst s x u, subst s x v)
  | TyName n -> t
  | TyTuple ts -> TyTuple(List.map (subst s x) ts)

// TODO implement this
let apply_subst (t : ty) (s : subst) : ty = 
  List.foldBack (fun (x, e) -> subst e x) s t

let apply_subst_helper s t = apply_subst t s

// Give all tyvar in a type -> FV
let rec freevars_ty (t : ty) : tyvar Set =
    match t with
    | TyName _ -> Set.empty
    | TyArrow (t1, t2) -> Set.union (freevars_ty t1) (freevars_ty t2)
    | TyVar tv -> Set.singleton tv
    | TyTuple ts -> List.fold (fun set t -> Set.union set (freevars_ty t)) Set.empty ts 

let freevars_scheme (Forall (tvs, t)) =
    Set.difference (freevars_ty t) (Set.ofList tvs)

let rec freevars_env (en: scheme env) : tyvar Set =
    match en with
    | [] -> Set.empty
    | e  -> match e with
            |(_,sc)::tail -> Set.union (freevars_env tail) (freevars_scheme sc)


let generalize (env : scheme env) (typ : ty) : scheme =
    let vars = Set.difference (freevars_ty typ) (freevars_env env)
    Forall (Set.toList vars, typ)

let instantiate (Forall (tvs, typ)) : ty =
    let nvars = List.map (fun _ -> TyVar(generate_fresh_variable()) ) tvs
    let s = Map.ofSeq (Seq.zip tvs nvars) |> Map.toList
    apply_subst typ s

let rec tupleMap l: subst =
    match l with
    |[] -> []
    |head::tail -> 
                   match head with
                   |(_,su) -> compose_subst su (tupleMap tail)

let rec tupleMap2 l: ty list =
    match l with
    |[] -> []
    |head::tail -> 
                   match head with
                   |(typ,_) -> typ::(tupleMap2 tail)
// type inference
//

let rec typeinfer_expr (env : scheme env) (e : expr) : ty * subst =
    match e with
    | Var x -> 
        let _, t = List.find (fun (y, _) -> x = y) env
        (instantiate t, [])

    | Lit (LInt _) -> (TyInt, [])
    | Lit (LFloat _) -> (TyFloat, [])
    | Lit (LString _) -> (TyString, [])
    | Lit (LChar _) -> (TyChar, [])
    | Lit (LBool _) -> (TyBool, [])
    | Lit LUnit -> (TyUnit, [])

    | App (e1, e2) ->
        let codTy = TyVar(generate_fresh_variable ())
        let t1, s1 = typeinfer_expr env e1
        let t2, s2 = typeinfer_expr env e2
        let s3 = unify t1 (TyArrow (t2,codTy))
        let s32 = compose_subst s3 s2 
        let s321 = compose_subst s32 s1
        (apply_subst codTy s321, s321) 

    | Lambda (x, None, e) ->
        let freshVar = TyVar(generate_fresh_variable())
        let sc1 = Forall(list.Empty,freshVar) //46:00 lesson 30 november
        let t,s = typeinfer_expr((x, sc1) :: env) e
        let finalType = apply_subst (TyArrow(freshVar,t)) s
        (finalType,s)

    | Lambda (x, Some typ, e) ->
        let sc1 = Forall(list.Empty,typ)
        let t,s = typeinfer_expr((x, sc1) :: env) e
        let finalType = apply_subst (TyArrow(typ,t)) s
        (finalType,s)

    //monomorphic version
    (*| Let (x, None , e1, e2) -> 
        let t1, s1 = typeinfer_expr env e1
        let t2, s2 = typeinfer_expr ((x,t1) :: env) e2
        let s3 = compose_subst s2 s1
        (t2, s3)*)

    //polimorphic version
    | Let (x, None , e1, e2) -> 
        let t1, s1 = typeinfer_expr env e1
        //Generalize
        let sc1 = generalize env t1
        let t2, s2 = typeinfer_expr ((x,sc1) :: env) e2
        let s3 = compose_subst s2 s1
        (t2, s3)

    | IfThenElse (e1, e2, e3o) ->
        let t1, s1 = typeinfer_expr env e1
        let t2, s2 = typeinfer_expr env e2
        let s4 = unify t1 TyBool
        match e3o with
        | None -> let s5 = unify t2 TyUnit
                  let tot = compose_subst (compose_subst (compose_subst s5 s4) s2) s1
                  (apply_subst t2 s5, tot)

        | Some ex -> let t3, s3 = typeinfer_expr env ex
                     let s5 = unify t2 t3
                     let tot = compose_subst (compose_subst (compose_subst (compose_subst s5 s4) s3) s2) s1
                     (apply_subst t2 s5, tot)

    | Tuple es ->
        let t = List.map (typeinfer_expr env) es
        let comp = tupleMap t
        let typL =  tupleMap2 t
        let typ = TyTuple(List.map (apply_subst_helper comp) typL)
        (typ,comp)

    | BinOp(e1, op, e2) ->
        typeinfer_expr env (App (App (Var op, e1), e2))

    | BinOp (_, op, _) -> unexpected_error "typecheck_expr: unsupported binary operator (%s)" op
        
    | UnOp(op, e) ->
        typeinfer_expr env (App (Var op, e))

    | UnOp (op, _) -> unexpected_error "typeinfer_expr: unsupported unary operator (%s)" op

    | _ -> unexpected_error "typeinfer_expr: unsupported expression: %s [AST: %A]" (pretty_expr e) e

I would like to have some explanation on why when I start the program and write inside the shell the following line:

let test = fun x -> x;;

the type returned by the type inference is 'b ->'b and not 'a ->' a .

it's probably because in the process of type inference the function generate_fresh_variable () is called twice for the lamba and for the var, but I would like to understand how I can make it come out as type 'a ->' a ?

Obviously I accept suggestions on how to improve the code and how to better implement the algorithm, I apologize for my lack of experience but it is the first time that I look at F # and the type inference.

My guess is that you inferred the type of another expression before this one, which caused the variable counter to increment. For example, consider the following test code:

let test () =
    let lambda = Lambda ("x", None, Var "x")   // fun x -> x
    let ty, subs = typeinfer_expr [] lambda
    pretty_ty ty |> printfn "%s"

test ()   // 'a -> 'a
test ()   // 'b -> 'b
test ()   // 'c -> 'c

Note that a fresh type variable is generated each time the test runs. Maybe you need a way to reset the counter between invocations? Or, even better, see if you can get rid of the side-effect in generate_fresh_variable entirely.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM