A Type Inference Implementation Adventure

I stumbled across a blog post entitled “Generic Unification” a while ago. It makes this claim:

The unification-fd package by wren gayle romano is the de-facto standard way to do unification in Haskell. You’d use it if you need to implement type inference for your DSL, for example.

This was incredibly relevant to my interests! At the time I had been thinking about how to do some sort of minimal type inference for A Thing, and discovering that it was harder than I thought it was. Running into a library to which I could offload half of the work was pretty sweet.

I wrote code that used the library to do type inference for The Thing, and it kinda sorta worked a bit. It wasn’t a great success, though, and I kept discovering things that I didn’t really understand. Continually running into things I don’t understand made it seem like a good time to step back and do a learning exercise, which I christened BabyTC. It’s a typechecker for a simple language like you might do as an exercise in a PL implementation course. This post is about the problems I ran into while I was doing this, and how I got through them. If you know a bit about haskell and a bit about how type inference works and are interested in some of the details of making it go, then this post might just be for you!

What Language Are We Checking?

The language is a super-simple lambda calculus sort of thing. A program is a single expression. In haskell syntax, the language looks like this:

data Expr = Lam Text Expr
          | App Expr Expr
          | Var Text
          | Let [(Text, Expr)] Expr
          | Number Integer
          | Text Text

Informally, the meaning of this language goes something like this:

  • The expression can be a Lam, which has a name and an expression. When an argument is applied to the Lam, the expression will be evaluated with the argument bound to the name.
  • The expression can be an App, in which case the App‘s second argument will be applied to the first. The first is going to have to evaluate to a Lam for this to work!
  • If the expression is a Var, then it evaluates to whatever the Var‘s name is bound to.
  • When the expression is a Let, then the first argument is a list of names and expressions. The result of the expression is the second argument, evaluated with every expression in first argument bound to the corresponding name. This was a late addition. We’ll come back to this soon.
  • If the expression is a Number or a Text, then it just evaluates to the value in the expression.

It’s common in things like this to have a Case statement of some sort. I’ve deliberately forgone that for now – I’m happy to inject an ad-hoc library of functions that will constrain things to be Numbers or Texts and provide ways to manipulate them and move between them. It’s also usual to provide a more formal explanation of the semantics, but I don’t really trust myself to get that right. Let me know if you’d like a pointer to a similar language with a more formal specification!

What does unification-fd give us?

If unification-fd is the tool we’re going to use to try to crack this nut, we’d better look at what exactly it gives us. We have:

  • A type, UTerm, representing either a variable or a thing in some language we define
  • A typeclass Unifiable to let us say whether two things in the language we defined can be unified
  • Ways to create, bind and look up variables (the functions of the BindingMonad typeclass)
  • Ways to query the relationship between terms (===, =~=)
  • Ways to assert that terms have particular relationships (=:=, <:=)
  • A Big Red Button (applyBinding,applyBindings)

The post that kicked off this whole caper, Roman Cheplyaka’s Generic Unification, and wren gayle romano’s tutorial on unification-fd cover how all those bits fit together, so I won’t talk about that here. How can we use these bits to do type inference?

What does a type look like?

I started with something like this classic:

data Type = LamTy Type Type
          | NumberTy
          | TextTy
          | TyVar Text
          | ForallTy Text

A LamTy has an argument type and a result type. NumberTy and TextTy are our primitive types. We don’t need an AppTy, because when you write an App expression its type is the type of the LamTy in its first argument after applying the second argument to it.

So what’s the deal with this TyVar and ForallTy business? If you’re writing a map function, it might have a type like (a -> b) -> [a] -> [b]. The a and b in that type are type variables, but where do they come from? It matters because the point where the variables are introduced determines how much flexibility the consumer of the map function has in picking types to send it. Foralls are a way of specifying where type variables come from. If we explicitly put in some foralls to that map function, we get forall a b. ((a -> b) -> [a] -> [b]). What that means is that you can pick any a and b you want, give it a function that takes an ato a b, and get a function that takes a list of the same as to a list of the same bs. But once you commit to an a and a b, you can’t pick another b halfway through. On the other hand, I can write down the type forall a. ((a -> forall b. b) -> [a] -> [forall b. b]) but it’s not going to end well. The caller of this function has to pass in something that can return a value of any type, then the function will presumably only use it to return values of the type the caller expects this function to return a list of. We’ve introduced a demand for more flexibility than we need, which makes problems for the consumer of the map function. If we commit to b early, we know we’re talking about the sameb throughout the whole type signature, and we don’t get those problems.

Our goal is to have a type like this associated with every expression in our program. Depending on exactly what your goal is, you might want to then throw away all except the top-level type, and that’s ok! Our general plan of attack is:

  1. Create a type variable for each expression
  2. Gather information from up and down the expression tree to constrain those type variables and feed it to unification-fd
  3. Ask unification-fd to do its thing and find values for all the type variables
  4. Dig through the variables until we have the actual type for each expression (which may still be a variable!)
  5. Remove all the unification variables, replacing them with TyVars and introducing Foralls as appropriate to get into roughly the form described above.

