Compositional zooming for StateT and ReaderT using lens

Tuesday, 04 September 2018, by Edsko de Vries.
Filed under coding.

Consider writing updates in a state monad where the state contains deeply nested structures. As our running example we will consider a state containing multiple “wallets”, where each wallet has multiple “accounts”, and each account has multiple “addresses”. Suppose we want to write an update that changes one of the fields in a particular address. If the address cannot be found, we want a precise error message that distinguishes between the address itself not being found, or one of its parents (the account, or the wallet) not being found. Without the help of suitable abstractions, we might end up writing something monstrous like:

setUsed :: AddrId -> Update UnknownAddr DB ()
setUsed addrId@(accId@(walletId, accIx), addrIx) = do
    db <- get
    -- find the wallet
    case Map.lookup walletId db of
      Nothing ->
        throwError $ UnknownAddrParent
                   $ UnknownAccParent
                   $ UnknownWalletId walletId
      Just wallet ->
        -- find the account
        case Map.lookup accIx wallet of
          Nothing ->
            throwError $ UnknownAddrParent
                       $ UnknownAccId accId
          Just acc ->
            -- find the address
            case Map.lookup addrIx acc of
              Nothing ->
                throwError $ UnknownAddrId addrId
              Just (addr, _isUsed) -> do
                let acc'    = Map.insert addrIx (addr, True) acc
                    wallet' = Map.insert accIx acc' wallet
                    db'     = Map.insert walletId wallet' db
                put db'

In the remainder of this blog post we will show how we can develop some composable abstractions that will allow us to rewrite this as

setUsed :: AddrId -> Update UnknownAddr DB ()
setUsed addrId =
    zoomAddress id addrId $
      modify $ \(addr, _isUsed) -> (addr, True)

for an appropriate definition of zoomAddress given later.

Zooming

To obtain compositionality, we want to be able to lift updates on a smaller context (such as a particular wallet) to a larger context (the entire state). In order to do that, we will need a way to get the smaller context from the larger, and to be able to lift modifications of the smaller context to modifications of the larger context. This is of course precisely the definition of a lens, and so we arrive at the following signature:

zoom :: Lens' st st' -> State st' a -> State st  a

For the purposes of the first part this blog post we will define State in a somewhat unusual way as

newtype Result a st = Result { getResult :: (a, st) }
type State st a = st -> Result a st

It will become evident why we choose this definition soon; for now, if you squint a bit you can hopefully see that this is equivalent to the state monad we all know and love. A somewhat naive way to write zoom is

zoom :: Lens' st st' -> State st' a -> State st  a
zoom l f large = fmap updSmall $ f (large ^. l)
  where
    updSmall small' = large & l .~ small'

This definition clearly demonstrates what we said above: we use the lens to first get the small state from the large, run the update on that smaller state, and finally use the lens once more to update the larger state with the new value of the smaller state, relying on the fact that Result is a Functor.

If we are using lenses in Van Laarhoven representation, however, we can actually write this in a more direct way. Expanding synoynms, we get

zoom :: (forall f. Functor f => (st' -> f st')  
                             -> (st  -> f st))
     -> (st' -> Result a st')
     -> (st  -> Result a st)

Note how if we take advantage of our somewhat unusual representation of the state monad, we can instantiate f to Result a, so that lens already gives us precisely what we need! In other words, we can rewrite zoom as simply

zoom :: Lens' st st' -> State st' a -> State st  a
zoom = id

Dealing with failure

In order to deal with missing values, we need a variation on zoom:

zoomM :: Lens' st (Maybe st') -> State st' a -> State st (Maybe a)

We can write this in a naive way again, being very explicit about what’s happening:

zoomM :: Lens' st (Maybe st') -> State st' a -> State st (Maybe a)
zoomM l f large =
    case large ^. l of
      Nothing    -> Result (Nothing, large)
      Just small -> bimap Just (updSmall . Just) $ f small
  where
    updSmall small' = large & l .~ small'

As before, we first use the lens to get the smaller state from the larger. This may now fail; if it does, we return Nothing as the result along with the unchanged state. If the smaller state does exist, we run the update on that smaller state, and then wrap its result in Just; this relies on the fact that Result is a Bifunctor. In case you haven’t seen that class before, it’s the “obvious” generalization of Functor to datatypes with two type arguments:

class Bifunctor p where
  bimap :: (a -> b) -> (c -> d) -> p a c -> p b d  

The instance for Result is easy:

instance Bifunctor Result where
  bimap f g (Result (a, st)) = Result (f a, g st)

As before, however, we can use the lens in a more direct way. Expanding synonyms once again, we get:

zoomM :: (forall f. Functor f => (Maybe st' -> f (Maybe st'))
                              -> (st -> f st))
      -> (st' -> Result a st')
      -> (st  -> Result (Maybe a) st)

