-- Decrypt.hs: OpenPGP (RFC4880) recursive packet decryption
-- Copyright © 2013-2020  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).
{-# LANGUAGE FlexibleContexts #-}

module Data.Conduit.OpenPGP.Decrypt
  ( conduitDecrypt
  ) where

import Control.Monad (when)
import Control.Monad.Fail (MonadFail)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Trans.Resource (MonadResource, MonadThrow)
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import Data.Binary (get)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16.Lazy as B16L
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Combinators as CC
import qualified Data.Conduit.List as CL
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Maybe (fromJust, isNothing)

import Codec.Encryption.OpenPGP.CFB (decryptOpenPGPCfb, decryptPreservingNonce)
import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.Types

data RecursorState =
  RecursorState
    { RecursorState -> Int
_depth :: Int
    , RecursorState -> Maybe PKESK
_lastPKESK :: Maybe PKESK
    , RecursorState -> Maybe SKESK
_lastSKESK :: Maybe SKESK
    , RecursorState -> Maybe ByteString
_lastNonce :: Maybe B.ByteString
    , RecursorState -> Maybe ByteString
_lastClearText :: Maybe B.ByteString
    }
  deriving (RecursorState -> RecursorState -> Bool
(RecursorState -> RecursorState -> Bool)
-> (RecursorState -> RecursorState -> Bool) -> Eq RecursorState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RecursorState -> RecursorState -> Bool
== :: RecursorState -> RecursorState -> Bool
$c/= :: RecursorState -> RecursorState -> Bool
/= :: RecursorState -> RecursorState -> Bool
Eq, Int -> RecursorState -> ShowS
[RecursorState] -> ShowS
RecursorState -> String
(Int -> RecursorState -> ShowS)
-> (RecursorState -> String)
-> ([RecursorState] -> ShowS)
-> Show RecursorState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RecursorState -> ShowS
showsPrec :: Int -> RecursorState -> ShowS
$cshow :: RecursorState -> String
show :: RecursorState -> String
$cshowList :: [RecursorState] -> ShowS
showList :: [RecursorState] -> ShowS
Show)

def :: RecursorState
def :: RecursorState
def = Int
-> Maybe PKESK
-> Maybe SKESK
-> Maybe ByteString
-> Maybe ByteString
-> RecursorState
RecursorState Int
0 Maybe PKESK
forall a. Maybe a
Nothing Maybe SKESK
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing Maybe ByteString
forall a. Maybe a
Nothing

type InputCallback m = String -> m BL.ByteString

conduitDecrypt ::
     (MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
  => InputCallback IO
  -> ConduitT Pkt Pkt m ()
conduitDecrypt :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt = RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' RecursorState
def

conduitDecrypt' ::
     (MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> ConduitT Pkt Pkt m ()
conduitDecrypt' :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' RecursorState
rs InputCallback IO
cb = (Pkt -> RecursorState -> m (RecursorState, [Pkt]))
-> RecursorState -> ConduitT Pkt Pkt m ()
forall (m :: * -> *) a accum b.
Monad m =>
(a -> accum -> m (accum, [b])) -> accum -> ConduitT a b m ()
CC.concatMapAccumM Pkt -> RecursorState -> m (RecursorState, [Pkt])
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
Pkt -> RecursorState -> m (RecursorState, [Pkt])
push RecursorState
rs
  where
    push ::
         (MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m)
      => Pkt
      -> RecursorState
      -> m (RecursorState, [Pkt])
    push :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
Pkt -> RecursorState -> m (RecursorState, [Pkt])
push Pkt
i RecursorState
s
      | RecursorState -> Int
_depth RecursorState
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
42 = String -> m (RecursorState, [Pkt])
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"I think we've been quine-attacked"
      | Bool
otherwise =
        case Pkt
i of
          SKESKPkt {} -> (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s {_lastSKESK = Just (fromPkt i)}, [])
          (SymEncDataPkt ByteString
bs) -> do
            [Pkt]
d <- RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEDP RecursorState
s InputCallback IO
cb (Maybe SKESK -> SKESK
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe SKESK -> SKESK)
-> (RecursorState -> Maybe SKESK) -> RecursorState -> SKESK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecursorState -> Maybe SKESK
_lastSKESK (RecursorState -> SKESK) -> RecursorState -> SKESK
forall a b. (a -> b) -> a -> b
$ RecursorState
s) ByteString
bs
            (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt]
d)
          (SymEncIntegrityProtectedDataPkt Word8
_ ByteString
bs) -> do
            [Pkt]
d <- RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEIPDP RecursorState
s InputCallback IO
cb (Maybe SKESK -> SKESK
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe SKESK -> SKESK)
-> (RecursorState -> Maybe SKESK) -> RecursorState -> SKESK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecursorState -> Maybe SKESK
_lastSKESK (RecursorState -> SKESK) -> RecursorState -> SKESK
forall a b. (a -> b) -> a -> b
$ RecursorState
s) ByteString
bs
            (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt]
