{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Internal
(
TcPlugin.newWanted
, newGiven
, evByFiat
, lookupModule
, lookupName
, tracePlugin
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
)
where
import GHC.Driver.Config.Finder (initFinderOpts)
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
(newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginSolveResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)
import GhcApi.Constraint (Ct(..))
import GhcApi.GhcPlugins
import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
HscEnv
hsc_env <- TcPluginM HscEnv
TcPlugin.getTopEnv
let fc :: FinderCache
fc = HscEnv -> FinderCache
hsc_FC HscEnv
hsc_env
dflags :: DynFlags
dflags = HscEnv -> DynFlags
hsc_dflags HscEnv
hsc_env
fopts :: FinderOpts
fopts = DynFlags -> FinderOpts
initFinderOpts DynFlags
dflags
units :: UnitState
units = (() :: Constraint) => HscEnv -> UnitState
HscEnv -> UnitState
hsc_units HscEnv
hsc_env
mhome_unit :: Maybe HomeUnit
mhome_unit = HscEnv -> Maybe HomeUnit
hsc_home_unit_maybe HscEnv
hsc_env
FindResult
found_module <- IO FindResult -> TcPluginM FindResult
forall a. IO a -> TcPluginM a
TcPlugin.tcPluginIO (IO FindResult -> TcPluginM FindResult)
-> IO FindResult -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ FinderCache
-> FinderOpts
-> UnitState
-> Maybe HomeUnit
-> ModuleName
-> IO FindResult
findPluginModule FinderCache
fc FinderOpts
fopts UnitState
units
Maybe HomeUnit
mhome_unit ModuleName
mod_nm
case FindResult
found_module of
Found ModLocation
_ Module
h -> Module -> TcPluginM Module
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return Module
h
FindResult
_ -> do
let pkg_qual :: PkgQual
pkg_qual = PkgQual -> (HomeUnit -> PkgQual) -> Maybe HomeUnit -> PkgQual
forall b a. b -> (a -> b) -> Maybe a -> b
maybe PkgQual
NoPkgQual (UnitId -> PkgQual
ThisPkg (UnitId -> PkgQual) -> (HomeUnit -> UnitId) -> HomeUnit -> PkgQual
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HomeUnit -> UnitId
forall u. GenHomeUnit u -> UnitId
homeUnitId) Maybe HomeUnit
mhome_unit
FindResult
found_module' <- ModuleName -> PkgQual -> TcPluginM FindResult
TcPlugin.findImportedModule ModuleName
mod_nm PkgQual
pkg_qual
case FindResult
found_module' of
Found ModLocation
_ Module
h -> Module -> TcPluginM Module
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return Module
h
FindResult
_ -> String -> SDoc -> TcPluginM Module
forall a. String -> SDoc -> a
panicDoc String
"Couldn't find module" (ModuleName -> SDoc
forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = Module -> OccName -> TcPluginM Name
lookupOrig
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> UniqFM TyCon TcPluginRewriter
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: TcPluginM s
tcPluginSolve :: s -> TcPluginSolver
tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter
tcPluginStop :: s -> TcPluginM ()
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginRewrite :: ()
tcPluginStop :: ()
..} = TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit = TcPluginM s
traceInit
, tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve = s -> TcPluginSolver
traceSolve
, tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = s -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite
, tcPluginStop :: s -> TcPluginM ()
tcPluginStop = s -> TcPluginM ()
traceStop
}
where
traceInit :: TcPluginM s
traceInit = do
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM s -> TcPluginM s
forall a b. TcPluginM a -> TcPluginM b -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit
traceStop :: s -> TcPluginM ()
traceStop s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM () -> TcPluginM ()
forall a b. TcPluginM a -> TcPluginM b -> TcPluginM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z
traceSolve :: s -> TcPluginSolver
traceSolve s
z EvBindsVar
ev [Ct]
given [Ct]
wanted = do
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"given =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
given
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"wanted =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
TcPluginSolveResult
r <- s -> TcPluginSolver
tcPluginSolve s
z EvBindsVar
ev [Ct]
given [Ct]
wanted
case TcPluginSolveResult
r of
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
-> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
TcPluginContradiction [Ct]
bad
-> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve contradiction " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
TcPluginSolveResult [Ct]
bad [(EvTerm, Ct)]
solved [Ct]
new
-> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolveResult " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
TcPluginSolveResult -> TcPluginM TcPluginSolveResult
forall a. a -> TcPluginM a
forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginSolveResult
r
flattenGivens :: [Ct] -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
([((TcTyVar, TcType), Ct)] -> Maybe Ct)
-> [[((TcTyVar, TcType), Ct)]] -> [Ct]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, TcType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, TcType), Ct)]]
flat [Ct] -> [Ct] -> [Ct]
forall a. [a] -> [a] -> [a]
++ (Ct -> Ct) -> [Ct] -> [Ct]
forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, TcType)] -> Ct -> Ct
substCt [(TcTyVar, TcType)]
subst') [Ct]
givens
where
subst :: [((TcTyVar, TcType), Ct)]
subst = [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' [Ct]
givens
([[((TcTyVar, TcType), Ct)]]
flat,[(TcTyVar, TcType)]
subst')
= ([[((TcTyVar, TcType), Ct)]] -> [(TcTyVar, TcType)])
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
-> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)])
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst ([((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)])
-> ([[((TcTyVar, TcType), Ct)]] -> [((TcTyVar, TcType), Ct)])
-> [[((TcTyVar, TcType), Ct)]]
-> [(TcTyVar, TcType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[((TcTyVar, TcType), Ct)]] -> [((TcTyVar, TcType), Ct)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
(([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
-> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)]))
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
-> ([[((TcTyVar, TcType), Ct)]], [(TcTyVar, TcType)])
forall a b. (a -> b) -> a -> b
$ ([((TcTyVar, TcType), Ct)] -> Bool)
-> [[((TcTyVar, TcType), Ct)]]
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2) (Int -> Bool)
-> ([((TcTyVar, TcType), Ct)] -> Int)
-> [((TcTyVar, TcType), Ct)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((TcTyVar, TcType), Ct)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length)
([[((TcTyVar, TcType), Ct)]]
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]]))
-> [[((TcTyVar, TcType), Ct)]]
-> ([[((TcTyVar, TcType), Ct)]], [[((TcTyVar, TcType), Ct)]])
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct) -> Bool)
-> [((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (TcTyVar -> TcTyVar -> Bool
forall a. Eq a => a -> a -> Bool
(==) (TcTyVar -> TcTyVar -> Bool)
-> (((TcTyVar, TcType), Ct) -> TcTyVar)
-> ((TcTyVar, TcType), Ct)
-> ((TcTyVar, TcType), Ct)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((TcTyVar, TcType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, TcType) -> TcTyVar)
-> (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst))
([((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]])
-> [((TcTyVar, TcType), Ct)] -> [[((TcTyVar, TcType), Ct)]]
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, TcType), Ct) -> TcTyVar)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn ((TcTyVar, TcType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, TcType) -> TcTyVar)
-> (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst) [((TcTyVar, TcType), Ct)]
subst
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' = (((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)])
-> [((TcTyVar, TcType), Ct)]
-> [((TcTyVar, TcType), Ct)]
-> [((TcTyVar, TcType), Ct)]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
substSubst [] ([((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)])
-> ([Ct] -> [((TcTyVar, TcType), Ct)])
-> [Ct]
-> [((TcTyVar, TcType), Ct)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> Maybe ((TcTyVar, TcType), Ct))
-> [Ct] -> [((TcTyVar, TcType), Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, TcType), Ct)
mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst :: ((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
substSubst ((TcTyVar
tv,TcType
t),Ct
ct) [((TcTyVar, TcType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, TcType)] -> TcType -> TcType
substType ((((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst [((TcTyVar, TcType), Ct)]
s) TcType
t),Ct
ct)
((TcTyVar, TcType), Ct)
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall a. a -> [a] -> [a]
: (((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct))
-> [((TcTyVar, TcType), Ct)] -> [((TcTyVar, TcType), Ct)]
forall a b. (a -> b) -> [a] -> [b]
map (((TcTyVar, TcType) -> (TcTyVar, TcType))
-> ((TcTyVar, TcType), Ct) -> ((TcTyVar, TcType), Ct)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((TcType -> TcType) -> (TcTyVar, TcType) -> (TcTyVar, TcType)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar
tv,TcType
t)]))) [((TcTyVar, TcType), Ct)]
s
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt [(TcTyVar, TcType)]
subst = (TcType -> TcType) -> Ct -> Ct
overEvidencePredType ([(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar, TcType)]
subst)