If we line up the result of the lens with the result we want from zoomM, we see that we must pick Result (Maybe a) for f; all that remains is writing a suitable wrapper:

liftMaybe :: Biapplicative p
          => (st -> p a st) -> Maybe st -> p (Maybe a) (Maybe st)
liftMaybe _ Nothing   = bipure Nothing Nothing
liftMaybe f (Just st) = bimap Just Just $ f st

This relies on Result being Biapplicative, which is again the “obvious” generalization of Applicative to datatypes with two arguments:

class Bifunctor p => Biapplicative p where
  bipure  :: a -> b -> p a b
  (<<*>>) :: p (a -> b) (c -> d) -> p a c -> p b d

The instance for Result again is straight-forward:

instance Biapplicative Result where
  bipure a st = Result (a, st)
  Result (f, g) <<*>> Result (a, st) = Result (f a, g st)

This out of the way, we can now define zoomM as

zoomM :: Lens' st (Maybe st') -> State st' a -> State st (Maybe a)
zoomM l = l . liftMaybe

Generalizing

So far we have been using a non-standard definition of the state monad. In this section we will see how we can avoid doing that and, more importantly, how we can write our zooming combinators in such a way that they can be used also in the reader monad.

Let’s define a monad for updates and a monad for queries using the standard monad transformers:

newtype Update e st a = Update {
    runUpdate :: StateT st (Except e) a
  }
  deriving ( Functor, Applicative
           , Monad, MonadState st, MonadError e )

newtype Query e st a = Query {
    runQuery :: ReaderT st (Except e) a
  }
  deriving ( Functor, Applicative
           , Monad, MonadReader st, MonadError e )

We want to be able to “zoom” in either of these two monads. We saw above that the key to be able to use the lens directly is the ability to express our update as a function

st -> f st

for some suitable functor f. For zoom we picked Result a, for zoomM we picked Result (Maybe a). The choice of Result, however, was specific to our concrete definition of State. If we want to generalize, we need to generalize away from this type:

class Biapplicative (Result z) => Zoomable z where
  type Result z :: * -> * -> *

  wrap   :: (st -> Result z a st) -> z st a
  unwrap :: z st a -> (st -> Result z a st)

In this type class we introduce a type family Result that we can instantiate to different types for different monads; wrap and unwrap are necessary because unlike our bespoke State monad definition above, the conventional definition of the state monad is isomorphic, but not equal, to a function from a state to a state. We saw above why we need Result z to be Biapplicative.

Zoomable instance for Update

In order to be able to define a Zoomable instance for Update, we need to introduce a type that captures the result of an update:

newtype UpdResult e a st = UpdResult {
    getUpdResult :: Except e (a, st)
  }

Defining the Zoomable instance for UpdResult is now easy:

instance Zoomable (Update e) where
  type Result (Update e) = UpdResult e

  wrap   = coerce
  unwrap = coerce

Note that wrap and unwrap are simply coerce; in other words, they exist only to satisfy the type checker, but have no runtime cost.

Zoomable instances for Query

The nice thing is that we can just as easily give a Zoomable instance for Query. The only difference is that the result of the query does not have a final state:

newtype QryResult e a st = QryResult {
    getQryResult :: Except e a
  }

The Zoomable instance is just as simple:

instance Zoomable (Query e) where
  type Result (Query e) = QryResult e

  wrap   = coerce
  unwrap = coerce

Functor from Bifunctor

If we now try to define zoom for any Zoomable monad, we find that we get stuck very quickly: in order to be able to apply the lens, we need Result z a to be a functor; but all we know is that Result z is a bifunctor. Starting from ghc 8.6 we could use quantified constraints and write

class ( Biapplicative (Result z)
      , forall a. Functor (Result z a)
      )        
   => Zoomable z where (..)

to insist that Result z a must be a functor for any choice of a. We could also add a Functor (Result z a) constraint to the type of zoom itself, but this gives zoom a more messy signature than it needs to have.

If we want to be compatible with older versions of ghc but still keep the nicer signature, we can take advantage of the fact that if a datatype is a bifunctor it must also be a functor:

newtype FromBi p a st = WrapBi { unwrapBi :: p a st }

instance Bifunctor p => Functor (FromBi p a) where
  fmap f (WrapBi x) = WrapBi (second f x)

Generalizing the zoom operators

We now have everything we need to give the generalized definitions of the zoom operators. In fact, the definition is almost dictated by the types:

zoom :: Zoomable z => Lens' st st' -> z st' a -> z st  a
zoom l k = wrap $ \st -> unwrapBi $ l (WrapBi . unwrap k) st

Although this looks more complicated than the definition we have before, note that

   zoom l k
 -- definition
== wrap $ \st -> unwrapBi $ l (WrapBi . unwrap k) st
 -- wrap and unwrap are both 'coerce'
