'Why does converting the factorial function to iterative form by defunctionalizing the continuation give such a bad result?

I'm trying the "defunctionalize the continuation" technique on some recursive functions, to see if I can get a good iterative version to pop out. Following along with The Best Refactoring You've Never Heard Of (in Lua just for convenience; this seems to be mostly language-agnostic), I did this:

-- original
function printTree(tree)
    if tree then
        printTree(tree.left)
        print(tree.content)
        printTree(tree.right)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- make tail-recursive with CPS
function printTree(tree, kont)
    if tree then
        printTree(tree.left, function()
            print(tree.content)
            printTree(tree.right, kont)
        end)
    else
        kont()
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree, function() end)

-- defunctionalize the continuation
function apply(kont)
    if kont then
        print(kont.tree.content)
        printTree(kont.tree.right, kont.next)
    end
end
function printTree(tree, kont)
    if tree then
        printTree(tree.left, { tree = tree, next = kont })
    else
        apply(kont)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- inline apply
function printTree(tree, kont)
    if tree then
        printTree(tree.left, { tree = tree, next = kont })
    elseif kont then
        print(kont.tree.content)
        printTree(kont.tree.right, kont.next)
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

-- perform tail-call elimination
function printTree(tree, kont)
    while true do
        if tree then
            kont = { tree = tree, next = kont }
            tree = tree.left
        elseif kont then
            print(kont.tree.content)
            tree = kont.tree.right
            kont = kont.next
        else
            return
        end
    end
end
tree = { left = { content = 1 }, content = 2, right = { content = 3 } }
printTree(tree)

Then I tried the same technique on the factorial function:

-- original
function factorial(n)
    if n == 0 then
        return 1
    else
        return n * factorial(n - 1)
    end
end
print(factorial(6))

-- make tail-recursive with CPS
function factorial(n, kont)
    if n == 0 then
        return kont(1)
    else
        return factorial(n - 1, function(x)
            return kont(n * x)
        end)
    end
end
print(factorial(6, function(x) return x end))

-- defunctionalize the continuation
function apply(kont, x)
    if kont then
        return apply(kont.next, kont.n * x)
    else
        return x
    end
end
function factorial(n, kont)
    if n == 0 then
        return apply(kont, 1)
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

Here's where things start to go wrong. The next step is to inline apply, but I can't do that since apply calls itself recursively. To keep going, I tried doing tail-call elimination on it.

-- perform tail-call elimination
function apply(kont, x)
    while kont do
        x = kont.n * x
        kont = kont.next
    end
    return x
end
function factorial(n, kont)
    if n == 0 then
        return apply(kont, 1)
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

Okay, now we seem to be back on track.

-- inline apply
function factorial(n, kont)
    if n == 0 then
        local x = 1
        while kont do
            x = kont.n * x
            kont = kont.next
        end
        return x
    else
        return factorial(n - 1, { n = n, next = kont })
    end
end
print(factorial(6))

-- perform tail-call elimination
function factorial(n, kont)
    while n ~= 0 do
        kont = { n = n, next = kont }
        n = n - 1
    end
    local x = 1
    while kont do
        x = kont.n * x
        kont = kont.next
    end
    return x
end
print(factorial(6))

Okay, we got a fully iterative implementation of factorial, but it's a pretty bad one. I was hoping to end up with something like this instead:

function factorial(n)
    local x = 1
    while n ~= 0 do
        x = n * x
        n = n - 1
    end
    return x
end
print(factorial(6))

Is there any modification to the steps I followed that will let me mechanically end up with a function that looks more like this one?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source