State monad

2022-02-06

This post explores the state monad in Haskell. Don’t forget to checkout the previous post about Functor map.

Context

Let’s consider this counter implementation in Elixir:

defmodule MyApp.Counter do
  use GenServer
  def handle_call({:increment, n}, _from, state) do
    state = state + n
    {:reply, state, state}
  end
end

This module doesn’t own a state, instead it returns a new state on each invocation. The genserver provides an implementation to pass the state like this:

defp receive_loop(state) do
  new_state =
    receive do
      msg -> handle_call(msg, state)
    end
  receive_loop(new_state)
end

This model, where the function takes its state as an input, and returns a new state can also be implemented in Python like this:

def increment(state):
    state = state + 1
    return (state, state)

(value, my_counter) = increment(0)
(value, my_counter) = increment(my_counter)
(value, my_counter) = increment(my_counter)

Notice how such stateless functions return both a new state, and its output value. We might want to return a different value, for example:

def show_counter(self):
    return ("Counter is at %d" % self, self)

print(show_counter(my_counter))
# Shows ("Counter is at 3", 3)

By eliminating the ability to mutate a state outside of scope, functional programs are often easier to follow and contain fewer bugs.

Let’s dive into the State Monad, and how it can be used for this pattern.

StateT

Start a REPL by running ghci (if missing, run sudo dnf install ghc ghc-mtl-devel).

λ> import Control.Monad.State
λ> :t StateT
StateT :: (s -> m (a, s)) -> StateT s m a

The StateT has 3 type variables, s is the type of the state, m is the execution context, and a is the return type. Thus we can define the increment and show_counter like this:

λ> type Counter a = StateT Int IO a
λ> increment    = modify (+1) >> get                         :: Counter Int
λ> show_counter = mappend "Counter is at " . show <$> get    :: Counter String

And finally we can write the three increments test as:

λ> test_counter = increment >> increment >> increment >> show_counter
λ> :t test_counter
test_counter :: Counter String

… that we can evaluate like this:

λ> :t runStateT
runStateT :: Counter a -> s -> m (a, s)
λ> runStateT test_counter 0
# Shows ("Counter is at 3",3)

StateT step by step

So how does that works? Let’s have a closer look step by step.

We first defined a Counter a type alias for StateT Int IO a. This type alias indicates that Counter is a computation happening in a StateT context, with a state of type Int (the counter), running in a IO context. In other words, to obtain the a of a Counter a, we need to use the runStateT to perform the computation.

The benefit is that the increment and show_counter implementation does not have to deal with the state explicitly, the StateT context provides a few helpers:

λ> :t modify
modify :: MonadState s m => (s -> s) -> m ()
λ> :t get
get :: MonadState s m => m s

And this is possible because StateT has an instance of the MonadState. This post is not going to try to explain what a Monad is, but here is its definition:

class Applicative m => Monad m where
  (>>=) :: m a -> (a -> m b) -> m b
  (>>) :: m a -> m b -> m b
  pure :: a -> m a

Or in other words, a Monad is an Applicative context, which provides the >>= operator, also known as bind (or flat_map or then).

Let’s break down the increment definition.

λ> :t modify (+1)
modify (+1)        :: (MonadState Int m) => m ()
λ> :t modify (+1) >> get
modify (+1) >> get :: (MonadState Int m) => m Int

This sequencing pattern is so powerful, it has its own syntax, called the do notation. Here is the increment function defined using the do notation.

increment :: Counter Int
increment = do
  modify (+1) -- increment the counter
  get         -- return the current state

To enter multiline definition in the REPL, first enter :{, then paste the code, and finish with :}.

Similarly, the show_counter can be defined as:

show_counter :: Counter String
show_counter = do
  value <- get
  pure $ "Counter is at " <> show value

And finally, the test_counter:

test_counter :: Counter String
test_counter = do
  increment
  increment
  increment
  show_counter

This test_counter definition is similar to the previous version we saw. It relies on the StateT context to thread the state value in and out of each context.

Recap

First, we saw how to implement an immutable counter in Python using the same design as Elixir GenServer. This can be done using a function that takes the current state s and return a new state along with the output value a.

Then, we observed that this is actually the definition of the haskell’s StateT: s -> m (a, s) (a function that takes a s, and return a (a, s) inside a context m).

And we re-implemented the counter examples using the MonadState which takes care of passing the state around to simplify the implementation down to the essential parts.

Motivating example

Such abstractions are general purpose, and Monad can be used for other things. In the previous post we saw the Functor map and how traverse can be used to penetrate nested structures. Well we can use traverse with StateT. Let’s consider a new function to add a number to our counter:

λ> add = modify . (+) :: Int -> Counter ()

… which can be used like this

λ> flip execStateT 0 $ add 5
5

Then we can use traverse to apply the add function to a list of numbers:

λ> flip execStateT 0 $ traverse add [5, 11, 11, 15]
42

Cheers o/

#haskell #blog