d)
          m :: Pkt
m@(ModificationDetectionCodePkt ByteString
mdc) -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString -> Bool
forall a. Maybe a -> Bool
isNothing (RecursorState -> Maybe ByteString
_lastClearText RecursorState
s)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"MDC with no referent"
            let mcalculated :: Maybe ByteString
mcalculated = ByteString -> ByteString -> ByteString
calculateMDC (ByteString -> ByteString -> ByteString)
-> Maybe ByteString -> Maybe (ByteString -> ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecursorState -> Maybe ByteString
_lastNonce RecursorState
s Maybe (ByteString -> ByteString)
-> Maybe ByteString -> Maybe ByteString
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> RecursorState -> Maybe ByteString
_lastClearText RecursorState
s
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString
mcalculated Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
mdc) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
              String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
              String
"MDC indicates tampering: " String -> ShowS
forall a. [a] -> [a] -> [a]
++
              ByteString -> String
forall a. Show a => a -> String
show (ByteString -> ByteString
B16L.encode ByteString
mdc) String -> ShowS
forall a. [a] -> [a] -> [a]
++
              String
" versus " String -> ShowS
forall a. [a] -> [a] -> [a]
++
              String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"<empty>" (ByteString -> String
forall a. Show a => a -> String
show (ByteString -> String)
-> (ByteString -> ByteString) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B16L.encode) Maybe ByteString
mcalculated String -> ShowS
forall a. [a] -> [a] -> [a]
++
              String
"  ... " String -> ShowS
forall a. [a] -> [a] -> [a]
++
              Maybe ByteString -> String
forall a. Show a => a -> String
show (RecursorState -> Maybe ByteString
_lastNonce RecursorState
s) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" / " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe ByteString -> String
forall a. Show a => a -> String
show (RecursorState -> Maybe ByteString
_lastClearText RecursorState
s)
            (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt
m])
          Pkt
p -> (RecursorState, [Pkt]) -> m (RecursorState, [Pkt])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecursorState
s, [Pkt
p])

decryptSEDP ::
     (MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> SKESK
  -> BL.ByteString
  -> m [Pkt]
decryptSEDP :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEDP RecursorState
rs InputCallback IO
cb SKESK
skesk ByteString
bs -- FIXME: this shouldn't pass the whole SKESK
 = do
  ByteString
passphrase <- IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ InputCallback IO
cb String
"Input the passphrase I want"
  let key :: ByteString
key = SKESK -> ByteString -> ByteString
skesk2Key SKESK
skesk ByteString
passphrase
      decrypted :: ByteString
decrypted =
        case SymmetricAlgorithm
-> ByteString -> ByteString -> Either String ByteString
decryptOpenPGPCfb
               (SKESK -> SymmetricAlgorithm
_skeskSymmetricAlgorithm SKESK
skesk)
               (ByteString -> ByteString
BL.toStrict ByteString
bs)
               ByteString
key of
          Left String
e -> String -> ByteString
forall a. HasCallStack => String -> a
error String
e
          Right ByteString
x -> ByteString
x
  ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt]
forall (m :: * -> *) r.
MonadUnliftIO m =>
ConduitT () Void (ResourceT m) r -> m r
runConduitRes (ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt])
-> ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt]
forall a b. (a -> b) -> a -> b
$
    ByteString -> ConduitT () ByteString (ResourceT m) ()
forall (m :: * -> *) i.
Monad m =>
ByteString -> ConduitT i ByteString m ()
CB.sourceLbs (ByteString -> ByteString
BL.fromStrict ByteString
decrypted) ConduitT () ByteString (ResourceT m) ()
-> ConduitT ByteString Void (ResourceT m) [Pkt]
-> ConduitT () Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Get Pkt -> ConduitT ByteString Pkt (ResourceT m) ()
forall (m :: * -> *) b.
MonadThrow m =>
Get b -> ConduitT ByteString b m ()
conduitGet Get Pkt
forall t. Binary t => Get t
get ConduitT ByteString Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT ByteString Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    ConduitT Pkt Pkt (ResourceT m) ()
forall (m :: * -> *). MonadThrow m => ConduitT Pkt Pkt m ()
conduitDecompress ConduitT Pkt Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    RecursorState
-> InputCallback IO -> ConduitT Pkt Pkt (ResourceT m) ()
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' RecursorState
rs {_depth = _depth rs + 1} InputCallback IO
cb ConduitT Pkt Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a o. Monad m => ConduitT a o m [a]
CL.consume