Rearranging Some Preliminaries

You know that Expr type I gave you before?

I lied.

We want to associate a type with not just the program as a whole, but with every expression in the program. That means rejigging the Expr type so that we associate some type information with every node. In another project, I play a recursion-schemes song that I considered reprising here, but I wondered if something simpler might be fun. The type I wound up using is as follows:

data Node t = Node t (Expr t)
data Expr t = Lam Text (Node t)
            | App (Node t) (Node t)
            | Var Text
            | Let [(Text, Node t)] (Node t)
            | Number Integer
            | Text Text

(plus a handful of derivings: Functor, Foldable, Traversable, Show, Eq. The Traversable one has the fun property of making the compiler hang if optimisations are turned on). 

Now, when I initially make an Expr I make an Expr (). We can allocate an empty unification variable to point to the type of each Node like this:

allocateTyVars :: forall t v m. BindingMonad t v m 
               => Node () -> m (Node v)
allocateTyVars = mapM $ \() -> freeVar

That’s step 1 finished!

I’m afraid I wasn’t just lying about the Expr type. I lied about the Type type too. Unification-fd expects us to provide an algebra that gives us a language to talk about what values our unification values can hold. The following gives a TyF that works as an algebra for unification-fd, and works out to be loosely equivalent to the Type I listed above but without the ForallTys.

data TyF f = LamTy f f 
           | NumberTy
           | TextTy

type Ty v = UTerm TyF v

lamTy :: Ty v -> Ty v -> Ty v
lamTy argTy bodyTy = UTerm $ LamTy argTy bodyTy

numTy :: Ty v
numTy = UTerm NumTy

textTy :: Ty v
textTy = UTerm TextTy

Gathering and Solving Constraints

My initial assumption was that I would use unification-fd’s functions to build up a list of constraints and variables, then pass that list of constraints to a solver and get a set of variable assignments back. This is not how it works! The library uses a state monad and builds up a list of variables and constraints inside it, and when you use =:= it saves a new constraint. Unifying two terms does return a new term, but you don’t need to keep that term around (although the documentation suggests that you can get better performance by using it in place of either of the now-unified terms).

We’re going to walk through our Node/Expr tree, look at every expression, constrain things appropriately, and bind the resulting type to the logic variable in the node.

constrain :: (BindingMonad t v m, Fallible t v e, TinyFallible v e, MonadTrans em, Functor (em m), MonadError e (em m), t ~ TyF, Show v)
          => Map Text (NeedsFreshening, Ty v) -> Node v -> em m (Ty v)
constrain tyEnv (Node tyVar expr) = do ty' <- go expr
                                       lift $ bindVar tyVar ty'
                                       pure ty'
  where go (Lam argName lamBody) = do argTy <- UVar  lift freeVar
                                      bodyTy <- constrain (Map.insert argName (SoFreshAlready, argTy) tyEnv) lamBody
                                      pure (UTerm $ LamTy argTy bodyTy)

When we’re looking at a Lam, we don’t know much about the argument type. We can learn about it both from the body of the function and from the way the function is called. Since we don’t know much about it right now, we just generate a fresh variable for it and stick that variable in the environment while we figure out lamBody. That means the process of gathering constraints for lamBody both gives us a type to be the Lam‘s bodyTy and discovers the details we need for that argTy.

        go (App funExp argExp) = do argTy <- constrain tyEnv argExp
                                    funBodyTy <- UVar  lift freeVar
                                    funExpTy <- constrain tyEnv funExp
                                    unify (UTerm $ LamTy argTy funBodyTy) funExpTy
                                    pure funBodyTy

This was by far the most interesting stanza to figure out. The key to understanding what’s going on here is that we know the type of funExp is going to be a LamTy. If we set that LamTy’s argument type to be the type of argExp, then the type of that LamTy‘s body will be the type of the App expression as a whole.

My first impulse was to get the type of funExp, pattern match on it to pull the actual argTy and funBodyTy from its LamTy, and go from there. That doesn’t work because when this code runs funExp doesn’t have a type for us to get. Instead, we concoct a new LamTy with an argTy that matches this argExp‘s type and a new variable to be both that LamTy‘s return type and the return type for this App expression as a whole. We gather constraints for funExp, and unify the type we get for funExp with this made-up LamTy. It seems to work!

        go (Let bindings bodyExp) = do bindingTys  ((name,) . (NeedsFreshening,))  constrain tyEnv exp) bindings
                                       constrain (Map.union (Map.fromList bindingTys) tyEnv) bodyExp

This might be a good time to talk about why we have these Let expressions. It’s mostly a way to support parametric polymorphism.