== \st -> unwrapBi $ l (WrapBi . k) st
 -- unwrapBi and WrapBi are just newtype wrappers
== \st -> l k st
 -- eta-reduce
== l k

In other words, modulo newtype wrapping, we still have zoom = id. The definition of zoomM is similar to what we had above also:

zoomM :: Zoomable z
      => Lens' st (Maybe st')
      -> z st' a
      -> z st (Maybe a)
zoomM l k = wrap $ \st -> unwrapBi $
              l (WrapBi . liftMaybe (unwrap k)) st

The proof that this is equivalent to simply l (liftMaybe k) is left as a simple exercise for the reader.

Finally, we can define a useful variation on zoomM that uses a fallback when the smaller context was not found:

zoomDef :: (Zoomable z, Monad (z st))
        => Lens' st (Maybe st')
        -> z st  a -- ^ When not found
        -> z st' a -- ^ When found
        -> z st  a
zoomDef l def k = zoomMaybe l k `catchNothing` def

where

catchNothing :: Monad m => m (Maybe a) -> m a -> m a
catchNothing act fallback = act >>= maybe fallback return

Using the combinators

We will now go back to the example from the introduction and show how we can write some domain-specific zoom operators using the building blocks we just defined.

Setup

The example is a state consisting of multiple wallets, where each wallet has multiple accounts, and each account has multiple addresses. For the sake of this blog post it doesn’t really matter what “wallets”, “accounts” and “addresses” are, and we will model them very simply as

type DB      = Map WalletId Wallet
type Wallet  = Map AccIx    Account
type Account = Map AddrIx   Address
type Address = (String, Bool)

The top-level state is a mapping from wallet IDs to wallets, but a wallet is a mapping from account indices to accounts. The reason for the difference is that we will reserve the term account ID for the combination of a wallet ID and an account index, and similarly for addresses:

type AccIx  = Int
type AddrIx = Int

type WalletId = Int
type AccId    = (WalletId, AccIx)
type AddrId   = (AccId, AddrIx)

Finally, the requirements stated that we wanted to distinguish between, say, an address not found because although the account exists, it doesn’t have that particular address, and an address not found because its enclosing account (or indeed wallet) does not exist:

data UnknownWallet = UnknownWalletId   WalletId
data UnknownAcc    = UnknownAccId      AccId    
                   | UnknownAccParent  UnknownWallet
data UnknownAddr   = UnknownAddrId     AddrId   
                   | UnknownAddrParent UnknownAcc

Zooming

Ok, definitions done, we can now define our zoom combinators. Our initial attempt might be something like

zoomWallet :: WalletId
           -> Update e Wallet a
           -> Update e DB     a

If the wallet ID was not found, however, we want to be able to throw an UnknownWallet error. We could change the signature to

zoomWallet :: WalletId
           -> Update UnknownWallet Wallet a
           -> Update UnknownWallet DB     a

but now we cannot use zoomWallet for updates with a richer error type. A better solution is to take as an argument a function that allows us to embed the UnknownWallet error into e:

zoomWallet :: (UnknownWallet -> e)
           -> WalletId
           -> Update e Wallet a
           -> Update e DB     a
zoomWallet embedErr walletId k =
    zoomDef (at walletId)
            (throwError $ embedErr (UnknownWalletId walletId)) $
      k

The definition is pleasantly straightforward. We use the at combinator from lens to give us a lens into the map, and then use zoomDef with a fallback that throws the error to complete the definition.

Composition

In order to show that our new combinators are compositional we should be able to define zoomAccount in terms of zoomWallet, and indeed we can:

zoomAccount :: (UnknownAcc -> e)
            -> AccId
            -> Update e Account a
            -> Update e DB      a
zoomAccount embedErr accId@(walletId, accIx) k =
    zoomWallet (embedErr . UnknownAccParent) walletId $
      zoomDef (at accIx)
              (throwError $ embedErr (UnknownAccId accId)) $
        k

Composing the zoom combinators is effectively lens composition, which is taking care of getting the account from the DB by first getting the account in one direction, and updating the DB by first lifting the update on the account to an update on the wallet, and then to an update on the DB itself.

The “embed error” argument is helping with compositionality also: zoomAccount needs its embedErr to embed UnknownAcc into e, but when it calls zoomWallet it composes embedErr with UnknownAccParent to embed UnknownWallet into e.

The definition for address follows the exact same pattern:

zoomAddress :: (UnknownAddr -> e)
            -> AddrId
            -> Update e Address a
            -> Update e DB      a
zoomAddress embedErr addrId@(accId, addrIx) k =
    zoomAccount (embedErr . UnknownAddrParent) accId $
      zoomDef (at addrIx)
              (throwError $ embedErr (UnknownAddrId addrId)) $
        k

so that we can now write the definition we promised in the introduction:

