-----------------------------------------------------------------------------
--
-- Module      :  Test.Tasty.TH
-- Copyright   :  Oscar Finnsson, Benno Fünfstück
-- License     :  BSD3
--
-- Maintainer  :  Benno Fünfstück
-- Stability   :
-- Portability :
--
--
-----------------------------------------------------------------------------
{-# LANGUAGE TemplateHaskell #-}

-- | This module provides TemplateHaskell functions to automatically generate
-- tasty TestTrees from specially named functions. See the README of the package
-- for examples.
--
-- Important: due to to the GHC staging restriction, you must put any uses of these
-- functions at the end of the file, or you may get errors due to missing definitions.
module Test.Tasty.TH
  ( testGroupGenerator
  , defaultMainGenerator
  , testGroupGeneratorFor
  , defaultMainGeneratorFor
  , extractTestFunctions
  , locationModule
  ) where

import Control.Monad (join)
import Control.Applicative
import Language.Haskell.Exts (parseFileContentsWithMode)
import Language.Haskell.Exts.Parser (ParseResult(..), defaultParseMode, parseFilename)
import qualified Language.Haskell.Exts.Syntax as S
import Language.Haskell.TH
import Data.Maybe
import Data.Data (gmapQ, Data)
import Data.Typeable (cast)
import Data.List (nub, isPrefixOf, find)
import qualified Data.Foldable as F

import Test.Tasty
import Prelude

-- | Convenience function that directly generates an `IO` action that may be used as the
-- main function. It's just a wrapper that applies 'defaultMain' to the 'TestTree' generated
-- by 'testGroupGenerator'.
--
-- Example usage:
--
-- @
-- -- properties, test cases, ....
--
-- main :: IO ()
-- main = $('defaultMainGenerator')
-- @
defaultMainGenerator :: ExpQ
defaultMainGenerator :: ExpQ
defaultMainGenerator = [| defaultMain $(ExpQ
testGroupGenerator) |]

-- | This function generates a 'TestTree' from functions in the current module. 
-- The test tree is named after the current module.
--
-- The following definitions are collected by `testGroupGenerator`:
--
-- * a test_something definition in the current module creates a sub-testGroup with the name "something"
-- * a prop_something definition in the current module is added as a QuickCheck property named "something"
-- * a case_something definition leads to a HUnit-Assertion test with the name "something"
--
-- Example usage:
--
-- @
-- prop_example :: Int -> Int -> Bool
-- prop_example a b = a + b == b + a
--
-- tests :: 'TestTree'
-- tests = $('testGroupGenerator')
-- @
testGroupGenerator :: ExpQ
testGroupGenerator :: ExpQ
testGroupGenerator = Q ExpQ -> ExpQ
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Q ExpQ -> ExpQ) -> Q ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ExpQ
testGroupGeneratorFor (String -> [String] -> ExpQ) -> Q String -> Q ([String] -> ExpQ)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Loc -> String) -> Q Loc -> Q String
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Loc -> String
loc_module Q Loc
location Q ([String] -> ExpQ) -> Q [String] -> Q ExpQ
forall a b. Q (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Q [String]
testFunctions
 where
  testFunctions :: Q [String]
testFunctions = Q Loc
location Q Loc -> (Loc -> Q [String]) -> Q [String]
forall a b. Q a -> (a -> Q b) -> Q b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO [String] -> Q [String]
forall a. IO a -> Q a
runIO (IO [String] -> Q [String])
-> (Loc -> IO [String]) -> Loc -> Q [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO [String]
extractTestFunctions (String -> IO [String]) -> (Loc -> String) -> Loc -> IO [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Loc -> String
loc_filename

-- | Retrieves all function names from the given file that would be discovered by 'testGroupGenerator'.
extractTestFunctions :: FilePath -> IO [String]
extractTestFunctions :: String -> IO [String]
extractTestFunctions String
filePath = do
  String
file <- String -> IO String
readFile String
filePath
  -- we first try to parse the file using haskell-src-exts
  -- if that fails, we fallback to lexing each line, which is less
  -- accurate but is more reliable (haskell-src-exts sometimes struggles
  -- with less-common GHC extensions).
  let functions :: [String]
functions = [String] -> Maybe [String] -> [String]
forall a. a -> Maybe a -> a
fromMaybe (String -> [String]
lexed String
file) (String -> Maybe [String]
parsed String
file)
      filtered :: String -> [String]
filtered String
pat = (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String
pat String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf`) [String]
functions
  [String] -> IO [String]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([String] -> IO [String])
-> ([String] -> [String]) -> [String] -> IO [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> IO [String]) -> [String] -> IO [String]
forall a b. (a -> b) -> a -> b
$ [[String]] -> [String]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String -> [String]
filtered String
"prop_", String -> [String]
filtered String
"case_", String -> [String]
filtered String
"test_"]
 where
  lexed :: String -> [String]
lexed = ((String, String) -> String) -> [(String, String)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String, String) -> String
forall a b. (a, b) -> a
fst ([(String, String)] -> [String])
-> (String -> [(String, String)]) -> String -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> [(String, String)]) -> [String] -> [(String, String)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap String -> [(String, String)]
lex ([String] -> [(String, String)])
-> (String -> [String]) -> String -> [(String, String)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
lines
  
  parsed :: String -> Maybe [String]
parsed String
file = case ParseMode -> String -> ParseResult (Module SrcSpanInfo)
parseFileContentsWithMode (ParseMode
defaultParseMode { parseFilename :: String
parseFilename = String
filePath }) String
file of
    ParseOk Module SrcSpanInfo
parsedModule -> [String] -> Maybe [String]
forall a. a -> Maybe a
Just (Module SrcSpanInfo -> [String]
forall {l}. Data l => Module l -> [String]
declarations Module SrcSpanInfo
parsedModule)
    ParseFailed SrcLoc
_ String
_ -> Maybe [String]
forall a. Maybe a
Nothing
  declarations :: Module l -> [String]
declarations (S.Module l
_ Maybe (ModuleHead l)
_ [ModulePragma l]
_ [ImportDecl l]
_ [Decl l]
decls) = (Decl l -> [String]) -> [Decl l] -> [String]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Decl l -> [String]
forall {l}. Data l => Decl l -> [String]
testFunName [Decl l]
decls
  declarations Module l
_ = []
  testFunName :: Decl l -> [String]
testFunName (S.PatBind l
_ Pat l
pat Rhs l
_ Maybe (Binds l)
_) = Pat l -> [String]
forall l. Data l => Pat l -> [String]
patternVariables Pat l
pat
  testFunName (S.FunBind l
_ [Match l]
clauses) = [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ((Match l -> String) -> [Match l] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Match l -> String
forall {l}. Match l -> String
clauseName [Match l]
clauses)
  testFunName Decl l
_ = []
  clauseName :: Match l -> String
clauseName (S.Match l
_ Name l
name [Pat l]
_ Rhs l
_ Maybe (Binds l)
_) = Name l -> String
forall l. Name l -> String
nameString Name l
name
  clauseName (S.InfixMatch l
_ Pat l
_ Name l
name [Pat l]
_ Rhs l
_ Maybe (Binds l)
_) = Name l -> String
forall l. Name l -> String
nameString Name l
name

-- | Convert a 'Name' to a 'String'
nameString :: S.Name l -> String
nameString :: forall l. Name l -> String
nameString (S.Ident l
_ String
n) = String
n
nameString (S.Symbol l
_ String
n) = String
n

-- | Find all variables that are bound in the given pattern.
patternVariables :: Data l => S.Pat l -> [String]
patternVariables :: forall l. Data l => Pat l -> [String]
patternVariables = Pat l -> [String]
forall l. Data l => Pat l -> [String]
go
 where
  go :: Pat l -> [String]
go (S.PVar l
_ Name l
name) = [Name l -> String
forall l. Name l -> String
nameString Name l
name]
  go Pat l
pat = [[String]] -> [String]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[String]] -> [String]) -> [[String]] -> [String]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [String]) -> Pat l -> [[String]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> Pat l -> [u]
gmapQ ((Pat l -> [String]) -> Maybe (Pat l) -> [String]
forall m a. Monoid m => (a -> m) -> Maybe a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
F.foldMap Pat l -> [String]
go (Maybe (Pat l) -> [String])
-> (d -> Maybe (Pat l)) -> d -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> Maybe (Pat l)
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast) Pat l
pat

-- | Extract the name of the current module.
locationModule :: ExpQ
locationModule :: ExpQ
locationModule = do
  Loc
loc <- Q Loc
location
  Exp -> ExpQ
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ExpQ) -> Exp -> ExpQ
forall a b. (a -> b) -> a -> b
$ Lit -> Exp
LitE (Lit -> Exp) -> Lit -> Exp
forall a b. (a -> b) -> a -> b
$ String -> Lit
StringL (String -> Lit) -> String -> Lit
forall a b. (a -> b) -> a -> b
$ Loc -> String
loc_module Loc
loc

-- | Like 'testGroupGenerator', but generates a test group only including the specified function names.
-- The function names still need to follow the pattern of starting with one of @prop_@, @case_@ or @test_@.
testGroupGeneratorFor
  :: String   -- ^ The name of the test group itself
  -> [String] -- ^ The names of the functions which should be included in the test group
  -> ExpQ
testGroupGeneratorFor :: String -> [String] -> ExpQ
testGroupGeneratorFor String
name [String]
functionNames = [| testGroup name $([ExpQ] -> ExpQ
forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE ((String -> Maybe ExpQ) -> [String] -> [ExpQ]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe String -> Maybe ExpQ
forall {m :: * -> *}. Quote m => String -> Maybe (m Exp)
test [String]
functionNames)) |]
 where
  testFunctions :: [(String, String)]
testFunctions = [(String
"prop_", String
"testProperty"), (String
"case_", String
"testCase"), (String
"test_", String
"testGroup")]
  getTestFunction :: String -> Maybe String
getTestFunction String
fname = (String, String) -> String
forall a b. (a, b) -> b
snd ((String, String) -> String)
-> Maybe (String, String) -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((String, String) -> Bool)
-> [(String, String)] -> Maybe (String, String)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
fname) (String -> Bool)
-> ((String, String) -> String) -> (String, String) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, String) -> String
forall a b. (a, b) -> a
fst) [(String, String)]
testFunctions
  test :: String -> Maybe (m Exp)
test String
fname = do
    String
fn <- String -> Maybe String
getTestFunction String
fname
    m Exp -> Maybe (m Exp)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (m Exp -> Maybe (m Exp)) -> m Exp -> Maybe (m Exp)
forall a b. (a -> b) -> a -> b
$ m Exp -> m Exp -> m Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (m Exp -> m Exp -> m Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> m Exp) -> Name -> m Exp
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
fn) (String -> m Exp
forall (m :: * -> *). Quote m => String -> m Exp
stringE (String -> String
fixName String
fname))) (Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (String -> Name
mkName String
fname))

-- | Like 'defaultMainGenerator', but only includes the specific function names in the test group.
-- The function names still need to follow the pattern of starting with one of @prop_@, @case_@ or @test_@.
defaultMainGeneratorFor
  :: String   -- ^ The name of the top-level test group
  -> [String] -- ^ The names of the functions which should be included in the test group
  -> ExpQ
defaultMainGeneratorFor :: String -> [String] -> ExpQ
defaultMainGeneratorFor String
name [String]
fns = [| defaultMain $(String -> [String] -> ExpQ
testGroupGeneratorFor String
name [String]
fns) |]

fixName :: String -> String
fixName :: String -> String
fixName = Char -> Char -> String -> String
forall a. Eq a => a -> a -> [a] -> [a]
replace Char
'_' Char
' ' (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
forall a. HasCallStack => [a] -> [a]
tail (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_')

replace :: Eq a => a -> a -> [a] -> [a]
replace :: forall a. Eq a => a -> a -> [a] -> [a]
replace a
b a
v = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (\a
i -> if a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
i then a
v else a
i)