I want my program to be able to have terms whose type can vary, and for multiple uses of that term to be able to specialise it to different types. For example, let’s say we have the identity function Lam "x" (Var "x") . We know that this function returns a value of the same type that is passed in to it, yielding a type something like forall a. a -> a in a haskellish syntax or ForallTy "a" (LamTy (TyVar "a") (TyVar "a") in our ADT above. If we asked unification-fd to infer the type for the identity function in isolation, we expect it to give us a type LamTy v v, where v is some logic variable. One part of our program might use it as aLamTy NumTy NumTy , and another as a LamTy TextTy TextTy.  In other words, we want our program to be able to define the function once, and use it multiple times, and have it adopt different values for the type variables each time.

If we don’t do anything to make this work, what happens? When one part of the program uses this identity function, we unify the function’s argument and return types with particular concrete types. Then, when the other part of the program tries to use it with a different type, it won’t be able to unify it. It’s too busy being a Text to hang out with the Numbers!

My first attempt at a solution was to insert a freshen in my App constrainer so that I unified my made-up term with a freshened funExpTy. The call to freshen replaces every logic variable in the term with a new, fresh logic variable, with the new fresh variables sharing the same relationships between them as the existing variables. This solved the problem, but the cure was worse than the disease. One of the ways we get information about the type of a LamTy is from what happens to the Lam when we apply an argument to it and use the result. When we freshen funExpTy we break the connection with the original LamTy. Instead of gaining information about the original LamTy, that new information is associated with the freshened copy and never makes it back.

My solution to this – which I am not at all sure is correct, but it works for at least one simple case – was inspired by Stephen Diehl’s Write You A Haskell‘s choice to treat the variable assignment in a Let quite differently to that in a Lam. Instead of defining an environment of types, he defines one of type schemes. When a Lam inserts a type scheme into the environment, it does nothing at all with type variables and so forces all of the users of the type to agree on exactly what it is. When a Let inserts a type scheme into the environment, it does so in such a way that anyone pulling the type scheme from the environment gets to decide anew where the type variables where land.

Instead of a formally-coherent notion of type schemes, my riff on this just has an environment of tuples. When Let inserts a type into the environment, it inserts a tag of NeedsFreshening alongside it. When a Var inserts a type into the environment, the tag SoFreshAlready is used, because this is a very serious and highly professional project.

        go (Var text)
         | Just (NeedsFreshening, varTy) <- Map.lookup text tyEnv = freshen varTy
         | Just (SoFreshAlready, varTy)  <- Map.lookup text tyEnv = pure varTy
         | Nothing                       <- Map.lookup text tyEnv = throwError $ undefinedVar text tyVar

When we look up the type of an expression-level variable, we inspect the tag alongside the type. If it needs freshening – which means it came from a let binding – then we freshen it before returning it. That resolves the issues about polymorphism I was talking about before.

        go (Number _) = pure numTy
        go (Text _) = pure textTy

Once we make numeric literals be numbers and textual literals be text, we’re all done for constraint gathering!

What’s next?

Steps three and four are really easy! You just call applyBinding on the top type to make unification-fd do its thing, and that’s step 3 out of the way. I wrote a recursive variable lookup thing to handle step 4 but it turns out fullprune/semiprune do the same thing for you. Sweet!

I thought step 5 was easy, but my thing to handle it was wrong and there’s some other fun stuff to talk about at the same time. So I’m going to leave it as an exercise for the reader, and maybe I’ll write about it later! (edited to add: I wrote about it later! It’s called From Logic Variables to Type Variables)

So… apart from that, what’s next?

This take on let bindings doesn’t support inferring types of recursive functions. I think there are at least three options: eliminate recursion with a fixed-point combinator, allocate a type variable to represent the type of this let-bound term and stick it in the environment before you start gathering constraints for the term, or allocate type variables for all of the let-bound terms and stick them all in the environment when gathering constraints for all the terms. I’ve tried the second option, and it seems to work. The third option should let you have mutually recursive functions.

Using let-bindings as a point of introducing and eliminating type variables is a pleasingly simple way to get out of the parametricity pickle. It feels kind of arbitrary to me, though. It might be more theoretically elegant if we didn’t unify all of the types involved when we apply an argument to a function, and instead we demanded that the value we’re applying be a subtype of the expected argument type. I believe this would push me to figure out what the deal is with unification-fd’s subsumes function. I’d probably learn a bunch about covariant and contravariant functions in the process.

Speaking of beliefs… I believe my implementation works, but it’s a pretty weak belief. I should probably test it more thoroughly. I’d also be interested in trying to get a deeper understanding of the practical difference between my NeedsFreshening hack and the Type Scheme business. I kind of want to implement both my approach and the one Diehl uses in coq and try to prove their equivalence.

Useful programming languages normally offer some sort of compound type and some way to build things up and break them down. Maybe I should look into that. But for now, I’m pretty happy to have written my first type inferencer!

One thought on “A Type Inference Implementation Adventure

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s