# A Type Inference Implementation Adventure

I stumbled across a blog post entitled “Generic Unification” a while ago. It makes this claim:

The unification-fd package by wren gayle romano is the de-facto standard way to do unification in Haskell. You’d use it if you need to implement type inference for your DSL, for example.

This was incredibly relevant to my interests! At the time I had been thinking about how to do some sort of minimal type inference for A Thing, and discovering that it was harder than I thought it was. Running into a library to which I could offload half of the work was pretty sweet.

I wrote code that used the library to do type inference for The Thing, and it kinda sorta worked a bit. It wasn’t a great success, though, and I kept discovering things that I didn’t really understand. Continually running into things I don’t understand made it seem like a good time to step back and do a learning exercise, which I christened BabyTC. It’s a typechecker for a simple language like you might do as an exercise in a PL implementation course. This post is about the problems I ran into while I was doing this, and how I got through them. If you know a bit about haskell and a bit about how type inference works and are interested in some of the details of making it go, then this post might just be for you!

## What Language Are We Checking?

The language is a super-simple lambda calculus sort of thing. A program is a single expression. In haskell syntax, the language looks like this:

```data Expr = Lam Text Expr
| App Expr Expr
| Var Text
| Let [(Text, Expr)] Expr
| Number Integer
| Text Text
```

Informally, the meaning of this language goes something like this:

• The expression can be a `Lam`, which has a name and an expression. When an argument is applied to the `Lam`, the expression will be evaluated with the argument bound to the name.
• The expression can be an `App`, in which case the `App`‘s second argument will be applied to the first. The first is going to have to evaluate to a `Lam` for this to work!
• If the expression is a `Var`, then it evaluates to whatever the `Var`‘s name is bound to.
• When the expression is a `Let`, then the first argument is a list of names and expressions. The result of the expression is the second argument, evaluated with every expression in first argument bound to the corresponding name. This was a late addition. We’ll come back to this soon.
• If the expression is a `Number` or a `Text`, then it just evaluates to the value in the expression.

It’s common in things like this to have a `Case` statement of some sort. I’ve deliberately forgone that for now – I’m happy to inject an ad-hoc library of functions that will constrain things to be `Number`s or `Text`s and provide ways to manipulate them and move between them. It’s also usual to provide a more formal explanation of the semantics, but I don’t really trust myself to get that right. Let me know if you’d like a pointer to a similar language with a more formal specification!

## What does unification-fd give us?

If unification-fd is the tool we’re going to use to try to crack this nut, we’d better look at what exactly it gives us. We have:

• A type, `UTerm`, representing either a variable or a thing in some language we define
• A typeclass `Unifiable` to let us say whether two things in the language we defined can be unified
• Ways to create, bind and look up variables (the functions of the `BindingMonad` typeclass)
• Ways to query the relationship between terms (`===`, `=~=`)
• Ways to assert that terms have particular relationships (`=:=`, `<:=`)
• A Big Red Button (`applyBinding`,`applyBindings`)

The post that kicked off this whole caper, Roman Cheplyaka’s Generic Unification, and wren gayle romano’s tutorial on unification-fd cover how all those bits fit together, so I won’t talk about that here. How can we use these bits to do type inference?

## What does a type look like?

I started with something like this classic:

```data Type = LamTy Type Type
| NumberTy
| TextTy
| TyVar Text
| ForallTy Text```

A `LamTy` has an argument type and a result type. `NumberTy` and `TextTy` are our primitive types. We don’t need an `AppTy`, because when you write an `App` expression its type is the type of the `LamTy` in its first argument after applying the second argument to it.

So what’s the deal with this `TyVar` and `ForallTy` business? If you’re writing a map function, it might have a type like `(a -> b) -> [a] -> [b]`. The `a` and `b` in that type are type variables, but where do they come from? It matters because the point where the variables are introduced determines how much flexibility the consumer of the map function has in picking types to send it. `Forall`s are a way of specifying where type variables come from. If we explicitly put in some `forall`s to that map function, we get `forall a b. ((a -> b) -> [a] -> [b])`. What that means is that you can pick any `a` and `b` you want, give it a function that takes an `a`to a `b`, and get a function that takes a list of the same `a`s to a list of the same `b`s. But once you commit to an `a` and a `b`, you can’t pick another `b` halfway through. On the other hand, I can write down the type `forall a. ((a -> forall b. b) -> [a] -> [forall b. b])` but it’s not going to end well. The caller of this function has to pass in something that can return a value of any type, then the function will presumably only use it to return values of the type the caller expects this function to return a list of. We’ve introduced a demand for more flexibility than we need, which makes problems for the consumer of the map function. If we commit to `b` early, we know we’re talking about the same`b` throughout the whole type signature, and we don’t get those problems.

Our goal is to have a type like this associated with every expression in our program. Depending on exactly what your goal is, you might want to then throw away all except the top-level type, and that’s ok! Our general plan of attack is:

