{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Dense.Index
(
Layout
, Shape (..)
, indexIso
, shapeIndexes
, shapeIndexesFrom
, shapeIndexesBetween
, HasLayout (..)
, extent
, size
, indexes
, indexesBetween
, indexesFrom
, ArrayException (IndexOutOfBounds)
, _IndexOutOfBounds
, boundsCheck
, SizeMissmatch (..)
, AsSizeMissmatch (..)
, sizeMissmatch
, showShape
) where
import Control.Applicative
import Control.Exception
import Control.Exception.Lens
import Control.Lens
import Control.Lens.Internal.Getter
import Data.Foldable as F
import Data.Typeable
import Data.Functor.Classes
import Data.Traversable
import Linear
type Layout f = f Int
class (Eq1 f, Additive f, Traversable f) => Shape f where
shapeToIndex :: Layout f -> f Int -> Int
shapeToIndex l x = F.foldl (\k (e, a) -> k * e + a) 0 (liftI2 (,) l x)
{-# INLINE shapeToIndex #-}
shapeFromIndex :: Layout f -> Int -> f Int
shapeFromIndex l i = snd $ mapAccumR quotRem i l
{-# INLINE shapeFromIndex #-}
shapeIntersect :: Layout f -> Layout f -> Layout f
shapeIntersect = liftU2 min
{-# INLINE shapeIntersect #-}
unsafeShapeStep :: Layout f -> f Int -> f Int
unsafeShapeStep l
= shapeFromIndex l
. (+1)
. shapeToIndex l
{-# INLINE unsafeShapeStep #-}
shapeStep :: Layout f -> f Int -> Maybe (f Int)
shapeStep l = fmap (shapeFromIndex l)
. guardPure (< shapeSize l)
. (+1)
. shapeToIndex l
{-# INLINE shapeStep #-}
shapeStepBetween :: f Int -> Layout f -> f Int -> Maybe (f Int)
shapeStepBetween a l = fmap (^+^ a) . shapeStep l . (^-^ a)
{-# INLINE shapeStepBetween #-}
shapeInRange :: Layout f -> f Int -> Bool
shapeInRange l i = F.and $ liftI2 (\ii li -> ii >= 0 && ii < li) i l
{-# INLINE shapeInRange #-}
shapeSize :: Layout f -> Int
shapeSize = F.product
{-# INLINE shapeSize #-}
guardPure :: Alternative f => (a -> Bool) -> a -> f a
guardPure p a = if p a then pure a else empty
{-# INLINE guardPure #-}
instance Shape V0
instance Shape V1 where
{-# INLINE shapeToIndex #-}
{-# INLINE shapeFromIndex #-}
{-# INLINE shapeIntersect #-}
{-# INLINE shapeStep #-}
{-# INLINE shapeInRange #-}
shapeToIndex _ (V1 i) = i
shapeFromIndex _ i = V1 i
shapeIntersect = min
shapeStep l = guardPure (shapeInRange l) . (+1)
shapeStepBetween _a b i = guardPure (> b) i'
where i' = i + 1
shapeInRange m i = i >= 0 && i < m
instance Shape V2 where
shapeToIndex (V2 _x y) (V2 i j) = y*i + j
{-# INLINE shapeToIndex #-}
shapeFromIndex (V2 _x y) n = V2 i j
where (i, j) = n `quotRem` y
{-# INLINE shapeFromIndex #-}
shapeStep (V2 x y) (V2 i j)
| j + 1 < y = Just (V2 i (j + 1))
| i + 1 < x = Just (V2 (i + 1) 0 )
| otherwise = Nothing
{-# INLINE shapeStep #-}
unsafeShapeStep (V2 _ y) (V2 i j)
| j + 1 < y = V2 i (j + 1)
| otherwise = V2 (i + 1) 0
{-# INLINE unsafeShapeStep #-}
shapeStepBetween (V2 _ia ja) (V2 ib jb) (V2 i j)
| j + 1 < jb = Just (V2 i (j + 1))
| i + 1 < ib = Just (V2 (i + 1) ja )
| otherwise = Nothing
{-# INLINE shapeStepBetween #-}
instance Shape V3 where
shapeStep (V3 x y z) (V3 i j k)
| k + 1 < z = Just (V3 i j (k + 1))
| j + 1 < y = Just (V3 i (j + 1) 0 )
| i + 1 < x = Just (V3 (i + 1) 0 0 )
| otherwise = Nothing
{-# INLINE shapeStep #-}
shapeStepBetween (V3 _ia ja ka) (V3 ib jb kb) (V3 i j k)
| k < kb = Just (V3 i j (k + 1))
| j < jb = Just (V3 i (j + 1) ka )
| i < ib = Just (V3 (i + 1) ja ka )
| otherwise = Nothing
{-# INLINE shapeStepBetween #-}
instance Shape V4 where
shapeStep (V4 x y z w) (V4 i j k l)
| l + 1 < w = Just (V4 i j k (l + 1))
| k + 1 < z = Just (V4 i j (k + 1) 0 )
| j + 1 < y = Just (V4 i (j + 1) 0 0 )
| i + 1 < x = Just (V4 (i + 1) 0 0 0 )
| otherwise = Nothing
{-# INLINE shapeStep #-}
shapeStepBetween (V4 _ia ja ka la) (V4 ib jb kb lb) (V4 i j k l)
| l < lb = Just (V4 i j k (l + 1))
| k < kb = Just (V4 i j (k + 1) la )
| j < jb = Just (V4 i (j + 1) ka la )
| i < ib = Just (V4 (i + 1) ja ka la )
| otherwise = Nothing
{-# INLINE shapeStepBetween #-}
indexIso :: Shape f => Layout f -> Iso' (f Int) Int
indexIso l = iso (shapeToIndex l) (shapeFromIndex l)
{-# INLINE indexIso #-}
class Shape f => HasLayout f a | a -> f where
layout :: Lens' a (Layout f)
default layout :: (a ~ f Int) => (Layout f -> g (Layout f)) -> a -> g a
layout = id
{-# INLINE layout #-}
instance i ~ Int => HasLayout V0 (V0 i)
instance i ~ Int => HasLayout V1 (V1 i)
instance i ~ Int => HasLayout V2 (V2 i)
instance i ~ Int => HasLayout V3 (V3 i)
instance i ~ Int => HasLayout V4 (V4 i)
extent :: HasLayout f a => a -> f Int
extent = view layout
{-# INLINE extent #-}
size :: HasLayout f a => a -> Int
size = shapeSize . view layout
{-# INLINE size #-}
indexes :: HasLayout f a => IndexedFold Int a (f Int)
indexes = layout . shapeIndexes
{-# INLINE indexes #-}
shapeIndexes :: Shape f => IndexedFold Int (Layout f) (f Int)
shapeIndexes g l = go (0::Int) (if eq1 l zero then Nothing else Just zero) where
go i (Just x) = indexed g i x *> go (i + 1) (shapeStep l x)
go _ Nothing = noEffect
{-# INLINE shapeIndexes #-}
indexesFrom :: HasLayout f a => f Int -> IndexedFold Int a (f Int)
indexesFrom a = layout . shapeIndexesFrom a
{-# INLINE indexesFrom #-}
shapeIndexesFrom :: Shape f => f Int -> IndexedFold Int (Layout f) (f Int)
shapeIndexesFrom a f l = shapeIndexesBetween a l f l
{-# INLINE shapeIndexesFrom #-}
indexesBetween :: HasLayout f a => f Int -> f Int -> IndexedFold Int a (f Int)
indexesBetween a b = layout . shapeIndexesBetween a b
{-# INLINE indexesBetween #-}
shapeIndexesBetween :: Shape f => f Int -> f Int -> IndexedFold Int (Layout f) (f Int)
shapeIndexesBetween a b f l =
go (if eq1 l a || not (shapeInRange l b) then Nothing else Just a) where
go (Just x) = indexed f (shapeToIndex l x) x *> go (shapeStepBetween a b x)
go Nothing = noEffect
{-# INLINE shapeIndexesBetween #-}
boundsCheck :: Shape l => Layout l-> l Int -> a -> a
boundsCheck l i
| shapeInRange l i = id
| otherwise = throwing _IndexOutOfBounds $ "(" ++ showShape i ++ ", " ++ showShape l ++ ")"
{-# INLINE boundsCheck #-}
data SizeMissmatch = SizeMissmatch String
deriving Typeable
instance Exception SizeMissmatch
instance Show SizeMissmatch where
showsPrec _ (SizeMissmatch s)
= showString "size missmatch"
. (if not (null s) then showString ": " . showString s
else id)
class AsSizeMissmatch t where
_SizeMissmatch :: Prism' t String
instance AsSizeMissmatch SizeMissmatch where
_SizeMissmatch = prism' SizeMissmatch $ (\(SizeMissmatch s) -> Just s)
{-# INLINE _SizeMissmatch #-}
instance AsSizeMissmatch SomeException where
_SizeMissmatch = exception . (_SizeMissmatch :: Prism' SizeMissmatch String)
{-# INLINE _SizeMissmatch #-}
sizeMissmatch :: Int -> Int -> String -> a -> a
sizeMissmatch i j err
| i == j = id
| otherwise = throwing _SizeMissmatch err
{-# INLINE sizeMissmatch #-}
showShape :: Shape f => f Int -> String
showShape l = "V" ++ show (lengthOf folded l) ++ " " ++ unwords (show <$> F.toList l)