What’s a monad?

It’s almost traditional that people learning Haskell should write their own version of ‘An introduction to monads’. I think it serves to teach the writer more than the reader, but that’s fine. I’ve understood them in a way which I’ve not seen covered in the existing introductions, so I thought I’d get it down on ‘paper’. Note that the point of this post is to demystify more than it is to enlighten – if you don’t already, you probably won’t fully understand monads by the end, but you hopefully will be much closer to the point where you can.

Java has the foreach loop:

for(Foo foo : foos) doStuff(foo);

Fundamentally unexciting, but a nice bit of syntactic sugar.

A common argument for adding ‘closures’ (really first class functions) to Java is that if they’d added them in in 1.5 they wouldn’t have needed to add the forEach loop because it could be defined as a method. Here’s an example in Scala:


package Demo;

object ControlFlow {
    def forEach[T] (iter : Iterator[T], action : T => Unit) : Unit = 
        while(iter.hasNext) action(iter.next); 

    def forEach[T] (iter : Iterable[T], action : T => Unit) : Unit = forEach (iter.elements, action);
}

Still fundamentally unexciting, right? It’s just yet another loop. Except… it’s not quite, is it? I’m going to give this its own line just to make sure the point is clear:

When we introduced first class functions to the language, we gained the ability to define our own control flow mechanisms.

This is Exciting.

Scala introduces another neat concept, sequence comprehensions:

package Demo;

object ComprehensionTest
{
    def main (args : Array[String])={
        val bar = for {
            val arg <- args;
            val arg2 <- args;
            !arg.equals(arg2) }
            yield (arg, arg2);
        
        Console.println(bar.toList);
    }
}

What does this do? Well, it constructs an iterable object consisting of all pairs of command line arguments, omitting repeated pairs. So

> scala Demo.ComprehensionTest foo bar

List((foo,bar), (bar,foo))

We could do much the same thing with nested for loops, but it wouldn't be as nice. For very involved collection manipulation, comprehensions simplify life a lot, so their addition to Scala is a great boon.

But we could have defined something very similar ourself.

Here's a rewrite that doesn't use comprehensions:


object NoComprehension
{
    def guard[T](test : Boolean, value : T) = 
        if (test) 
            new ::(value, Nil)
        else 
            Nil;
    
    def main (args : Array[String])={
        val bar = args.flatMap(
            (arg : String) => 
                args.flatMap(
                    (arg2 : String) => 
                        guard(!arg.equals(arg2), (arg, arg2))))

    Console.println(bar.toList);}
}

In fact, these compile to very similar things. Scala would use filter where I defined and used guard. I used the guard because a) I think it's clearer and b) It supports my point. :-)

So, what's going on here?

The following is the definition of flatMap in the Iterable[A] class definition (See here).

def flatMap [B](f : (A) => Iterable[B]) : Collection[B]
Applies the given function f to each element of this iterable, then concatenates the results.

So, let's look at the inner example first.

args.flatMap ((arg2 : String) => guard(!arg.equals(arg2), (arg, arg2)))

The anonymous function takes arg2 and returns a List, which is either [(arg, arg2)] or []. It then concatenates these lists together. So this has the effect of simultaneously pairing up the (arg, arg2) values and filtering out all elements for which the two are equal

So for each value of arg we have a list of the right (arg, arg2) pairs we want. We now flatMap this over all of args, and get the full list we want.

Easy, right?

The higher order functions approach is much more flexible, but the comprehension syntax is a lot more readable (and concise). Especially if you come from an imperative background. How do we make the two meet?

Now let's rewrite these examples in Haskell. First the higher order function one:

import Monad
import System

main = do{
    args <- getArgs;
    print $ distinctPairs args;
}

distinctPairs :: [String] -> [(String, String)]
distinctPairs args =
    args >>=
        \arg -> 
            args >>=
                \arg2 -> 
                    guard (arg /= arg2) >>
                    return (arg, arg2)

This looks almost identical to the Scala one, once you get over superficial differences in syntax.

In particular we replace the method foo.flatMap(bar) with the operator foo >>= bar. It does exactly the same thing (well, the type signatures are a bit different, but in this instance it does exactly the same thing).

The 'guard' method is a little different, as we're using Haskell's built in function of that name.

This is basically what it does:

guard :: Bool -> [()]
guard True = [()]
guard False = []

(This is again not really correct. It's a correct definition in this instance, but the real definition is more general).

What >> does is:

(>>) :: [a] -> [b] -> [b]
foo >> bar = foo >>= (\_ -> bar)

You may find this a bit confusing, so I'll unpack the definition with a quick reminder.

\_ -> bar is an anonymous function which takes anything and returns bar. So foo >>= bar concatenates together one copy of bar for each element of foo. i.e. it's length foo copies of bar joined together.

In particular guard test >> bar is either [] if test is False or bar if test is true (as guard has length 0 in the first case and 1 in the second).

return is very simple. return x = [x]

So, guard test >> return x is the same as guard(test, x) in our Scala method.

Still with me?

Now, how do we write this so it looks like a Scala comprehension?

import Monad
import System

main :: IO ()
main = do{
    args <- getArgs;
    print $ distinctPairs args;
}

distinctPairs :: [String] -> [(String, String)]
distinctPairs args = do{
    arg  <- args;
    arg2 <- args;
    guard(arg /= arg2);
    return (arg, arg2)
}    

Looks almost exactly like the Scala version, doesn't it?

At this point you might feel cheated if you've not seen 'do' notation before. "So... the point of this article is that Scala has these cool things called sequence comprehensions, and Haskell has them too? Who cares??".

Now, look up there a little bit. For convenience, I'll repeat it here:

Main :: IO ()
main = do{
    args <- getArgs;
    print $ distinctPairs args;
}

What's that got to do with sequence comprehensions?

Well, nothing.

It turns out that this set of operations '>>=', '>>' and 'return' is so useful that Haskell has bundled them into their own type class, called Monad. So these apply to any type in this type class, including both List and IO, as well as many others. You can then apply do notation to an monad, and it just gets converted into a use of these operations in more or less the same way that we went from the list comprehension to the higher order functions. It works like this:

do { foo } = foo, for foo an instance of the monad.
do { foo; bar; baz } = foo >> do { bar; baz }
do { myFoo <- foo; bar; baz } = foo >>= myFoo -> do{bar; baz} 

(note that the last one puts myFoo in scope for bar and baz, so we can use it in their definitions exactly like we'd expect).

Why bother giving this special treatment to monads? Well, it's the same reason as the foreach loop was introduced - they crop up *everywhere*. It turns out that (for reasons I won't go into here) you can realise the most outrageous range of programming idioms as instances of Monad. But doing so gives you somewhat clunky syntax, so the do notation exists to make that nicer. That's all it is.

So, what's a monad? Nothing special. It's a type class which gets some preferential treatment from the language because of its ubiquity. It contains some standard operations which map quite well onto common forms of control flow, so it tends to crop up quite a lot. That's all.

This entry was posted in programming and tagged , on by .

One thought on “What’s a monad?

  1. Pingback: Best of drmaciver.com | David R. MacIver

Comments are closed.