Skip to main content

Extra - λY

Today, we'll continue with the lambda calculus, and explore a fascinating and challenging problem: how to construct recursive functions. This is important, as recursion can be used to express iteration, and all sorts of other important computational patterns. It might seem initially easy, by just having functions call themselves (indeed, while recursion in Pyret may have been confusing, it wasn't hard to do), but when we stop to think about it, we run into a problem -- lambdas have no name, and therefore have no way of "calling themselves".

How, then, can we construct functions that call themselves? This is the puzzle we'll figure out today.

Let's recap the definitions we have (the ones in ALLCAPS are actual lambda calculus terms, the of...(...) and to...(...) are convenient helpers to make it easier to test our code in Pyret).

Note that we make a minor change from last time -- since in Pyret, functions evaluate their arguments before being called, if we use the implementations of booleans (and IF, AND, OR, and NOT) that we came up with last time, we won't get the short-circuiting behavior that we expect -- IF will evaluate both the then and the else branches, etc. This worked in the original lambda calculus because there was no defined order of evaluation, but it will make our concrete example today not work.

So we slightly tweak -- expecting the arguments to TRUE and FALSE to be zero argument functions that get evaluated by TRUE and FALSE (note the parantheses after x and y in the definitions of TRUE and FALSE)

TRUE = lam(x,y): x() end
FALSE = lam(x,y): y() end
fun tobool(cb):
cb(lam(): true end, lam(): false end)
end

IF = lam(c,t,e): c(t,e) end
AND = lam(b1, b2): b1(lam(): b2 end, lam(): FALSE end) end
OR = lam(b1, b2): b1(lam(): TRUE end, lam(): b2 end) end
NOT = lam(b): b(lam(): FALSE end, lam(): TRUE end) end

ZERO = lam(f, x): x end
ONE = lam(f, x): f(x) end
TWO = lam(f, x): f(f(x)) end
fun ofnum(n :: Number):
lam(f, x):
fun r(m):
if m == 0:
x
else:
f(r(m - 1))
end
end
r(n)
end
end
fun tonum(cn) -> Number:
cn(lam(y): y + 1 end, 0)
end

PAIR = lam(a,b): lam(z): z(a,b) end end
FIRST = lam(p): p(lam(a,b): a end) end
SECOND = lam(p): p(lam(a,b): b end) end

ADD = lam(n1, n2): lam(f, x): n2(f, n1(f, x)) end end
MUL = lam(n1, n2): n1(lam(y): ADD(n2, y) end, ZERO) end
MINUS1 = lam(n): lam(f,x): FIRST(n(lam(y): PAIR(SECOND(y), f(SECOND(y))) end, PAIR(x,x))) end end

EQUAL0 = lam(n): n(lam(y): FALSE end, TRUE) end

Let's do a little review -- we'll use some normal Pyret features in these tests, and then we'll stick with pure lambda calculus for the rest of the lecture:

check:
IF(TRUE, lam(): 1 end, lam(): 2 end) is 1
IF(FALSE, lam(): 1 end, lam(): 2 end) is 2
IF(AND(OR(FALSE, TRUE), NOT(FALSE)), lam(): "a" end, lam(): "b" end) is "a"

FOUR = ofnum(4)

FOUR(lam(y): y + 1 end, 0) is 4
FOUR(lam(y): y + 1 end, 3) is 7
FOUR(lam(y): y + 2 end, 1) is 9

TWO(lam(s): "Hi! " + s end, "Bye!") is "Hi! Hi! Bye!"

ZERO(not, true) is true
ONE(not, true) is false
TWO(not, true) is true
end

Omega

To start our journey towards Y, we take as inspiration a simple program in the lambda calculus that must involve recursion (or something close enough), since it runs forever!

(lam(x): x(x) end)(lam(x): x(x) end)

Why does this run forever? Because the argument is lam(x): x(x) end, and is substituted for x in the first function, which then calls that on itself -- this immediately gets us back to the same program.

Moving towards recursion

How do we exploit that idea to get something useful? Let's pick a concrete function we want to write: the factorial function. This is about the simplest recursive function that produces a value.

In Pyret, we would write:

fun factorial(n :: Number) -> Number:
if n == 0:
1
else:
n * factorial(n - 1)
end
end

Importantly, with what we did the previous lecture, can now express all the pieces of this in the lambda calculus except the recursive call. We have IF, EQUAL0, ONE, MUL, and MINUS1. So a hybrid Pyret-LambdaCalculus version might be:

fun factorial(n :: Number) -> Number:
IF(EQUAL0(N),
lam(): ONE end,
lam(): MUL(N, factorial(MINUS1(N))) end)
end

But how do we handle the recursive call?

The key idea turns out to be: write a function that, rather than calling itself (how recursion normally works), expects to be passed the function that it should call.

This is our first attempt, which doesn't work, but gets us closer (as it eliminates the explicit recursion):

FACT0 = lam(rcall, n): 
IF(EQUAL0(n), lam(): ONE end, lam(): MUL(n, rcall(MINUS1(n))) end)
end

This almost works (and is valid, lambda calculus code), but to use it, we need something to pass as rcall. It seems like we'd already need to have a recursive version of the function to make that work.