1. Create a type variable for each expression
2. Gather information from up and down the expression tree to constrain those type variables and feed it to unification-fd
3. Ask unification-fd to do its thing and find values for all the type variables
4. Dig through the variables until we have the actual type for each expression (which may still be a variable!)
5. Remove all the unification variables, replacing them with `TyVars` and introducing `Foralls` as appropriate to get into roughly the form described above.

## Rearranging Some Preliminaries

You know that `Expr` type I gave you before?

I lied.

We want to associate a type with not just the program as a whole, but with every expression in the program. That means rejigging the `Expr` type so that we associate some type information with every node. In another project, I play a recursion-schemes song that I considered reprising here, but I wondered if something simpler might be fun. The type I wound up using is as follows:

```data Node t = Node t (Expr t)
data Expr t = Lam Text (Node t)
| App (Node t) (Node t)
| Var Text
| Let [(Text, Node t)] (Node t)
| Number Integer
| Text Text```

(plus a handful of derivings: Functor, Foldable, Traversable, Show, Eq. The Traversable one has the fun property of making the compiler hang if optimisations are turned on).

Now, when I initially make an `Expr` I make an `Expr ()`. We can allocate an empty unification variable to point to the type of each `Node` like this:

```allocateTyVars :: forall t v m. BindingMonad t v m
=> Node () -> m (Node v)
allocateTyVars = mapM \$ \() -> freeVar
```

That’s step 1 finished!

I’m afraid I wasn’t just lying about the `Expr` type. I lied about the `Type` type too. Unification-fd expects us to provide an algebra that gives us a language to talk about what values our unification values can hold. The following gives a `TyF` that works as an algebra for unification-fd, and works out to be loosely equivalent to the Type I listed above but without the ForallTys.

```data TyF f = LamTy f f
| NumberTy
| TextTy

type Ty v = UTerm TyF v

lamTy :: Ty v -> Ty v -> Ty v
lamTy argTy bodyTy = UTerm \$ LamTy argTy bodyTy

numTy :: Ty v
numTy = UTerm NumTy

textTy :: Ty v
textTy = UTerm TextTy
```

## Gathering and Solving Constraints

My initial assumption was that I would use unification-fd’s functions to build up a list of constraints and variables, then pass that list of constraints to a solver and get a set of variable assignments back. This is not how it works! The library uses a state monad and builds up a list of variables and constraints inside it, and when you use `=:=` it saves a new constraint. Unifying two terms does return a new term, but you don’t need to keep that term around (although the documentation suggests that you can get better performance by using it in place of either of the now-unified terms).

We’re going to walk through our `Node`/`Expr` tree, look at every expression, constrain things appropriately, and bind the resulting type to the logic variable in the node.

```constrain :: (BindingMonad t v m, Fallible t v e, TinyFallible v e, MonadTrans em, Functor (em m), MonadError e (em m), t ~ TyF, Show v)
=> Map Text (NeedsFreshening, Ty v) -> Node v -> em m (Ty v)
constrain tyEnv (Node tyVar expr) = do ty' <- go expr
lift \$ bindVar tyVar ty'
pure ty'
where go (Lam argName lamBody) = do argTy <- UVar  lift freeVar
bodyTy <- constrain (Map.insert argName (SoFreshAlready, argTy) tyEnv) lamBody
pure (UTerm \$ LamTy argTy bodyTy)
```

When we’re looking at a `Lam`, we don’t know much about the argument type. We can learn about it both from the body of the function and from the way the function is called. Since we don’t know much about it right now, we just generate a fresh variable for it and stick that variable in the environment while we figure out `lamBody`. That means the process of gathering constraints for `lamBody` both gives us a type to be the `Lam`‘s `bodyTy` and discovers the details we need for that `argTy`.

```        go (App funExp argExp) = do argTy <- constrain tyEnv argExp
funBodyTy <- UVar  lift freeVar
funExpTy <- constrain tyEnv funExp
unify (UTerm \$ LamTy argTy funBodyTy) funExpTy
pure funBodyTy
```

This was by far the most interesting stanza to figure out. The key to understanding what’s going on here is that we know the type of `funExp` is going to be a `LamTy`. If we set that LamTy’s argument type to be the type of `argExp`, then the type of that `LamTy`‘s body will be the type of the `App` expression as a whole.

My first impulse was to get the type of `funExp`, pattern match on it to pull the actual `argTy` and `funBodyTy` from its `LamTy`, and go from there. That doesn’t work because when this code runs `funExp` doesn’t have a type for us to get. Instead, we concoct a new `LamTy` with an `argTy` that matches this `argExp`‘s type and a new variable to be both that `LamTy`‘s return type and the return type for this `App` expression as a whole. We gather constraints for `funExp`, and unify the type we get for `funExp` with this made-up `LamTy`. It seems to work!

```        go (Let bindings bodyExp) = do bindingTys  ((name,) . (NeedsFreshening,))  constrain tyEnv exp) bindings
constrain (Map.union (Map.fromList bindingTys) tyEnv) bodyExp
```

This might be a good time to talk about why we have these `Let` expressions. It’s mostly a way to support parametric polymorphism.

