module Network.OAuth2.Experiment.Pkce (
  mkPkceParam,
  CodeChallenge (..),
  CodeVerifier (..),
  CodeChallengeMethod (..),
  PkceRequestParam (..),
) where

import Control.Monad.IO.Class
import Crypto.Hash qualified as H
import Crypto.Random qualified as Crypto
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Data.Word

newtype CodeChallenge = CodeChallenge {CodeChallenge -> Text
unCodeChallenge :: Text}

newtype CodeVerifier = CodeVerifier {CodeVerifier -> Text
unCodeVerifier :: Text} deriving (Int -> CodeVerifier -> ShowS
[CodeVerifier] -> ShowS
CodeVerifier -> String
(Int -> CodeVerifier -> ShowS)
-> (CodeVerifier -> String)
-> ([CodeVerifier] -> ShowS)
-> Show CodeVerifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CodeVerifier -> ShowS
showsPrec :: Int -> CodeVerifier -> ShowS
$cshow :: CodeVerifier -> String
show :: CodeVerifier -> String
$cshowList :: [CodeVerifier] -> ShowS
showList :: [CodeVerifier] -> ShowS
Show)

data CodeChallengeMethod = S256
  deriving (Int -> CodeChallengeMethod -> ShowS
[CodeChallengeMethod] -> ShowS
CodeChallengeMethod -> String
(Int -> CodeChallengeMethod -> ShowS)
-> (CodeChallengeMethod -> String)
-> ([CodeChallengeMethod] -> ShowS)
-> Show CodeChallengeMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CodeChallengeMethod -> ShowS
showsPrec :: Int -> CodeChallengeMethod -> ShowS
$cshow :: CodeChallengeMethod -> String
show :: CodeChallengeMethod -> String
$cshowList :: [CodeChallengeMethod] -> ShowS
showList :: [CodeChallengeMethod] -> ShowS
Show)

data PkceRequestParam = PkceRequestParam
  { PkceRequestParam -> CodeVerifier
codeVerifier :: CodeVerifier
  , PkceRequestParam -> CodeChallenge
codeChallenge :: CodeChallenge
  , PkceRequestParam -> CodeChallengeMethod
codeChallengeMethod :: CodeChallengeMethod
  -- ^ spec says optional but really it shall be s256 or can be omitted?
  -- https://datatracker.ietf.org/doc/html/rfc7636#section-4.3
  }

mkPkceParam :: MonadIO m => m PkceRequestParam
mkPkceParam :: forall (m :: * -> *). MonadIO m => m PkceRequestParam
mkPkceParam = do
  ByteString
codeV <- m ByteString
forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier
  PkceRequestParam -> m PkceRequestParam
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    PkceRequestParam
      { codeVerifier :: CodeVerifier
codeVerifier = Text -> CodeVerifier
CodeVerifier (ByteString -> Text
T.decodeUtf8 ByteString
codeV)
      , codeChallenge :: CodeChallenge
codeChallenge = Text -> CodeChallenge
CodeChallenge (ByteString -> Text
encodeCodeVerifier ByteString
codeV)
      , codeChallengeMethod :: CodeChallengeMethod
codeChallengeMethod = CodeChallengeMethod
S256
      }

encodeCodeVerifier :: BS.ByteString -> Text
encodeCodeVerifier :: ByteString -> Text
encodeCodeVerifier = ByteString -> Text
B64.encodeBase64Unpadded (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA256 -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack (Digest SHA256 -> [Word8])
-> (ByteString -> Digest SHA256) -> ByteString -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest SHA256
hashSHA256

genCodeVerifier :: MonadIO m => m BS.ByteString
genCodeVerifier :: forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
getBytesInternal ByteString
BS.empty

cvMaxLen :: Int
cvMaxLen :: Int
cvMaxLen = Int
128

-- The default 'getRandomBytes' generates bytes out of unreverved characters scope.
-- code-verifier = 43*128unreserved
--   unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
--   ALPHA = %x41-5A / %x61-7A
--   DIGIT = %x30-39
getBytesInternal :: BS.ByteString -> IO BS.ByteString
getBytesInternal :: ByteString -> IO ByteString
getBytesInternal ByteString
ba
  | ByteString -> Int
BS.length ByteString
ba Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cvMaxLen = ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
cvMaxLen ByteString
ba)
  | Bool
otherwise = do
      ByteString
bs <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
Crypto.getRandomBytes Int
cvMaxLen
      let bsUnreserved :: ByteString
bsUnreserved = ByteString
ba ByteString -> ByteString -> ByteString
`BS.append` (Word8 -> Bool) -> ByteString -> ByteString
BS.filter Word8 -> Bool
isUnreversed ByteString
bs
      ByteString -> IO ByteString
getBytesInternal ByteString
bsUnreserved

hashSHA256 :: BS.ByteString -> H.Digest H.SHA256
hashSHA256 :: ByteString -> Digest SHA256
hashSHA256 = ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
H.hash

isUnreversed :: Word8 -> Bool
isUnreversed :: Word8 -> Bool
isUnreversed Word8
w = Word8
w Word8 -> ByteString -> Bool
`BS.elem` ByteString
unreverseBS

{-
a-z: 97-122
A-Z: 65-90
-: 45
.: 46
_: 95
~: 126
-}
unreverseBS :: BS.ByteString
unreverseBS :: ByteString
unreverseBS = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Word8
97 .. Word8
122] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
65 .. Word8
90] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
45, Word8
46, Word8
95, Word8
126]