One more layer

What if, however, rcall itself also expected to be passed the function to be called on the next iteration. How would it make a recursive call? Well, if at the next iteration we wanted to call rcall, then we could pass rcall both itself (as the function to call recursively) and the argument.

FACT1 = lam(rcall, n): 
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, rcall(rcall, MINUS1(n))) end)
end

Now, the question is how can we use this? Well, what if we call fact1 passing itself as the first argument. This is not recursion -- we aren't cheating -- since we could easily just copy the code we have. Remember, the fact that we are using Pyret constant definitions is a matter of convenience only.

check:
tonum(FACT1(FACT1, 5)) is 120
end

And, miraculously, this works! We've figured out how to write recursive functions, in the lambda calculus!

But, it was slightly clumsy. Let's figure out how to extract out the essential parts of the code from the parts that are needed to set up the recursion, so our code can be more natural, and easier to read.

Moving towards Y

First, we see that we have this pattern where we define a variable, then call it with itself as its first argument. I used define to accomplish that, but define doesn't exist in the pure lambda calculus. We can accomplish the same thing with lambda and application, though:

check:
tonum((lam(fact2):
fact2(fact2, ofnum(5))
end)(lam(rcall, n):
IF(EQUAL0(n), lam(): ONE end, lam(): MUL(n, rcall(rcall, MINUS1(n))) end) end)) is 120

We could extract out the argument, and end up with something like:

FACT3 = lam(m): (lam(fact2):
fact2(fact2, m)
end)(lam(rcall, n):
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, rcall(rcall, MINUS1(n))) end) end)
end

Now we are getting somewhere!

check:
tonum(FACT3(ofnum(5))) is 120
end

FACT3 can be called like a normal function! Good! Now, how can we make writing these easier? Well, one perhaps non-intuitive step is that if we make all our functions single argument, we actually might see opportunity to factor out more:

FACT4 = lam(m): (lam(fact2):
fact2(fact2)(m)
end)(lam(rcall): lam(n):
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, rcall(rcall)(MINUS1(n))) end) end end)
end

One thing that isn't great about our lam(rcall) ... end is that we have this rcall(rcall) on every recursive call. How do we abstract that out? Our first attempt would be to just add a lambda outside, take the argument (call it f) and apply it to itself before passing it as rcall. Now rcall is the recursive application, and so doesn't need the self-call within the function. But this doesn't work, as it ends up running forever! But if we suspend the same thing, and don't actually do the self application until we are actually called, this works fine:

FACT5 = lam(m): (lam(fact2):
fact2(fact2)(m)
end)(lam(f):
(lam(rcall):
lam(n):
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, rcall(MINUS1(n))) end)
end
end)(lam(x): f(f)(x) end)
end)
end

At this point, if we squint, we see in the middle the part that we want to write -- I renamed rcall to factorial -- and this is pretty ideal!:

FACT = lam(factorial):
lam(n):
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, factorial(MINUS1(n))) end)
end
end

If we rename rcall to fact, this is exactly the code we want to write. So let's extract all of that out as, say F, leaving us with the remainder of the code (the code we don't really want to write):

Y1 = lam(F):
lam(m):
(lam(fact2):
(fact2(fact2))(m)
end)(lam(f):
F(lam(x): (f(f))(x) end)
end)
end
end
check:
tonum(Y1(FACT)(ofnum(5))) is 120
end

Now let's do some renaming: first, fact2 really doesn't need to be called that -- it really has nothing to do with factorial! Let's call it g:

Y2 = lam(F):
lam(m):
(lam(g): g(g)(m) end)(lam(f):
F(lam(x): (f(f))(x) end)
end)
end
end

check:
tonum(Y2(FACT)(ofnum(5))) is 120
end

And once we do that, we notice that the return of the body of the inner lambda is g(g)(m). But that means that if we instead return g(g), we are returning a function that takes one argument and returns whatever g(g)(m) would return. But that means we can eliminate the outer (lam(m): ... end) and have the body just be g(g). (In general, lam(x): f(x) end is equivalent to f).

Y3 = lam(F):
(lam(g): g(g) end)(lam(f):
F(lam(x): f(f)(x) end)
end)
end

check:
tonum(Y3(FACT)(ofnum(5))) is 120
end

We can make this more symmetric (normally the two functions could be above each other, which gives it a Y shape, but Pyret doesn't want space between them) by partially applying on the inside, to yield our final form:

Y = lam(F):
(lam(f): F(lam(x): f(f)(x) end) end)(lam(f): F(lam(x): f(f)(x) end) end)
end

check:
tonum(Y(FACT)(ofnum(5))) is 120
end

While there is another example of that pattern inside: lam(x): f(f)(x) end, and in theory that is equivalent to f(f) (and indeed, that'll yield the typical form presented of the Y combinator), in a strict language that evaluates its arguments before substituting, making that change will cause the program to run forever. So Y is our final version, to see the complete program, which is pure lambda calculus:

FACTORIAL = Y(lam(fact):
lam(n):
IF(EQUAL0(n),
lam(): ONE end,
lam(): MUL(n, fact(MINUS1(n))) end)
end
end)

tonum(FACTORIAL(ofnum(5)))