setUnused :: AddrId -> Update UnknownAddr DB ()
setUnused addrId =
    zoomAddress id addrId $
      modify $ \(addr, _isUsed) -> (addr, False)

Iteration

There is one additional zoom operator that is very useful to define. Suppose we want to clear out all wallets. If we tried to write this with the combinators we have so far, we would end up with something like

emptyAllWallets :: Update UnknownWallet DB ()
emptyAllWallets = do
    walletIds <- gets Map.keys
    forM_ walletIds $ \walletId ->
      zoomWallet id walletId $
        put Map.empty

We get all wallet IDs, then zoom to each wallet in turn and empty it. However, notice the signature: it indicates that emptyAllWallets may throw a UnknownWallet error—but it never will! After all, we just read all wallet IDs, so we know for a fact that they must be present. One “solution” is to do something like

emptyAllWallets :: Update e DB ()
emptyAllWallets = do
    walletIds <- gets Map.keys
    forM_ walletIds $ \walletId ->
      zoomWallet (\_err -> error "can't happen") walletId $
        put Map.empty

but we can do much better: we need a zoom operator that gives us iteration.

Traversals

In the world of lens, iteration is captured by a Traversal'. Compare the synoynms:

type Lens'      st st' = forall f. Functor     f => (st' -> f st')
                                                 -> (st  -> f st)
type Traversal' st st' = forall f. Applicative f => (st' -> f st')
                                                 -> (st  -> f st)

A traversal will apply its argument to all occurrences of the smaller state; in order to patch the results back together it needs f to be Applicative rather than merely a Functor.

Applicative from Biapplicative

Remember that the f we’re using in Zoomable is the Result z type family, which we know to be Biapplicative. We showed above that we can easily derive Functor from Bifunctor; deriving Applicative from Biapplicative, however, is not so easy! Let’s see what we need to do:

instance Biapplicative p => Applicative (FromBi p a) where
  pure st     = WrapBi $ bipure _e st
  fun <*> arg = WrapBi $ bimap _c ($) (unwrapBi fun)
                           <<*>> unwrapBi arg

There are two problematic holes in this definition:

The usual solution to this problem is to require a to be a monoid. Then we can use mempty for the absence of a result, and mappend to combine results:

instance ( Biapplicative p
         , Monoid a
         )
      => Applicative (FromBi p a) where
  pure st     = WrapBi $ bipure mempty st
  fun <*> arg = WrapBi $ bimap mappend ($) (unwrapBi fun)
                           <<*>> unwrapBi arg

Zooming

We can now define zoomAllM:

zoomAllM :: (Zoomable z, Monoid a)
         => Traversal' st st' -> z st' a -> z st a
zoomAllM l k = wrap $ \st -> unwrapBi $ l (WrapBi . unwrap k) st

Apart from the signature, the body of this function is literally identical to zoom, and is therefore also equivalent to simply id. Mind-blowing stuff.

We can define two useful wrappers for zoomAllM with slightly simpler types. The first is just a synoynm which can be used when we don’t want to accumulate any results:

zoomAll_ :: Zoomable z => Traversal' st st' -> z st' () -> z st ()
zoomAll_ = zoomAllM

This works because () is trivially a monoid. Finally we can define a wrapper that accumulates results in a list:

zoomAll :: Zoomable z => Traversal' st st' -> z st' a -> z st [a]
zoomAll l k = wrap $ \st -> unwrapBi $
                l (WrapBi . first (:[]) . unwrap k) st

We could have defined zoomAll in terms of zoomAllM if we insist that z st' is a Functor; by unfolding the definition we can take advantage of the fact that Result z is a bifunctor and we keep the signature clean.

Usage example

The example function we were considering was one that cleared out all wallets. With our new combinators in hand, this is now trivial:

emptyAllWallets :: Update e DB ()
emptyAllWallets = zoomAll_ traverse $ put Map.empty

Conclusions

As Haskell programmers, compositionality is one of our most treasured principles. The ability to build larger components from smaller, and reason about larger components by reasoning about the smaller, is crucial to productivity and clean, maintainable code. When dealing with large states (for example, in an acid-state database), lenses are a powerful tool that can be used to lift operations on parts of the state to the whole state. In this blog post we defined some reuseable combinators that can be used both in updates and in queries; they are used extensively in the design of the new Cardano wallet kernel.

Postscript: zoom from Control.Lens.Zoom

The lens library itself also defines a zoom operator. It has the same purpose as the zoom operator we defined here, but generalizes over the underlying monad in a different way (allowing for deeply nested occurrences of StateT in a monad stack), and is not applicable to the reader monad (the equivalent for the reader monad is magnify). However, if compatibility with ReaderT is not required then it is also possible to define zoom, zoomDef, and zoomAll in terms of the lens operator; domain specific combinators like zoomWallet can then be defined just like we have done here.