{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      : Data.Rhythm.Markov
-- Description : Transition matrices and Markov chains
-- Copyright   : (c) Eric Bailey, 2025
--
-- License     : MIT
-- Maintainer  : eric@ericb.me
-- Stability   : stable
-- Portability : POSIX
--
-- Generating random numbers using a Markov chain.
module Data.Rhythm.Markov
  ( TransitionMatrix (..),
    SomeTransitionMatrix (..),
    markovGen,
    markovGen',
    someTransitionMatrix,
  )
where

import Control.Arrow (second)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Loops (unfoldrM)
import Data.Finite (Finite, finite, getFinite)
import Data.Functor ((<&>))
import Data.List (intercalate)
import Data.Maybe (fromJust, fromMaybe)
import Data.Proxy (Proxy (..))
import Data.Vector.Sized (Vector)
import Data.Vector.Sized qualified as VS
import GHC.Generics (Generic)
import GHC.IsList (IsList (..))
import GHC.TypeNats (KnownNat, Nat, SomeNat (..), natVal, someNatVal)
import Slist (len)
import System.Random (randomIO)
import Text.Printf (printf)
import Text.Trifecta (Parser, count, decimal, double, newline)

-- $setup
-- >>> import Data.Ix (inRange)

-- | An \(n \times n\) transition matrix.
--
-- For example, the following is a @'TransitionMatrix' 3@.
--
-- \[
--   \begin{bmatrix}
--     0.1 & 0.6 & 0.3 \\
--     0.4 & 0.4 & 0.2 \\
--     0.3 & 0.3 & 0.4
--   \end{bmatrix}
-- \]
newtype TransitionMatrix (n :: Nat) = TransitionMatrix
  { forall (n :: Nat). TransitionMatrix n -> Vector n (Vector n Double)
unTransitionMatrix :: Vector n (Vector n Double)
  }
  deriving (TransitionMatrix n -> TransitionMatrix n -> Bool
(TransitionMatrix n -> TransitionMatrix n -> Bool)
-> (TransitionMatrix n -> TransitionMatrix n -> Bool)
-> Eq (TransitionMatrix n)
forall (n :: Nat). TransitionMatrix n -> TransitionMatrix n -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat). TransitionMatrix n -> TransitionMatrix n -> Bool
== :: TransitionMatrix n -> TransitionMatrix n -> Bool
$c/= :: forall (n :: Nat). TransitionMatrix n -> TransitionMatrix n -> Bool
/= :: TransitionMatrix n -> TransitionMatrix n -> Bool
Eq, (forall x. TransitionMatrix n -> Rep (TransitionMatrix n) x)
-> (forall x. Rep (TransitionMatrix n) x -> TransitionMatrix n)
-> Generic (TransitionMatrix n)
forall (n :: Nat) x.
Rep (TransitionMatrix n) x -> TransitionMatrix n
forall (n :: Nat) x.
TransitionMatrix n -> Rep (TransitionMatrix n) x
forall x. Rep (TransitionMatrix n) x -> TransitionMatrix n
forall x. TransitionMatrix n -> Rep (TransitionMatrix n) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (n :: Nat) x.
TransitionMatrix n -> Rep (TransitionMatrix n) x
from :: forall x. TransitionMatrix n -> Rep (TransitionMatrix n) x
$cto :: forall (n :: Nat) x.
Rep (TransitionMatrix n) x -> TransitionMatrix n
to :: forall x. Rep (TransitionMatrix n) x -> TransitionMatrix n
Generic)

instance Show (TransitionMatrix n) where
  show :: TransitionMatrix n -> String
show (TransitionMatrix Vector n (Vector n Double)
matrix) =
    String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
      (Vector n Double -> String) -> [Vector n Double] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ([String] -> String
unwords ([String] -> String)
-> (Vector n Double -> [String]) -> Vector n Double -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> String) -> [Double] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%.6f") ([Double] -> [String])
-> (Vector n Double -> [Double]) -> Vector n Double -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n Double -> [Double]
forall (n :: Nat) a. Vector n a -> [a]
VS.toList) ([Vector n Double] -> [String]) -> [Vector n Double] -> [String]
forall a b. (a -> b) -> a -> b
$
        Vector n (Vector n Double) -> [Vector n Double]
forall (n :: Nat) a. Vector n a -> [a]
VS.toList Vector n (Vector n Double)
matrix

-- | Existential wrapper around a square 'TransitionMatrix' of unknown size.
data SomeTransitionMatrix where
  SomeTransitionMatrix :: (KnownNat n) => TransitionMatrix n -> SomeTransitionMatrix

instance Show SomeTransitionMatrix where
  show :: SomeTransitionMatrix -> String
show (SomeTransitionMatrix TransitionMatrix n
matrix) = TransitionMatrix n -> String
forall a. Show a => a -> String
show TransitionMatrix n
matrix