I want my program to be able to have terms whose type can vary, and for multiple uses of that term to be able to specialise it to different types. For example, let’s say we have the identity function `Lam "x" (Var "x")` . We know that this function returns a value of the same type that is passed in to it, yielding a type something like `forall a. a -> a` in a haskellish syntax or `ForallTy "a" (LamTy (TyVar "a") (TyVar "a")` in our ADT above. If we asked unification-fd to infer the type for the identity function in isolation, we expect it to give us a type `LamTy v v`, where v is some logic variable. One part of our program might use it as a`LamTy NumTy NumTy` , and another as a `LamTy TextTy TextTy`.  In other words, we want our program to be able to define the function once, and use it multiple times, and have it adopt different values for the type variables each time.

If we don’t do anything to make this work, what happens? When one part of the program uses this identity function, we unify the function’s argument and return types with particular concrete types. Then, when the other part of the program tries to use it with a different type, it won’t be able to unify it. It’s too busy being a `Text` to hang out with the `Number`s!

My first attempt at a solution was to insert a freshen in my `App` constrainer so that I unified my made-up term with a `freshen`ed `funExpTy`. The call to `freshen` replaces every logic variable in the term with a new, fresh logic variable, with the new fresh variables sharing the same relationships between them as the existing variables. This solved the problem, but the cure was worse than the disease. One of the ways we get information about the type of a `LamTy` is from what happens to the `Lam` when we apply an argument to it and use the result. When we freshen `funExpTy` we break the connection with the original `LamTy`. Instead of gaining information about the original `LamTy`, that new information is associated with the freshened copy and never makes it back.

My solution to this – which I am not at all sure is correct, but it works for at least one simple case – was inspired by Stephen Diehl’s Write You A Haskell‘s choice to treat the variable assignment in a `Let` quite differently to that in a `Lam`. Instead of defining an environment of types, he defines one of type schemes. When a `Lam` inserts a type scheme into the environment, it does nothing at all with type variables and so forces all of the users of the type to agree on exactly what it is. When a `Let` inserts a type scheme into the environment, it does so in such a way that anyone pulling the type scheme from the environment gets to decide anew where the type variables where land.

Instead of a formally-coherent notion of type schemes, my riff on this just has an environment of tuples. When `Let` inserts a type into the environment, it inserts a tag of `NeedsFreshening` alongside it. When a `Var` inserts a type into the environment, the tag `SoFreshAlready` is used, because this is a very serious and highly professional project.

```        go (Var text)
| Just (NeedsFreshening, varTy) <- Map.lookup text tyEnv = freshen varTy
| Just (SoFreshAlready, varTy)  <- Map.lookup text tyEnv = pure varTy
| Nothing                       <- Map.lookup text tyEnv = throwError \$ undefinedVar text tyVar
```

When we look up the type of an expression-level variable, we inspect the tag alongside the type. If it needs freshening – which means it came from a let binding – then we freshen it before returning it. That resolves the issues about polymorphism I was talking about before.

```        go (Number _) = pure numTy
go (Text _) = pure textTy```

Once we make numeric literals be numbers and textual literals be text, we’re all done for constraint gathering!

## What’s next?

Steps three and four are really easy! You just call `applyBinding` on the top type to make unification-fd do its thing, and that’s step 3 out of the way. I wrote a recursive variable lookup thing to handle step 4 but it turns out `fullprune`/`semiprune` do the same thing for you. Sweet!

I thought step 5 was easy, but my thing to handle it was wrong and there’s some other fun stuff to talk about at the same time. So I’m going to leave it as an exercise for the reader, and maybe I’ll write about it later! (edited to add: I wrote about it later! It’s called From Logic Variables to Type Variables)

## So… apart from that, what’s next?

This take on let bindings doesn’t support inferring types of recursive functions. I think there are at least three options: eliminate recursion with a fixed-point combinator, allocate a type variable to represent the type of this let-bound term and stick it in the environment before you start gathering constraints for the term, or allocate type variables for all of the let-bound terms and stick them all in the environment when gathering constraints for all the terms. I’ve tried the second option, and it seems to work. The third option should let you have mutually recursive functions.

Using let-bindings as a point of introducing and eliminating type variables is a pleasingly simple way to get out of the parametricity pickle. It feels kind of arbitrary to me, though. It might be more theoretically elegant if we didn’t unify all of the types involved when we apply an argument to a function, and instead we demanded that the value we’re applying be a subtype of the expected argument type. I believe this would push me to figure out what the deal is with unification-fd’s subsumes function. I’d probably learn a bunch about covariant and contravariant functions in the process.

Speaking of beliefs… I believe my implementation works, but it’s a pretty weak belief. I should probably test it more thoroughly. I’d also be interested in trying to get a deeper understanding of the practical difference between my NeedsFreshening hack and the Type Scheme business. I kind of want to implement both my approach and the one Diehl uses in coq and try to prove their equivalence.

Useful programming languages normally offer some sort of compound type and some way to build things up and break them down. Maybe I should look into that. But for now, I’m pretty happy to have written my first type inferencer!