decryptSEIPDP ::
     (MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> SKESK
  -> BL.ByteString
  -> m [Pkt]
decryptSEIPDP :: forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadIO m, MonadThrow m) =>
RecursorState -> InputCallback IO -> SKESK -> ByteString -> m [Pkt]
decryptSEIPDP RecursorState
rs InputCallback IO
cb SKESK
skesk ByteString
bs -- FIXME: this shouldn't pass the whole SKESK
 = do
  ByteString
passphrase <- IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ InputCallback IO
cb String
"Input the passphrase I want"
  let key :: ByteString
key = SKESK -> ByteString -> ByteString
skesk2Key SKESK
skesk ByteString
passphrase
      (ByteString
nonce, ByteString
decrypted) =
        case SymmetricAlgorithm
-> ByteString
-> ByteString
-> Either String (ByteString, ByteString)
decryptPreservingNonce
               (SKESK -> SymmetricAlgorithm
_skeskSymmetricAlgorithm SKESK
skesk)
               (ByteString -> ByteString
BL.toStrict ByteString
bs)
               ByteString
key of
          Left String
e -> String -> (ByteString, ByteString)
forall a. HasCallStack => String -> a
error String
e
          Right (ByteString, ByteString)
x -> (ByteString, ByteString)
x
  ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt]
forall (m :: * -> *) r.
MonadUnliftIO m =>
ConduitT () Void (ResourceT m) r -> m r
runConduitRes (ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt])
-> ConduitT () Void (ResourceT m) [Pkt] -> m [Pkt]
forall a b. (a -> b) -> a -> b
$
    ByteString -> ConduitT () ByteString (ResourceT m) ()
forall (m :: * -> *) i.
Monad m =>
ByteString -> ConduitT i ByteString m ()
CB.sourceLbs (ByteString -> ByteString
BL.fromStrict ByteString
decrypted) ConduitT () ByteString (ResourceT m) ()
-> ConduitT ByteString Void (ResourceT m) [Pkt]
-> ConduitT () Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Get Pkt -> ConduitT ByteString Pkt (ResourceT m) ()
forall (m :: * -> *) b.
MonadThrow m =>
Get b -> ConduitT ByteString b m ()
conduitGet Get Pkt
forall t. Binary t => Get t
get ConduitT ByteString Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT ByteString Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    ConduitT Pkt Pkt (ResourceT m) ()
forall (m :: * -> *). MonadThrow m => ConduitT Pkt Pkt m ()
conduitDecompress ConduitT Pkt Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    RecursorState
-> InputCallback IO -> ConduitT Pkt Pkt (ResourceT m) ()
forall (m :: * -> *).
(MonadFail m, MonadUnliftIO m, MonadResource m, MonadThrow m) =>
RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt'
      RecursorState
rs
        { _depth = _depth rs + 1
        , _lastNonce = Just nonce
        , _lastClearText = Just decrypted
        }
      InputCallback IO
cb ConduitT Pkt Pkt (ResourceT m) ()
-> ConduitT Pkt Void (ResourceT m) [Pkt]
-> ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.|
    ConduitT Pkt Void (ResourceT m) [Pkt]
forall (m :: * -> *) a o. Monad m => ConduitT a o m [a]
CL.consume

calculateMDC :: B.ByteString -> B.ByteString -> BL.ByteString
calculateMDC :: ByteString -> ByteString -> ByteString
calculateMDC ByteString
nonce ByteString
garbage
  | ByteString -> Int
B.length ByteString
garbage Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
23 = ByteString
forall a. Monoid a => a
mempty -- FIXME: this is horrible
  | Bool
otherwise =
    ByteString -> ByteString
BL.fromStrict (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA1 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest SHA1 -> ByteString)
-> (ByteString -> Digest SHA1) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Digest SHA1
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
CH.hash :: B.ByteString -> CH.Digest CHA.SHA1) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
    ByteString
nonce ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString -> ByteString
B.take (ByteString -> Int
B.length ByteString
garbage Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
22) ByteString
garbage ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> [Word8] -> ByteString
B.pack [Word8
211, Word8
20]