{-# LANGUAGE OverloadedStrings #-}

module Ecluse.Pilot.Osv.Compile (
    compileOsvToSqlite,
) where

import Conduit
import Control.Monad.Catch (MonadMask)
import Data.Conduit.List qualified as CL
import Data.Time (getCurrentTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import Data.Version (showVersion)
import Database.SQLite.Simple
import Katip (KatipContext, Severity (..), logFM, ls)
import Paths_ecluse (version)
import System.Directory (createDirectoryIfMissing, removeFile)
import System.FilePath ((</>))
import System.IO.Error (catchIOError)
import UnliftIO.Exception (bracket)

import Ecluse.Core.Osv.Schema (MetaKey (..), osvDbFileName, osvSchemaEpoch, renderMetaKey)
import Ecluse.Pilot.Osv (ExtractedOsv (..))
import Ecluse.Pilot.Osv.Retry (defaultOsvRetryPolicy, withOsvRetry)
import Ecluse.Pilot.Osv.Stream (streamOsvUrl)
import Ecluse.Telemetry (Telemetry)

{- | Compile an ecosystem's OSV advisory export into the SQLite artifact and
return its path. The artifact's name, epoch stamp, and @meta@ table follow the
contract in "Ecluse.Core.Osv.Schema".
-}
compileOsvToSqlite :: (MonadResource m, MonadMask m, MonadUnliftIO m, KatipContext m) => Telemetry -> FilePath -> Text -> String -> m FilePath
compileOsvToSqlite :: forall (m :: * -> *).
(MonadResource m, MonadMask m, MonadUnliftIO m, KatipContext m) =>
Telemetry -> String -> Text -> String -> m String
compileOsvToSqlite Telemetry
telemetry String
outDir Text
ecosystem String
urlStr = do
    let dbFile :: String
dbFile = String
outDir String -> String -> String
</> Text -> String
osvDbFileName Text
ecosystem
    Severity -> LogStr -> m ()
forall (m :: * -> *).
(Applicative m, KatipContext m) =>
Severity -> LogStr -> m ()
logFM Severity
InfoS (Text -> LogStr
forall a. StringConv a Text => a -> LogStr
ls (Text
"Compiling OSV data for " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
ecosystem Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" to " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
forall a. ToText a => a -> Text
toText String
dbFile))

    -- Ensure clean state
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Bool -> String -> IO ()
createDirectoryIfMissing Bool
True String
outDir
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> (IOError -> IO ()) -> IO ()
forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError (String -> IO ()
removeFile String
dbFile) (IO () -> IOError -> IO ()
forall a b. a -> b -> a
const (IO () -> IOError -> IO ()) -> IO () -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

    m Connection
-> (Connection -> m ()) -> (Connection -> m ()) -> m ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (IO Connection -> m Connection
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$ String -> IO Connection
open String
dbFile) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Connection -> IO ()) -> Connection -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
close) ((Connection -> m ()) -> m ()) -> (Connection -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
initSchema Connection
conn

        -- The fetch runs under a truncated exponential backoff (see
        -- 'Ecluse.Pilot.Osv.Retry'): a transient osv.dev failure is retried with
        -- jittered, capped, and count-bounded backoff rather than tight-looping, so
        -- an outage cannot get our egress IP rate-limited or banned. Batches commit
        -- incrementally, so a mid-stream drop can leave a partial table behind; each
        -- attempt therefore wipes it first and re-streams from a clean slate. (INSERT
        -- OR IGNORE alone would not suffice: a NULL introduced/fixed bound is distinct
        -- under the composite primary key, so a re-run would duplicate those ranges.)
        RetryPolicyM m -> m () -> m ()
forall (m :: * -> *) a.
(MonadMask m, KatipContext m) =>
RetryPolicyM m -> m a -> m a
withOsvRetry RetryPolicyM m
forall (m :: * -> *). MonadIO m => RetryPolicyM m
defaultOsvRetryPolicy (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO ()
execute_ Connection
conn Query
"DELETE FROM package_vulnerability_ranges"
            ConduitT () Void m () -> m ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void m () -> m ()) -> ConduitT () Void m () -> m ()
forall a b. (a -> b) -> a -> b
$
                Telemetry -> String -> ConduitT () ExtractedOsv m ()
forall (m :: * -> *) i.
(MonadResource m, MonadThrow m, KatipContext m) =>
Telemetry -> String -> ConduitT i ExtractedOsv m ()
streamOsvUrl Telemetry
telemetry String
urlStr
                    ConduitT () ExtractedOsv m ()
-> ConduitT ExtractedOsv Void m () -> ConduitT () Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Int -> ConduitT ExtractedOsv [ExtractedOsv] m ()
forall (m :: * -> *) a. Monad m => Int -> ConduitT a [a] m ()
CL.chunksOf Int
2000
                    ConduitT ExtractedOsv [ExtractedOsv] m ()
-> ConduitT [ExtractedOsv] Void m ()
-> ConduitT ExtractedOsv Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Connection -> ConduitT [ExtractedOsv] Void m ()
forall (m :: * -> *) o.
MonadIO m =>
Connection -> ConduitT [ExtractedOsv] o m ()
sinkSqlite Connection
conn

        rowCount <- IO Int -> m Int
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ Connection -> Text -> String -> IO Int
writeMeta Connection
conn Text
ecosystem String
urlStr
        logFM InfoS (ls ("Compiled " <> show rowCount <> " advisory ranges for " <> ecosystem))

    String -> m String
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
dbFile

initSchema :: Connection -> IO ()
initSchema :: Connection -> IO ()
initSchema Connection
conn = do
    Connection -> Query -> IO ()
execute_
        Connection
conn
        Query
"CREATE TABLE package_vulnerability_ranges (\
        \  package_name TEXT NOT NULL,\
        \  cve_id TEXT NOT NULL,\
        \  introduced_version TEXT,\
        \  fixed_version TEXT,\
        \  severity TEXT,\
        \  PRIMARY KEY (package_name, cve_id, introduced_version, fixed_version)\
        \)"
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE INDEX idx_package_name ON package_vulnerability_ranges(package_name)"
    -- The reader's remediation probe is an exact (name, fixed) equality; this
    -- index makes it one B-tree traversal. Additive, so epoch-neutral.
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE INDEX idx_package_fixed ON package_vulnerability_ranges(package_name, fixed_version)"
    Connection -> Query -> IO ()
execute_
        Connection
conn
        Query
"CREATE TABLE meta (\
        \  key TEXT NOT NULL PRIMARY KEY,\
        \  value TEXT NOT NULL\
        \)"
    Connection -> Query -> IO ()
execute_ Connection
conn (String -> Query
forall a. IsString a => String -> a
fromString (String
"PRAGMA user_version = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall b a. (Show a, IsString b) => a -> b
show Int
osvSchemaEpoch))

-- Written once, after the stream has completed: the row count is only
-- meaningful for a complete artifact.
writeMeta :: Connection -> Text -> String -> IO Int
writeMeta :: Connection -> Text -> String -> IO Int
writeMeta Connection
conn Text
ecosystem String
urlStr = do
    now <- IO UTCTime
getCurrentTime
    counted <- query_ conn "SELECT COUNT(*) FROM package_vulnerability_ranges" :: IO [Only Int]
    let rowCount = Int -> (Only Int -> Int) -> Maybe (Only Int) -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 Only Int -> Int
forall a. Only a -> a
fromOnly ([Only Int] -> Maybe (Only Int)
forall a. [a] -> Maybe a
listToMaybe [Only Int]
counted)
    executeMany
        conn
        "INSERT INTO meta (key, value) VALUES (?, ?)"
        [ (renderMetaKey MetaPilotVersion, toText (showVersion version))
        , (renderMetaKey MetaEcosystem, ecosystem)
        , (renderMetaKey MetaBuiltAt, toText (iso8601Show now))
        , (renderMetaKey MetaSourceUrl, toText urlStr)
        , (renderMetaKey MetaRowCount, show rowCount)
        ]
    pure rowCount

sinkSqlite :: (MonadIO m) => Connection -> ConduitT [ExtractedOsv] o m ()
sinkSqlite :: forall (m :: * -> *) o.
MonadIO m =>
Connection -> ConduitT [ExtractedOsv] o m ()
sinkSqlite Connection
conn = ([ExtractedOsv] -> ConduitT [ExtractedOsv] o m ())
-> ConduitT [ExtractedOsv] o m ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever (([ExtractedOsv] -> ConduitT [ExtractedOsv] o m ())
 -> ConduitT [ExtractedOsv] o m ())
-> ([ExtractedOsv] -> ConduitT [ExtractedOsv] o m ())
-> ConduitT [ExtractedOsv] o m ()
forall a b. (a -> b) -> a -> b
$ \[ExtractedOsv]
batch ->
    IO () -> ConduitT [ExtractedOsv] o m ()
forall a. IO a -> ConduitT [ExtractedOsv] o m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT [ExtractedOsv] o m ())
-> IO () -> ConduitT [ExtractedOsv] o m ()
forall a b. (a -> b) -> a -> b
$
        Connection -> IO () -> IO ()
forall a. Connection -> IO a -> IO a
withTransaction Connection
conn (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            Connection
-> Query
-> [(Text, Text, Maybe Text, Maybe Text, Maybe Text)]
-> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany
                Connection
conn
                Query
"INSERT OR IGNORE INTO package_vulnerability_ranges (package_name, cve_id, introduced_version, fixed_version, severity) VALUES (?, ?, ?, ?, ?)"
                ((ExtractedOsv -> (Text, Text, Maybe Text, Maybe Text, Maybe Text))
-> [ExtractedOsv]
-> [(Text, Text, Maybe Text, Maybe Text, Maybe Text)]
forall a b. (a -> b) -> [a] -> [b]
map ExtractedOsv -> (Text, Text, Maybe Text, Maybe Text, Maybe Text)
osvToRow [ExtractedOsv]
batch)
  where
    osvToRow :: ExtractedOsv -> (Text, Text, Maybe Text, Maybe Text, Maybe Text)
osvToRow ExtractedOsv
osv = (ExtractedOsv -> Text
extPackage ExtractedOsv
osv, ExtractedOsv -> Text
extCveId ExtractedOsv
osv, ExtractedOsv -> Maybe Text
extIntroduced ExtractedOsv
osv, ExtractedOsv -> Maybe Text
extFixed ExtractedOsv
osv, ExtractedOsv -> Maybe Text
extSeverity ExtractedOsv
osv)