instance IsList SomeTransitionMatrix where
  type Item SomeTransitionMatrix = [Double]

  fromList :: [Item SomeTransitionMatrix] -> SomeTransitionMatrix
fromList [Item SomeTransitionMatrix]
rows =
    SomeTransitionMatrix
-> Maybe SomeTransitionMatrix -> SomeTransitionMatrix
forall a. a -> Maybe a -> a
fromMaybe (String -> SomeTransitionMatrix
forall a. HasCallStack => String -> a
error String
"Invalid transition matrix") (Maybe SomeTransitionMatrix -> SomeTransitionMatrix)
-> Maybe SomeTransitionMatrix -> SomeTransitionMatrix
forall a b. (a -> b) -> a -> b
$
      case Nat -> SomeNat
someNatVal (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Slist [Double] -> Int
forall a. Slist a -> Int
len ([Item (Slist [Double])] -> Slist [Double]
forall l. IsList l => [Item l] -> l
fromList [Item (Slist [Double])]
[Item SomeTransitionMatrix]
rows))) of
        SomeNat (Proxy n
_ :: Proxy n) ->
          TransitionMatrix n -> SomeTransitionMatrix
forall (n :: Nat).
KnownNat n =>
TransitionMatrix n -> SomeTransitionMatrix
SomeTransitionMatrix (TransitionMatrix n -> SomeTransitionMatrix)
-> (Vector n (Vector n Double) -> TransitionMatrix n)
-> Vector n (Vector n Double)
-> SomeTransitionMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n (Vector n Double) -> TransitionMatrix n
forall (n :: Nat). Vector n (Vector n Double) -> TransitionMatrix n
TransitionMatrix
            (Vector n (Vector n Double) -> SomeTransitionMatrix)
-> Maybe (Vector n (Vector n Double)) -> Maybe SomeTransitionMatrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
VS.fromList @n ([Vector n Double] -> Maybe (Vector n (Vector n Double)))
-> Maybe [Vector n Double] -> Maybe (Vector n (Vector n Double))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ([Double] -> Maybe (Vector n Double))
-> [[Double]] -> Maybe [Vector n Double]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse [Double] -> Maybe (Vector n Double)
forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
VS.fromList [[Double]]
[Item SomeTransitionMatrix]
rows)

  toList :: SomeTransitionMatrix -> [Item SomeTransitionMatrix]
toList (SomeTransitionMatrix (TransitionMatrix Vector n (Vector n Double)
matrix)) =
    Vector n Double -> [Double]
forall (n :: Nat) a. Vector n a -> [a]
VS.toList (Vector n Double -> [Double]) -> [Vector n Double] -> [[Double]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector n (Vector n Double) -> [Vector n Double]
forall (n :: Nat) a. Vector n a -> [a]
VS.toList Vector n (Vector n Double)
matrix

-- | Parse a square 'TransitionMatrix' of unknown size.
someTransitionMatrix :: Parser SomeTransitionMatrix
someTransitionMatrix :: Parser SomeTransitionMatrix
someTransitionMatrix =
  do
    Int
n <- Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer -> Int) -> Parser Integer -> Parser Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Integer
forall (m :: * -> *). TokenParsing m => m Integer
decimal Parser Int -> Parser Char -> Parser Int
forall a b. Parser a -> Parser b -> Parser a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser Char
forall (m :: * -> *). CharParsing m => m Char
newline
    [[Double]] -> SomeTransitionMatrix
