headshot of Ryan Ryan Marcus, assistant professor at the University of Pennsylvania (Fall '23). Using machine learning to build the next generation of data systems.
      
    ____                       __  ___                          
   / __ \__  ______ _____     /  |/  /___ _____________  _______
  / /_/ / / / / __ `/ __ \   / /|_/ / __ `/ ___/ ___/ / / / ___/
 / _, _/ /_/ / /_/ / / / /  / /  / / /_/ / /  / /__/ /_/ (__  ) 
/_/ |_|\__, /\__,_/_/ /_/  /_/  /_/\__,_/_/   \___/\__,_/____/  
      /____/                                                    
        
   ___                   __  ___                    
  / _ \__ _____ ____    /  |/  /__ ___________ _____
 / , _/ // / _ `/ _ \  / /|_/ / _ `/ __/ __/ // (_-<
/_/|_|\_, /\_,_/_//_/ /_/  /_/\_,_/_/  \__/\_,_/___/
     /___/                                          
        
   ___  __  ___                    
  / _ \/  |/  /__ ___________ _____
 / , _/ /|_/ / _ `/ __/ __/ // (_-<
/_/|_/_/  /_/\_,_/_/  \__/\_,_/___/                                   
        

No, really, what's a monad?

When seasoned Haskell developers try to explain monads to the rest of us, they’ll often find themselves resorting to ridiculously incomprehensible definitions. James Iry satirized this unique semi-meme as:

A monad is just a monoid in the category of endofunctors, what’s the problem?

The original quote, which reads with an equally satirical tone, comes from the book Categories for the Working Mathematician:

All told, a monad in X is just a monoid in the category of endofunctors of X, with product Ă— replaced by composition of endofunctors and unit set by the identity endofunctor.

There’s also much debate over if monads are or are not burritos. There are several different examples of “monad tutorials” out there as well, although I’ve found they tend to be either overly simplistic (they only give an example of a monad) or mathematically impenetrable.

What follows is yet another attempt at explaining monads. In general, abstract concepts like monads require a bit of struggling to fully comprehend, but I’ve designed this particular guide with a few goals in mind:

Before we get started, pick your preferred programming language. Personally, I find the Python code to be the most compact and readable, but you should pick whichever language with which you are most familiar.

Our first monads

First, consider the simple program below.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run():
  x = "friend"
  x = step1(x)
  x = step2(x)
  x = step3(x)
  return x

print(run())
  
class Main {
  public static void main(String[] args) {
    String x = "friend";
    x = step1(x);
    x = step2(x);
    x = step3(x);
    System.out.println(x);
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function run() {
  x = "friend";
  x = step1(x);
  x = step2(x);
  x = step3(x);
  return x;
}

console.log(run());
#include <iostream>
#include <string>

using namespace std;

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

string run() {
  string x = "friend";
  x = step1(x);
  x = step2(x);
  x = step3(x);
  return x;
}

int main() {
    cout << run() << "\n";
}

There are no tricks here: we just use a few functions static methods functions functions to put some strings together into a message, and then we print out that message. Notice how x is set to an initial value ("friend"), and then we set x to the result of calling a few other functionsmethodsfunctionsfunctions .

Next, let’s take this pattern and codify it a bit. We’ll still set x to some initial value, but then we’ll introduce two new functionsmethodsfunctionsfunctions: wrap and wrap_call wrapCall wrapCall wrapCall .

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"
  
def wrap(x):
    return x
    
def wrap_call(x, func):
    return func(x)
  
def run():
  x = "friend"
  
  # first, wrap up the value of x
  x = wrap(x)
  
  # now use wrap_call to run each step
  x = wrap_call(x, step1)
  x = wrap_call(x, step2)
  x = wrap_call(x, step3)
  
  return x
  
print(run())
import java.util.function.Function;

class Main {
  public static void main(String[] args) {
    String x = "friend";
    
    // first, wrap up the value of x
    x = wrap(x);
    
    x = wrap_call(x, Main::step1);
    x = wrap_call(x, Main::step2);
    x = wrap_call(x, Main::step3);
    System.out.println(x);
  }
  
  public static String wrap(String x) {
    return x;
  }
  
  public static String wrap_call(String s, Function<String, String> func) {
    return func.apply(s);
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function wrap(x) {
  return x;
}

function wrapCall(x, func) {
  return func(x);
}

function run() {
  x = "friend";
  
  // first, wrap up the value of x
  x = wrap(x);
  
  // now use wrapCall to run each step
  x = wrapCall(x, step1);
  x = wrapCall(x, step2);
  x = wrapCall(x, step3);
  return x;
}

console.log(run());
#include <iostream>
#include <string>

using namespace std;

string wrap(string x) {
  return x;
}

string wrapCall(string x, string (*func) (string)) {
  return func(x);
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

string run() {
  string x = "friend";
  
  // first, wrap up the value of x
  x = wrap(x);
  
  // now use wrapCall to run each step
  x = wrapCall(x, step1);
  x = wrapCall(x, step2);
  x = wrapCall(x, step3);
  return x;
}

int main() {
    cout << run() << "\n";
}

These two functions don’t do much: wrap just returns whatever we gave it, and wrap_callwrapCallwrapCallwrapCall just calls func with x as a parameter. You can run the code to verify it produces the same output as before. Take a minute to really convince yourself that the addition of wrap and wrap_callwrapCallwrapCallwrapCall leave the behavior of the program essentially unchanged.

Next, let’s get crazy and make wrap and wrap_callwrapCallwrapCallwrapCall do something a little less trivial.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"
  
def wrap(x):
    return "[" + x + "]"
    
def wrap_call(x, func):
    return "[" + func(x[1:-1]) + "]"
  
def run():
  x = "friend"
  
  # first, wrap up the value of x
  x = wrap(x)
  
  # now use wrap_call to run each step
  x = wrap_call(x, step1)
  x = wrap_call(x, step2)
  x = wrap_call(x, step3)
  
  return x
  
print(run())
import java.util.function.Function;

class Main {
  public static void main(String[] args) {
    String x = "friend";
    
    // first, wrap up the value of x
    x = wrap(x);
    
    x = wrap_call(x, Main::step1);
    x = wrap_call(x, Main::step2);
    x = wrap_call(x, Main::step3);
    System.out.println(x);
  }
  
  public static String wrap(String x) {
    return "[" + x + "]";
  }
  
  public static String wrap_call(String s, Function<String, String> func) {
    return "[" + func.apply(s.substring(1,s.length() - 1)) + "]";
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function wrap(x) {
  return "[" + x + "]";
}

function wrapCall(x, func) {
  return "[" + func(x.substring(1, x.length-1)) + "]";
}

function run() {
  x = "friend";
  
  // first, wrap up the value of x
  x = wrap(x);
  
  // now use wrapCall to run each step
  x = wrapCall(x, step1);
  x = wrapCall(x, step2);
  x = wrapCall(x, step3);
  return x;
}

console.log(run());
#include <iostream>
#include <string>

using namespace std;

string wrap(string x) {
  return "[" + x + "]";
}

string wrapCall(string x, string (*func) (string)) {
  return "[" + func(x.substr(1, x.length() - 2)) + "]";
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

string run() {
  string x = "friend";
  
  // first, wrap up the value of x
  x = wrap(x);
  
  // now use wrapCall to run each step
  x = wrapCall(x, step1);
  x = wrapCall(x, step2);
  x = wrapCall(x, step3);
  return x;
}

int main() {
    cout << run() << "\n";
}

All we’ve done is modify wrap to surround x with square brackets, and we’ve modified wrap_callwrapCallwrapCallwrapCall to remove the square brackets using Python’s string indexingsubstringsubstringsubstr , call func, and add the square brackets back.

The resulting program prints the same string as the previous program, but adds square brackets around it. Without a doubt, this is a very silly change to make to a program, since we could’ve just added the square brackets in step3. Notice, however, that this approach will work no matter how many steps we add to the program.

OK, one more step before we do something that’s actually useful. Let’s rename a few things, and make wrap and wrap_callwrapCallwrapCallwrapCall parameters.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  return "[" + x + "]"
    
def bind(x, func):
  return "[" +func(x[1:-1]) + "]"
  
print(run(ret, bind))
import java.util.function.Function;
import java.util.function.BiFunction;


class Main {
  public static void main(String[] args) {
    System.out.println(run(Main::ret, Main::bind));
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  public static String ret(String x) {
    return "[" + x + "]";
  }
  
  public static String bind(String s, Function<String, String> func) {
    return "[" + func.apply(s.substring(1,s.length() - 1)) + "]";
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return "[" + x + "]";
}

function bind(x, func) {
  return "[" + func(x.substring(1, x.length-1)) + "]";
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

console.log(run(ret, bind));
#include <iostream>
#include <string>

using namespace std;

string ret(string x) {
  return "[" + x + "]";
}

string bind(string x, string (*func) (string)) {
  return "[" + func(x.substr(1, x.length() - 2)) + "]";
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  T x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
    cout << run<string>(ret, bind) << "\n";
}

All we did was rename wrap to ret (which is short for return), rename wrap_callwrapCallwrapCallwrapCall to bind, and pass the two functions methods functions functions as parameters. Believe it or not, we’ve just written a monad (specifically, the monad is the pair of bind and ret)!

Symbolically, we can think of ret as a function that takes a value of some type and produces a value of some type . In our example, is “strings” and is “strings surrounded by square brackets.”

We can think of bind as a function that takes:

  1. A value of type
  2. A function that maps items of type to items of type

… and produces an item of type .

Going farther down the rabbit hole, we can say that bind is a mapping from the cross product of:

… to things of type .

It might help to “fill in the blanks” and think of ret and bind more concretely. In our example, ret took a string-typed () value and transformed it into a square-bracket-surrounded-string-typed () value. The bind function took a square-bracket-surrounded-string-typed () and a function that mapped strings () to strings (), such as step1, step2, and step3. bind produced a square-bracket-surrounded-string-typed () value. So the mathematics line up with our code.

While the mathematical formulation is specific, it often isn’t particularly clarifying about when a monad can or should be used. To answer these more practical questions (and to build our intuitive understanding of what a monad really is in the first place), let’s look at some less trivial examples.

In the code sample below, we create a monad that prints out debugging information from each step of our string building process.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  print("Initial value:", x)
  return "[" + x + "]"
  
def bind(x, func):
  print("Input to next step is:", x)
  result = func(x[1:-1])
  print("Result is:", result)
  return "[" + result + "]"
  
print(run(ret, bind))
import java.util.function.Function;
import java.util.function.BiFunction;


class Main {
  public static void main(String[] args) {
    System.out.println(run(Main::ret, Main::bind));
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  public static String ret(String x) {
    System.out.println("Initial value: " + x);
    return "[" + x + "]";
  }
  
  public static String bind(String s, Function<String, String> func) {
    System.out.println("Input to next step is: " + s);
    String result = func.apply(s.substring(1,s.length() - 1));
    System.out.println("Result is: " + result);
    return "[" + result + "]";
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  console.log("Initial value: " + x);
  return "[" + x + "]";
}

function bind(x, func) {
  console.log("Input to next step is: " + x)
  result = "[" + func(x.substring(1, x.length-1)) + "]";
  console.log("Result is: " + result);
  return result;
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

console.log(run(ret, bind));
#include <iostream>
#include <string>

using namespace std;

string ret(string x) {
  cout << "Initial value: " << x << "\n";
  return "[" + x + "]";
}

string bind(string x, string (*func) (string)) {
  cout << "Input to next step is: " << x << "\n";
  string r = "[" + func(x.substr(1, x.length() - 2)) + "]";
  cout << "Result is: " << r << "\n";
  return r;
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  T x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
    cout << run<string>(ret, bind) << "\n";
}

By adding a few simple lines to ret and bind, we can print out the value of x before and after each step. Run the code above to see an example. Notice that without using a monad, we would have to write code for each step. With three steps, this isn’t too bad, but if we had a much larger number of steps, using a monad might’ve saved us significant time! And, since normally we only want to add print statements like this to our code temporarily, monads let us just swap the ret and bind functions around to change between production and debugging.

So far, the “actual” type of the value passed between each function has been a strStringStringstd::string : we invented the idea of the “square-bracket-surrounded-string type”, but the language didn’t know about it. This doesn’t have to be the case. Consider another example of a simple monad, the counting monad, below.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  return {"value": x, "count": 0}
  
def bind(x, func):
  return {"value": func(x["value"]), "count": x["count"] + 1}
  
print(run(ret, bind))
import java.util.function.Function;
import java.util.function.BiFunction;


class Main {
  public static void main(String[] args) {
    System.out.println(run(Main::ret, Main::bind));
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  static class Record {
    String value;
    int count;
    
    public String toString() {
      return "value: " + value + " count: " + count;
    }
  }
  
  public static Record ret(String x) {
    Record r = new Record();
    r.value = x;
    r.count = 0;
    return r;
  }
  
  public static Record bind(Record input, Function<String, String> func) {
    Record r = new Record();
    r.value = func.apply(input.value);
    r.count = input.count + 1;
    return r;
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return {value: x, count: 0};
}

function bind(x, func) {
  return {value: func(x.value), count: x.count + 1}
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

console.log(run(ret, bind));
#include <iostream>
#include <string>
#include <memory>

using namespace std;

struct Record {
  string x;
  int count;
};

shared_ptr<Record> ret(string x) {
  shared_ptr<Record> toR(new Record);
  toR->x = x;
  toR->count = 0;
  return toR;
}

shared_ptr<Record> bind(shared_ptr<Record> p, 
                        string (*func) (string)) {
  p->x = func(p->x);
  p->count++;
  return p;
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  // first, wrap up the value of x
  T x = ret("friend");
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
    cout << run<shared_ptr<Record>>(ret, bind)->x << "\n";
}

Above, our ret functionmethodfunctionfunction gives back a structure containing a value and a counter. Our bind updates the value stored inside the structure, and increments the counter. As a result, run now returns this special structure, and not a string.

Now, we get both the return value of the original run functionmethodfunctionfunction as well as the number of steps taken. This might not seem useful at first, but consider that bind now has access to the current “step count” each time it is called! Now we can write this monad:

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  return {"value": x, "count": 0}
  
def bind(x, func):
  with open("f" + str(x["count"]) + ".txt", "w") as f:
    f.write("Initial value: " + str(x["value"]) + "\n")
    result = func(x["value"])
    f.write("Result: " + str(result) + "\n")
    
  return {"value": result, "count": x["count"] + 1}
  
import os
print(run(ret, bind))
print(os.listdir("."))
import java.util.function.Function;
import java.util.function.BiFunction;
import java.io.*;


class Main {
  public static void main(String[] args) {
    System.out.println(run(Main::ret, Main::bind).value);
    File[] list = (new File(".")).listFiles();
    
    for (File f : list) {
      System.out.println(f);
    }
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  static class Record {
    String value;
    int count;
  }
  
  
  public static Record ret(String x) {
      Record r = new Record();
      r.value = x;
      r.count = 0;
      return r;
  }
  
  public static Record bind(Record s, Function<String, String> func) {
    String result;
    try (Writer writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("f" + s.count  + ".txt"), "utf-8"))) {
      writer.write("Initial value: " + s.value + "\n");
      result = func.apply(s.value);
      writer.write("Result: " + result + "\n");
    } catch (Exception e) {
      // logging failed, proceed normally
      result = func.apply(s.value);
    }
    
    Record r = new Record();
    r.value = result;
    r.count = s.count + 1;
    return r;
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
const fs = require("fs");

function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return {value: x, count: 0};
}

function bind(x, func) {
  result = func(x.value)
  fs.writeFileSync("f" + x.count + ".txt", "Initial value: " + x.value + "\n Result: " + result + "\n");
  return {value: result, count: x.count + 1};
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

console.log(run(ret, bind));
console.log(fs.readdirSync("."));
#include <iostream>
#include <string>
#include <memory>
#include <fstream>
#include <sstream> 
#include <stdlib.h>

// sadly, we need this because repl.it's GCC is too old
namespace std {
    template<typename T>
    std::string to_string(const T &n) {
        std::ostringstream s;
        s << n;
        return s.str();
    }
}

using namespace std;

struct Record {
  string x;
  int count;
};

shared_ptr<Record> ret(string x) {
  shared_ptr<Record> toR(new Record);
  toR->x = x;
  toR->count = 0;
  return toR;
}

shared_ptr<Record> bind(shared_ptr<Record> p, 
            string (*func) (string)) {
              
  ofstream myfile;
  myfile.open("f" + std::to_string(p->count) + ".txt");
  myfile << "Initial value: " << p->x << "\n";
  p->x = func(p->x);
  p->count++;
  myfile << "Result: " << p->x << "\n";
  return p;
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  // first, wrap up the value of x
  T x = ret("friend");
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
    cout << run<shared_ptr<Record>>(ret, bind)->x << "\n" << flush;
    system("ls");
}

Above, we use the count information from the previous step to name a file (f1.txt for step1, f2.txt for step2, f3.txt for step3) that records debugging information about each step of run.

Finally, we can use the identity monad, as seen below, to run the program normally (with no files written or extra prints):

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  return x
  
def bind(x, func):
  return func(x)
  
print(run(ret, bind))
import java.util.function.Function;
import java.util.function.BiFunction;


class Main {
  public static void main(String[] args) {
    System.out.println(run(Main::ret, Main::bind));
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  
  public static String ret(String x) {
     return x;
  }
  
  public static String bind(String s, Function<String, String> func) {
    return func.apply(s);
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return x;
}

function bind(x, func) {
  return func(x)
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

console.log(run(ret, bind));
#include <iostream>
#include <string>
#include <memory>

using namespace std;

string ret(string x) {
  return x;
}

string bind(string p, 
            string (*func) (string)) {
  return func(p);
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  // first, wrap up the value of x
  T x = ret("friend");
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
    cout << run<string>(ret, bind) << "\n";
}

Now we can change the behavior of run to either:

… just by changing which ret and bind we pass to run! Notice that there’s no need to change run at all to change which monad we use. Now, of course, you could’ve used a logging library with a configuration file or a global variable to control where to send debugging information, but I think that monads provide a very clean solution to the problem. Next, we will look at a monad that can’t so easily be replaced with a logging library.

With this basic structure, we could also easily build monads that:

Lazy monads

A common use case for monads is transforming eager programs into lazy programs. An eager program is a program which immediately materializes its result. Most programs written in Python, except for code that uses generators,Java, except for parts of the Streams API,JavaScript that are synchronousC++ are eager. A lazy program is one that doesn’t materialize its result until it absolutely must. Examples include Python generators, Java streams, or JavaScript generators, introduced in ECMAScript 2015.

Imagine that your professor asks you to write the numbers 1 through 100 on placards so they can hold them up in class to discuss each number. You could create all 100 placards ahead of time, but you know that class is only an hour long and your professor is likely to spend at least twenty minutes talking about the delicate intricacies of the number 3. Since writing a number on a placard is a quick process, you decide to just come to class with blank placards, and quickly write a number on one as your professor needs it. Making all 100 placards ahead of time is an eager approach (you materialize each card ahead of time), whereas making each placard as required is a lazy approach.

The two most common examples of such monads are the “thunk” monad and the continuation monad. The “thunk” monad wraps a bunch of functions together to be executed at the same time (think waiting for the professor to ask for the first card, then making all the cards right afterwards). The continuation monad turns a bunch of functions into a single function that can be called multiple times to apply each of the original functions (each call is like making a single card).

The “thunk” monad

Thunks are one of those concepts in computer science from ancient history that has been “reclaimed” by functional programming advocates and SICP acolytes to describe a wide range of concepts, all of which vaguely deal with changing how a function is called (traditionally, thunks have been used to switch between call-by-value and call-by-reference semantics, but this isn’t relevant to us). Generally, when we talk about a “thunk” monad, we are talking about a monad that transforms an eager program into a single lazy value. So instead of evaluating step3(step2(step1("friend"))) all at once, we create a LazyWrapper classLazyWrapper classJavaScript objectRecord struct , which holds a listListarraystd::vector of functions to apply and a get functionmethodfunctionfunction that applies all of these functions to the initial value and returns it.

Check out the code below.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  class LazyWrapper:
    def __init__(self, x):
      self.initial = x
      self.__funcs = []
      
    def get(self):
      x = self.initial
      for f in self.__funcs:
        x = f(x)
      return x
      
    def add_func(self, f):
      self.__funcs.append(f)
      
  return LazyWrapper(x)
  
def bind(x, func):
  x.add_func(func)
  return x
  
lazy = run(ret, bind)
print("none of the steps have executed yet!")
print(lazy.get()) # this will execute them all
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.*;


class Main {
  public static void main(String[] args) {
    LazyWrapper lw = run(Main::ret, Main::bind);
    System.out.println("Nothing has been computed yet!");
    System.out.println(lw.get());
  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  static class LazyWrapper {
    private String initialValue;
    private List<Function<String, String>> funcs;
    
    public LazyWrapper(String initial) {
      initialValue = initial;
      funcs = new LinkedList<>();
    }
    
    public void addFunc(Function<String, String> f) {
      funcs.add(f);
    }
    
    public String get() {
      for (Function<String, String> f : funcs) {
        initialValue = f.apply(initialValue);
      }
      
      return initialValue;
    }
  }
  
  
  public static LazyWrapper ret(String x) {
      return new LazyWrapper(x);
  }
  
  public static LazyWrapper bind(LazyWrapper s, Function<String, String> func) {
    s.addFunc(func);
    return s;
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return { initial: x, funcs: [], 
  add: function(x) { this.funcs.push(x); },
  get: function () { 
    let x = this.initial;
    this.funcs.forEach((f) => x = f(x));
    return x;
  }};
}

function bind(x, func) {
  x.add(func);
  return x;
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

const lazy = run(ret, bind);
console.log("nothing has run yet!");
console.log(lazy.get());
#include <iostream>
#include <string>
#include <memory>
#include <fstream>
#include <vector>


using namespace std;

struct Record {
  string initial;
  vector<string (*) (string)> funcs;
  
  string get() {
    for (auto f : funcs) {
      initial = f(initial);
    }
    
    return initial;
  }
};

shared_ptr<Record> ret(string x) {
  shared_ptr<Record> toR(new Record);
  toR->initial;
  toR->funcs.clear();
  return toR;
}

shared_ptr<Record> bind(shared_ptr<Record> p, 
                        string (*func) (string)) {
              
  p->funcs.push_back(func);
  return p;
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  // first, wrap up the value of x
  T x = ret("friend");
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
  shared_ptr<Record> result = run<shared_ptr<Record>>(ret, bind);
  cout << "nothing has been computed yet!" << "\n";
  
  cout << result->get() << "\n";
}

This sort of monad lets you prepare a complex operation (potentially composing together an arbitrary number of functionsmethodsfunctionsfunctions lazily, e.g., without actually executing any of the functionsmethodsfunctionsfunctions . Then, you can execute them all at once with a single call to get.

Why is this useful? Couldn’t you just only execute step3(step2(step1("friend"))) when you actually need to evaluate the result? Certainly – this monad is really just a different way to bundle up (compose) a bunch of functionsmethodsfunctionsfunctions .

The continuation monad

The continuation monad is a lot like the thunk monad we made previously, but instead of executing all the functionsmethodsfunctionsfunctions at once when get is called, it executes each functionmethodfunctionfunction one at a time.

def step1(x):
  return "Hello " + x
  
def step2(x):
  return x + ", monads aren't that complicated."
  
def step3(x):
  return "***" + x + "***"

def run(ret, bind):
  x = "friend"
  
  # first, wrap up the value of x
  x = ret(x)
  
  # next, use bind to call each step
  x = bind(x, step1)
  x = bind(x, step2)
  x = bind(x, step3)

  return x
  
  
def ret(x):
  class ContinuationWrapper:
    def __init__(self, x):
      self.x = x
      self.__funcs = [lambda a: a]
      
    def get(self):
      # WARNING: bug! will keep trying to pop forever!
      self.x = self.__funcs.pop(0)(self.x)
      return self.x
      
    def add_func(self, f):
      self.__funcs.append(f)
      
  return ContinuationWrapper(x)
  
def bind(x, func):
  x.add_func(func)
  return x
  
lazy = run(ret, bind)
print("none of the steps have executed yet!")
print(lazy.get()) # this will execute the first step (the initial value, "friend")
print(lazy.get()) # this will execute the second step
print(lazy.get()) # this will execute the third step
print(lazy.get()) # this will execute the fourth step
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.*;


class Main {
  public static void main(String[] args) {
    ContinuationWrapper lw = run(Main::ret, Main::bind);
    System.out.println("Nothing has been computed yet!");
    System.out.println(lw.get()); // the initial value
    System.out.println(lw.get()); // step1
    System.out.println(lw.get()); // step2
    System.out.println(lw.get()); // step3

  }
  
  public static <T> T run(Function<String, T> ret, BiFunction<T, Function<String, String>, T> bind) {
    String x = "friend";
    
    // first, wrap up the value of x
    T res = ret.apply(x);
    
    res = bind.apply(res, Main::step1);
    res = bind.apply(res, Main::step2);
    res = bind.apply(res, Main::step3);
    return res;
  }
  
  static class ContinuationWrapper {
    private String x;
    private Deque<Function<String, String>> funcs;
    
    public ContinuationWrapper(String initial) {
      this.x = initial;
      funcs = new LinkedList<>();
      funcs.add((x) -> x);
    }
    
    public void addFunc(Function<String, String> f) {
      funcs.add(f);
    }
    
    public String get() {
      // WARNING: bug! will keep trying to pop forever!
      x = funcs.removeFirst().apply(x);
      return x;
    }
  }
  
  
  public static ContinuationWrapper ret(String x) {
      return new ContinuationWrapper(x);
  }
  
  public static ContinuationWrapper bind(ContinuationWrapper s, Function<String, String> func) {
    s.addFunc(func);
    return s;
  }
  
  public static String step1(String x) {
    return "Hello " + x;
  }
  
  public static String step2(String x) {
    return x + ", monads aren't that complicated.";
  }
  
  public static String step3(String x) {
    return "***" + x + "***";
  }
}
function step1(x) {
  return "Hello " + x;
}

function step2(x) {
  return x + ", monads aren't the complicated.";
}

function step3(x) {
  return "***" + x + "***";
}

function ret(x) {
  return { initial: x, funcs: [(x) => x], 
           add: function(x) { this.funcs.push(x); },
           get: function () { 
            // WARNING: bug! will keep trying to pop forever!
            this.initial = this.funcs.shift()(this.initial);
            return this.initial;
  }};
}

function bind(x, func) {
  x.add(func);
  return x;
}

function run(ret, bind) {
  x = "friend";
  
  // first, wrap up the value of x
  x = ret(x);
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

const lazy = run(ret, bind);
console.log("nothing has run yet!");
console.log(lazy.get()); // initial value
console.log(lazy.get()); // step 1
console.log(lazy.get()); // step 2
console.log(lazy.get()); // step 3
#include <iostream>
#include <string>
#include <memory>
#include <fstream>
#include <vector>


using namespace std;

struct Record {
  string initial;
  vector<string (*) (string)> funcs;
  
  string get() {
    // WARNING: bug! will keep trying to pop forever!
    initial = funcs[0](initial);
    funcs.erase(funcs.begin());
    return initial;
  }
};

string id(string x) { return x; }

shared_ptr<Record> ret(string x) {
  shared_ptr<Record> toR(new Record);
  toR->initial = x;
  toR->funcs.clear();
  toR->funcs.push_back(id);
  return toR;
}

shared_ptr<Record> bind(shared_ptr<Record> p, 
                        string (*func) (string)) {
              
  p->funcs.push_back(func);
  return p;
}

string step1(string x) {
  return "Hello " + x;
}

string step2(string x) {
  return x + ", monads aren't that complicated."; 
}

string step3(string x) {
  return "***" + x + "***";
}

template<typename T> T run(T (*ret) (string), 
                           T (*bind) (T, string (*) (string))) {
  
  // first, wrap up the value of x
  T x = ret("friend");
  
  // now use wrapCall to run each step
  x = bind(x, step1);
  x = bind(x, step2);
  x = bind(x, step3);
  return x;
}

int main() {
  shared_ptr<Record> result = run<shared_ptr<Record>>(ret, bind);
  cout << "nothing has been computed yet!" << "\n";
  
  cout << result->get() << "\n"; // initial value
  cout << result->get() << "\n"; // step1
  cout << result->get() << "\n"; // step2
  cout << result->get() << "\n"; // step3

}

The result is that the first call to get gives you the result from step1, the next call to get gives you the result from step2, and the final call gives you the result of step3. Many of you may recognize this as a coroutine. My implementation is not very safe, and further calls to get will result in an error.

The final behavior of the monad is a lot like a Java iterator, a Python generator, JavaScript generator functions, “Goroutines” with channels in Golang (almost), Boost coroutine for C++, or Ruby Fiber. So the continuation monad doesn’t necessarily give you anything that your language didn’t have beforehand. Again, the point is that every monad we’ve discussed – the debugging monad, the continuation monad, etc. – is entirely generic and be substituted easily.

The monadic design pattern

While many of the more formal monads above can be useful, real-world systems are generally very large, and refactoring them to “support” monads (like the ret and bind functionsmethodsfunctionsfunctions in the examples above) is probably not feasible. But the insight gained from these simple examples can be applied as a “design pattern” to many problems often encountered in real-world code.

For example, consider the code below that accesses the value 6, which is deeply nested inside dictionariesMapsa JSON objectstd::maps . In the real world, the path might look something more like users -> data -> contact -> phone -> mobile, but the idea is the same.

import json
jdata = """{"a": {"b": {"c": {"d": {"e": 6}}}}}"""
jdata = json.loads(jdata)


def extract_value(data):
  return data["a"]["b"]["c"]["d"]["e"]

print(extract_value(jdata))
import java.util.*;

class Main {
  
  // {"a": {"b": {"c": {"d": {"e": 6}}}}}
  static Record root;

  
  public static void main(String[] args) {
    System.out.println(extractValue(root));
  }
  
  public static String extractValue(Record data) {
    return data.get("a").get("b").get("c").get("d").get("e").value;
  }
  
  
  static class Record {
    public Record() {
      children = new HashMap<>();
    }
  
    public Map<String, Record> children;
    public String value;
    
    public Record get(String s) { return children.get(s); }
  }
  
  static {
    Record r1 = new Record();
    r1.value = "6";
    
    Record r2 = new Record();
    r2.children.put("e", r1);
    
    Record r3 = new Record();
    r3.children.put("d", r2);
    
    Record r4 = new Record();
    r4.children.put("c", r3);
    
    Record r5 = new Record();
    r5.children.put("b", r4);
    
    root = new Record();
    root.children.put("a", r5);
  }
}
jdata = {a: {b: {c: {d: {e: 6}}}}};

function extractValue(data) {
  return data.a.b.c.d.e;
}

console.log(extractValue(jdata));
#include <iostream>
#include <string>
#include <map>

using namespace std;

map<string, void*>* get(string key, map<string, void*>* m){
  map<string, void*> x = *m;
  return (map<string, void*>*) x[key];
}


int* extractValue(map<string, void*>* m) {
  map<string, void*>* x1 = get("a", m);
  map<string, void*>* x2 = get("b", x1);
  map<string, void*>* x3 = get("c", x2);
  map<string, void*>* x4 = get("d", x3);
  int* x5 = (int*) get("e", x4);
  
  return x5;
}

int main() {
  int i = 6;
  map<string, void*> m;
  m["e"] = &i;
  
  map<string, void*> m1;
  m1["d"] = &m;
  
  map<string, void*> m2;
  m2["c"] = &m1;
  
  map<string, void*> m3;
  m3["b"] = &m2;
  
  map<string, void*> m4;
  m4["a"] = &m3;
  // m4 is the map: {"a": {"b": {"c": {"d": {"e": 6}}}}}

  int* value = extractValue(&m4);
  
  if (value == NULL) {
    cout << "value was null!" << "\n";
  } else {
    cout << "value: " << *value << "\n";
  }
  
}

Of course, this code isn’t safe, since if the input data doesn’t have exactly the right path, an error will occur. Verify this by changing either the access path or the data itself. So instead, we have to write code that “looks before leaping”:

import json
jdata = """{"a": {"b": {"c": {"d": {"e": 6}}}}}"""
jdata = json.loads(jdata)


def extract_value(data):
  if "a" not in data:
    return None
    
  if "b" not in data["a"]:
    return None
    
  if "c" not in data["a"]["b"]:
    return None
    
  if "d" not in data["a"]["b"]["c"]:
    return None
    
  if "e" not in data["a"]["b"]["c"]["d"]:
    return None
    
  return data["a"]["b"]["c"]["d"]["e"]

print(extract_value(jdata))
import java.util.*;

class Main {
  
  // {"a": {"b": {"c": {"d": {"e": 6}}}}}
  static Record root;

  
  public static void main(String[] args) {
    System.out.println(extractValue(root));
  }
  
  public static String extractValue(Record data) {
    if (!data.has("a"))
      return null;
  
    if (!data.get("a").has("b"))
      return null;
      
    if (!data.get("a").get("b").has("c"))
      return null;
    
    if (!data.get("a").get("b").get("c").has("d"))
      return null;
      
    if (!data.get("a").get("b").get("c").get("d").has("e"))
      return null;
      
    return data.get("a").get("b").get("c").get("d").get("e").value;
  }
  
  
  static class Record {
    public Record() {
      children = new HashMap<>();
    }
  
    public Map<String, Record> children;
    public String value;
    
    public Record get(String s) { return children.get(s); }
    public boolean has(String s) { return children.containsKey(s); }
  }
  
  static {
    Record r1 = new Record();
    r1.value = "6";
    
    Record r2 = new Record();
    r2.children.put("e", r1);
    
    Record r3 = new Record();
    r3.children.put("d", r2);
    
    Record r4 = new Record();
    r4.children.put("c", r3);
    
    Record r5 = new Record();
    r5.children.put("b", r4);
    
    root = new Record();
    root.children.put("a", r5);
  }
}
jdata = {a: {b: {c: {d: {e: 6}}}}};

function extractValue(data) {
  if (! ("a" in data))
    return null;
    
  if (! ("b" in data.a))
    return null;
    
  if (! ("c" in data.a.b))
    return null;
    
    
  if (! ("d" in data.a.b.c))
    return null;
    
  if (! ("e" in data.a.b.c.d))
    return null;
  
  return data.a.b.c.d.e;
}

console.log(extractValue(jdata));
#include <iostream>
#include <string>
#include <map>

using namespace std;

map<string, void*>* get(string key, map<string, void*>* m){
  map<string, void*> x = *m;
  return (map<string, void*>*) x[key];
}


int* extractValue(map<string, void*>* m) {
  
  if (m->find("a") == m->end())
     return NULL;
     
  map<string, void*>* x1 = get("a", m);
  
  if (x1->find("b") == x1->end())
    return NULL;
  
  map<string, void*>* x2 = get("b", x1);
  
  if (x2->find("c") == x2->end())
    return NULL;
  
  map<string, void*>* x3 = get("c", x2);
  
  if (x3->find("d") == x3->end())
    return NULL;
  
  map<string, void*>* x4 = get("d", x3);
  
  if (x4->find("e") == x4->end())
    return NULL;
  
  int* x5 = (int*) get("e", x4);
  return x5;
}

int main() {
  int i = 6;
  map<string, void*> m;
  m["e"] = &i;
  
  map<string, void*> m1;
  m1["d"] = &m;
  
  map<string, void*> m2;
  m2["c"] = &m1;
  
  map<string, void*> m3;
  m3["b"] = &m2;
  
  map<string, void*> m4;
  m4["a"] = &m3;
  // m4 is the map: {"a": {"b": {"c": {"d": {"e": 6}}}}}

  int* value = extractValue(&m4);
  
  if (value == NULL) {
    cout << "value was null!" << "\n";
  } else {
    cout << "value: " << *value << "\n";
  }
  
}

Instead, we could use try/excepttry/catchtry/catchtry/catch to “ask for forgiveness instead of permission”:

import json
jdata = """{"a": {"b": {"c": {"d": {"e": 6}}}}}"""
jdata = json.loads(jdata)


def extract_value(data):
  try:
    return data["a"]["b"]["c"]["d"]["e"]
  except:
    return None
    
print(extract_value(jdata))
import java.util.*;

class Main {
  
  // {"a": {"b": {"c": {"d": {"e": 6}}}}}
  static Record root;

  
  public static void main(String[] args) {
    System.out.println(extractValue(root));
  }
  
  public static String extractValue(Record data) {
    try {
      return data.get("a").get("b").get("c").get("d").get("e").value;
    } catch (Exception e) {
      return null;
    }
  }
  
  
  static class Record {
    public Record() {
      children = new HashMap<>();
    }
  
    public Map<String, Record> children;
    public String value;
    
    public Record get(String s) { return children.get(s); }
    public boolean has(String s) { return children.containsKey(s); }
  }
  
  static {
    Record r1 = new Record();
    r1.value = "6";
    
    Record r2 = new Record();
    r2.children.put("e", r1);
    
    Record r3 = new Record();
    r3.children.put("d", r2);
    
    Record r4 = new Record();
    r4.children.put("c", r3);
    
    Record r5 = new Record();
    r5.children.put("b", r4);
    
    root = new Record();
    root.children.put("a", r5);
  }
}
jdata = {a: {b: {c: {d: {e: 6}}}}};

function extractValue(data) {
  try {
    return data.a.b.c.d.e;
  } catch (e) {
    return null;
  }
}

console.log(extractValue(jdata));
#include <iostream>
#include <string>
#include <map>

using namespace std;

map<string, void*>* get(string key, map<string, void*>* m){
  map<string, void*> x = *m;
  return (map<string, void*>*) x.at(key);
}


int* extractValue(map<string, void*>* m) {
  try {
    map<string, void*>* x1 = get("a", m);
    map<string, void*>* x2 = get("b", x1);
    map<string, void*>* x3 = get("c", x2);
    map<string, void*>* x4 = get("d", x3);
    int* x5 = (int*) get("e", x4);
    return x5;
  } catch (const std::exception& e) {
    return NULL;
  }
}

int main() {
  int i = 6;
  map<string, void*> m;
  m["e"] = &i;
  
  map<string, void*> m1;
  m1["d"] = &m;
  
  map<string, void*> m2;
  m2["c"] = &m1;
  
  map<string, void*> m3;
  m3["b"] = &m2;
  
  map<string, void*> m4;
  m4["a"] = &m3;
  // m4 is the map: {"a": {"b": {"c": {"d": {"e": 6}}}}}

  int* value = extractValue(&m4);
  
  if (value == NULL) {
    cout << "value was null!" << "\n";
  } else {
    cout << "value: " << *value << "\n";
  }
  
}

… but this approach has performance implications in many languages (stack unwinding) and requires us a write a “lot” of boilerplate code. Let’s see if we can come up with something better by thinking with monads! Instead of thinking of bind as a function that takes in a wrapped value and another function, think of bind as a function that takes in a dictMapJSON objectstd::map and a key and produces another dictMapJSON objectstd::map by looking up that key. If the dictMapJSON objectstd::map passed to bind is NonenullfalseNULL , we simply return NonenullnullNULL again.

import json
jdata = """{"a": {"b": {"c": {"d": {"e": 6}}}}}"""
jdata = json.loads(jdata)


def extract_value(data, ret, bind):
  data = ret(data)
  data = bind(data, "a")
  data = bind(data, "b")
  data = bind(data, "c")
  data = bind(data, "d")
  data = bind(data, "e")
  return data
  
def ret(x):
  return x
  
def bind(data, field):
  if data == None or field not in data:
    return None
  return data[field]
  
print(extract_value(jdata, ret, bind))
    
import java.util.*;
import java.util.function.*;

class Main {
  
  // {"a": {"b": {"c": {"d": {"e": 6}}}}}
  static Record root;

  
  public static void main(String[] args) {
    System.out.println(extractValue(root, Main::ret, Main::bind));
  }
  
  public static String extractValue(Record data, Function<Record, Record> ret, BiFunction<Record, String, Record> bind) {
    
    Record r = ret.apply(data);
    r = bind.apply(r, "a");
    r = bind.apply(r, "b");
    r = bind.apply(r, "c");
    r = bind.apply(r, "d");
    r = bind.apply(r, "e");
    
    return (r == null ? null : r.value);
  }
  
  public static Record ret(Record r) {
    return r;
  }
  
  public static Record bind(Record r, String field) {
    if (r == null || !r.has(field))
      return null;
      
    return r.get(field);
  }
  
  
  static class Record {
    public Record() {
      children = new HashMap<>();
    }
  
    public Map<String, Record> children;
    public String value;
    
    public Record get(String s) { return children.get(s); }
    public boolean has(String s) { return children.containsKey(s); }
  }
  
  static {
    Record r1 = new Record();
    r1.value = "6";
    
    Record r2 = new Record();
    r2.children.put("e", r1);
    
    Record r3 = new Record();
    r3.children.put("d", r2);
    
    Record r4 = new Record();
    r4.children.put("c", r3);
    
    Record r5 = new Record();
    r5.children.put("b", r4);
    
    root = new Record();
    root.children.put("a", r5);
  }
}
jdata = {a: {b: {c: {d: {e: 6}}}}};

function extractValue(data, ret, bind) {
  data = ret(data);
  data = bind(data, "a");
  data = bind(data, "b");
  data = bind(data, "c");
  data = bind(data, "d");
  data = bind(data, "e");
  return data;
}

function ret(x) {
  return x;
}

function bind(data, field) {
  if (data === null || !(field in data))
    return null;
    
  return data[field];
}

console.log(extractValue(jdata, ret, bind));
#include <iostream>
#include <string>
#include <map>

using namespace std;

map<string, void*>* ret(map<string, void*>* m) {
  return m;
}

map<string, void*>* bind(map<string, void*>* m, string key){
  if (m == NULL || m->find(key) == m->end())
    return NULL;
    
  map<string, void*> x = *m;
  return (map<string, void*>*) x[key];
}


int* extractValue(map<string, void*>* m) {
  auto x = ret(m);
  auto x1 = bind(x,  "a");
  auto x2 = bind(x1, "b");
  auto x3 = bind(x2, "c");
  auto x4 = bind(x3, "d");
  auto x5 = bind(x4, "e");
  return (int*)x5;
}

int main() {
  int i = 6;
  map<string, void*> m;
  m["e"] = &i;
  
  map<string, void*> m1;
  m1["d"] = &m;
  
  map<string, void*> m2;
  m2["c"] = &m1;
  
  map<string, void*> m3;
  m3["b"] = &m2;
  
  map<string, void*> m4;
  m4["a"] = &m3;
  // m4 is the map: {"a": {"b": {"c": {"d": {"e": 6}}}}}

  int* value = extractValue(&m4);
  
  if (value == NULL) {
    cout << "value was null!" << "\n";
  } else {
    cout << "value: " << *value << "\n";
  }
  
}

Now we use several calls to bind in order to access each key in the dictMapJSON objectstd::map . If the key doesn’t exist, bind gives back NonenullnullNULL , and all future calls to bind also give back NonenullnullNULL . Now the code is much safer, and changing the data or the access path won’t result in a crash. This type of monad (where the value can suddenly become “bad” and all future bind calls return the same “bad” value) is often called the Maybe monad, and it solves the problem.

… but we certainly didn’t reduce the amount of boilerplate from the “ask for forgiveness” example! In fact, we’ve added a lot more! Next, let’s combine some object-orientedobject-orientedobject-oriented programming know-how to make our monad easier to use.

import json
jdata = """{"a": {"b": {"c": {"d": {"e": 6}}}}}"""
jdata = json.loads(jdata)


class MonadicDict:
    def __init__(self, data):
      self.data = data
    
    def __getitem__(self, field):
      if self.data == None or field not in self.data:
        return MonadicDict(None)
      return MonadicDict(self.data[field])
      
    def get(self):
      return self.data
      

print(MonadicDict(jdata)["a"]["b"]["c"]["d"]["e"].get())
import java.util.*;

class Main {
  
  // {"a": {"b": {"c": {"d": {"e": 6}}}}}
  static Record root;

  
  public static void main(String[] args) {
    System.out.println(extractValue(root));
  }
  
  public static String extractValue(Record data) {
    MonadicRecord mr = new MonadicRecord(data);
    
    return mr.get("a").get("b").get("c").get("d").get("e").getValue();
  }
  
  static class MonadicRecord {
    private Record r;
    public MonadicRecord(Record r) {
      this.r = r;
    }
    
    public MonadicRecord get(String field) {
      if (r == null)
        return new MonadicRecord(null);
        
      return new MonadicRecord(r.get(field));
    }
    
    public String getValue() {
      if (r == null)
        return null;
        
      return r.value;
    }
  }
  
  static class Record {
    public Record() {
      children = new HashMap<>();
    }
  
    public Map<String, Record> children;
    public String value;
    
    public Record get(String s) { return children.get(s); }
    public boolean has(String s) { return children.containsKey(s); }
  }
  
  static {
    Record r1 = new Record();
    r1.value = "6";
    
    Record r2 = new Record();
    r2.children.put("e", r1);
    
    Record r3 = new Record();
    r3.children.put("d", r2);
    
    Record r4 = new Record();
    r4.children.put("c", r3);
    
    Record r5 = new Record();
    r5.children.put("b", r4);
    
    root = new Record();
    root.children.put("a", r5);
  }
}
jdata = {a: {b: {c: {d: {e: 6}}}}};

class MonadicJSON {
  constructor(data) {
    this.data = data;
  }
  
  getField(x) {
    if (this.data === null)
      return this;
    
    if (x in this.data)
      return new MonadicJSON(this.data[x]);
    
    return new MonadicJSON(null);
  }
}

const mdata = new MonadicJSON(jdata);

const result = mdata.getField("a").getField("b").getField("c").getField("d").getField("e");

console.log(result.data);
#include <iostream>
#include <string>
#include <map>
#include <memory>

using namespace std;


class Record {
public:
  Record(bool isMap, void* ptr) : isMap(isMap), ptr(ptr) { }
  bool isMap;
  void* ptr;
};

class MonadicMap {
  void* m;
  bool isValue;
public:
  MonadicMap(void* m) : m(m) { isValue = false; }
  void* asPtr() { return m; }
  
  MonadicMap operator[](const string& s) {
    if (isValue || m == NULL)
      return MonadicMap(NULL);
      
    map<string, Record*> x = *(map<string, Record*>*)m;
    
    if (x.find(s) == x.end())
      return MonadicMap(NULL);
    
    MonadicMap m = MonadicMap(x[s]->ptr);
    m.isValue = ! x[s]->isMap;
    return m;
  }
};


int* extractValue(map<string, Record*>* data) {
  MonadicMap m(data);
  return (int*) (m["a"]["b"]["c"]["d"]["e"].asPtr());
  
}

int main() {
  
  int i = 6;
  map<string, Record*> m;
  map<string, Record*> m1;
  map<string, Record*> m2;
  map<string, Record*> m3;
  map<string, Record*> m4;

  Record r1(true, &m1);
  Record r2(true, &m2);
  Record r3(true, &m3);
  Record r4(true, &m4);
  Record r5(false, &i);

  m["a"] = &r1;
  m1["b"] = &r2;
  m2["c"] = &r3;
  m3["d"] = &r4;
  m4["e"] = &r5;

  
  // m4 is the map: {"a": {"b": {"c": {"d": {"e": 6}}}}}

  int* value = extractValue(&m);
  
  if (value == NULL) {
    cout << "value was null!" << "\n";
  } else {
    cout << "value: " << *value << "\n";
  }
  
}

Here, we’ve created a class with a __getitem__ functionget methodgetField methodoperator[] function that does the same thing as bind did in our previous example. Now we can use the . operatorthis get methodthis getField methodthe overloaded [] operator to safely access arbitrary fields in nested data! No stack unwinding, no boilerplate, totally safe.

Purists will raise their voices in anger and insist that this isn’t really a monad (actually, most of the purists have probably made a few complaints already). Depending on how pedantic you want to be, this class may or may not be considered monadic, but it is a useful construction that came about from monadic thinking.

How I Learned to Stop Worrying and Love the Monad

Hopefully you have now gained at least a little sympathy for the Haskell purist who laments Python’sJava’sJavaScript’sC++’s lack of 1st class support for monads. In languages like Haskell, applying a monad to a function doesn’t require you to refactor that function into a bunch of calls to bind: every function actually already uses the identity/trivial monad, and you can swap out the implementation anytime you want. At the same time, monadic concepts can be used effectively in many other languages, including popular ones like Python, Java, JavaScript, and C++. While your language may lack “proper” support for monads, don’t be afraid to try “monads as a design pattern” when it seems appropriate.

Finally, try not to become part of the problem. Unless your in a context where you know everyone will know what you are talking about, don’t use the word “endofunctor.” When someone asks what Java’s Optional is, don’t just smugly reply “oh, it’s the Maybe monad.” And, please, please, when someone asks “what’s a monad?”, don’t say “a monad is like a burrito.”