[Item SomeTransitionMatrix] -> SomeTransitionMatrix
forall l. IsList l => [Item l] -> l
fromList ([[Double]] -> SomeTransitionMatrix)
-> Parser [[Double]] -> Parser SomeTransitionMatrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Parser [Double] -> Parser [[Double]]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
count Int
n (Int -> Parser Double -> Parser [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
count Int
n Parser Double
forall (m :: * -> *). TokenParsing m => m Double
double)

-- | Generate random numbers using a Markov chain.
--
-- >>> let matrix = fromList [[0.1,0.6,0.3],[0.4,0.4,0.2],[0.3,0.3,0.4]]
-- >>> let numbers = markovGen matrix 1 10
-- >>> (== 10) . length <$> numbers
-- True
-- >>> all (inRange (0,2)) <$> numbers
-- True
--
-- See 'markovGen''.
markovGen ::
  (MonadIO m, MonadFail m) =>
  SomeTransitionMatrix ->
  Integer ->
  Integer ->
  m [Integer]
markovGen :: forall (m :: * -> *).
(MonadIO m, MonadFail m) =>
SomeTransitionMatrix -> Integer -> Integer -> m [Integer]
markovGen (SomeTransitionMatrix (TransitionMatrix n
matrix :: TransitionMatrix n)) Integer
s Integer
n =
  case Nat -> SomeNat
someNatVal (Integer -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n) of
    SomeNat (Proxy n
_ :: Proxy steps) ->
      forall (n :: Nat) (steps :: Nat) (m :: * -> *).
(KnownNat n, KnownNat steps, MonadIO m, MonadFail m) =>
TransitionMatrix n -> Finite n -> m (Vector steps (Finite n))
markovGen' @n @steps TransitionMatrix n
matrix (forall (n :: Nat). KnownNat n => Integer -> Finite n
finite @n Integer
s) m (Vector n (Finite n))
-> (Vector n (Finite n) -> [Integer]) -> m [Integer]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Vector n (Finite n)
numbers ->
        (Finite n -> [Integer] -> [Integer])
-> [Integer] -> Vector n (Finite n) -> [Integer]
forall a b. (a -> b -> b) -> b -> Vector Vector n a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((:) (Integer -> [Integer] -> [Integer])
-> (Finite n -> Integer) -> Finite n -> [Integer] -> [Integer]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Finite n -> Integer
forall (n :: Nat). Finite n -> Integer
getFinite) [] Vector n (Finite n)
numbers

-- | See 'markovGen'.
markovGen' ::
  forall n steps m.
  (KnownNat n, KnownNat steps, MonadIO m, MonadFail m) =>
  TransitionMatrix n ->
  Finite n ->
  m (Vector steps (Finite n))
markovGen' :: forall (n :: Nat) (steps :: Nat) (m :: * -> *).
(KnownNat n, KnownNat steps, MonadIO m, MonadFail m) =>
TransitionMatrix n -> Finite n -> m (Vector steps (Finite n))
markovGen' (TransitionMatrix Vector n (Vector n Double)
matrix) Finite n
start =
  Maybe (Vector steps (Finite n)) -> Vector steps (Finite n)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Vector steps (Finite n)) -> Vector steps (Finite n))
-> ([Finite n] -> Maybe (Vector steps (Finite n)))
-> [Finite n]
-> Vector steps (Finite n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
VS.fromList @steps)
    ([Finite n] -> Vector steps (Finite n))
-> m [Finite n] -> m (Vector steps (Finite n))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Nat, Finite n) -> m (Maybe (Finite n, (Nat, Finite n))))
-> (Nat, Finite n) -> m [Finite n]
forall (m :: * -> *) a b.
Monad m =>
(a -> m (Maybe (b, a))) -> a -> m [b]
unfoldrM ((Nat, Vector n Double) -> m (Maybe (Finite n, (Nat, Finite n)))
forall {f :: * -> *} {b} {a} {n :: Nat}.
(MonadIO f, Random b, Ord b, Num a, Num b, Eq a) =>
(a, Vector n b) -> f (Maybe (Finite n, (a, Finite n)))
go ((Nat, Vector n Double) -> m (Maybe (Finite n, (Nat, Finite n))))
-> ((Nat, Finite n) -> (Nat, Vector n Double))
-> (Nat, Finite n)
-> m (Maybe (Finite n, (Nat, Finite n)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Finite n -> Vector n Double)
-> (Nat, Finite n) -> (Nat, Vector n Double)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Vector n (Vector n Double) -> Finite n -> Vector n Double
forall (n :: Nat) a. Vector n a -> Finite n -> a
VS.index Vector n (Vector n Double)
matrix)) (Nat
steps, Finite n
start)
  where
    go :: (a, Vector n b) -> f (Maybe (Finite n, (a, Finite n)))
go (a
0, Vector n b
_) = Maybe (Finite n, (a, Finite n))
-> f (Maybe (Finite n, (a, Finite n)))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Finite n, (a, Finite n))
forall a. Maybe a
Nothing
    go (a
n, Vector n b
prev) =
      IO b -> f b
forall a. IO a -> f a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO b
forall a (m :: * -> *). (Random a, MonadIO m) => m a
randomIO f b
-> (b -> Maybe (Finite n, (a, Finite n)))
-> f (Maybe (Finite n, (a, Finite n)))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \b
p ->
        (b -> Bool) -> Vector n b -> Maybe (Finite n)
forall a (n :: Nat). (a -> Bool) -> Vector n a -> Maybe (Finite n)
VS.findIndex (b -> b -> Bool
forall a. Ord a => a -> a -> Bool
> b
p) (Vector (1 + n) b -> Vector n b
forall (n :: Nat) a. Vector (1 + n) a -> Vector n a
VS.tail ((b -> b -> b) -> b -> Vector n b -> Vector (1 + n) b
forall a b (n :: Nat).
(a -> b -> a) -> a -> Vector n b -> Vector (1 + n) a
VS.scanl b -> b -> b
forall a. Num a => a -> a -> a
(+) b
0 Vector n b
prev)) Maybe (Finite n)
-> (Finite n -> (Finite n, (a, Finite n)))
-> Maybe (Finite n, (a, Finite n))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \Finite n
next ->
          (Finite n
next, (a
n a -> a -> a
forall a. Num a => a -> a -> a
- a
1, Finite n
next))
    steps :: Nat
steps = Proxy steps -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @steps)