diff --git a/booster/library/Booster/Definition/Ceil.hs b/booster/library/Booster/Definition/Ceil.hs index b265c40634..9b14814df2 100644 --- a/booster/library/Booster/Definition/Ceil.hs +++ b/booster/library/Booster/Definition/Ceil.hs @@ -18,6 +18,7 @@ import Booster.Log import Booster.Pattern.Bool import Booster.Pattern.Pretty import Booster.Pattern.Util (isConcrete, sortOfTerm) +import Booster.SMT.Interface import Booster.Util (Flag (..)) import Control.DeepSeq (NFData) import Control.Monad (foldM) @@ -101,7 +102,8 @@ computeCeilRule :: computeCeilRule mllvm def r@RewriteRule.RewriteRule{lhs, requires, rhs, attributes, computedAttributes} | null computedAttributes.notPreservesDefinednessReasons = pure Nothing | otherwise = do - (res, _) <- runEquationT def mllvm Nothing mempty mempty $ do + ns <- noSolver + (res, _) <- runEquationT def mllvm ns mempty mempty $ do lhsCeils <- Set.fromList <$> computeCeil lhs requiresCeils <- Set.fromList <$> concatMapM (computeCeil . coerce) (Set.toList requires) let subtractLHSAndRequiresCeils = (Set.\\ (lhsCeils `Set.union` requiresCeils)) . Set.fromList diff --git a/booster/library/Booster/JsonRpc.hs b/booster/library/Booster/JsonRpc.hs index 13f038eef3..0d1d54d8d2 100644 --- a/booster/library/Booster/JsonRpc.hs +++ b/booster/library/Booster/JsonRpc.hs @@ -20,7 +20,6 @@ import Control.Applicative ((<|>)) import Control.Concurrent (MVar, putMVar, readMVar, takeMVar) import Control.Exception qualified as Exception import Control.Monad -import Control.Monad.Extra (whenJust) import Control.Monad.IO.Class import Control.Monad.Trans.Except (catchE, except, runExcept, runExceptT, throwE, withExceptT) import Crypto.Hash (SHA256 (..), hashWith) @@ -145,10 +144,10 @@ respond stateVar request = , ceilConditions = pat.ceilConditions } - solver <- traverse (SMT.initSolver def) mSMTOptions + solver <- maybe (SMT.noSolver) (SMT.initSolver def) mSMTOptions result <- performRewrite doTracing def mLlvmLibrary solver mbDepth cutPoints terminals substPat - whenJust solver SMT.finaliseSolver + SMT.finaliseSolver solver stop <- liftIO $ getTime Monotonic let duration = if fromMaybe False req.logTiming @@ -228,7 +227,7 @@ respond stateVar request = | otherwise = Nothing - solver <- traverse (SMT.initSolver def) mSMTOptions + solver <- maybe (SMT.noSolver) (SMT.initSolver def) mSMTOptions result <- case internalised of Left patternErrors -> do @@ -299,7 +298,7 @@ respond stateVar request = pure $ Right (addHeader $ Syntax.KJAnd predicateSort result) (Left something, _) -> pure . Left . RpcError.backendError $ RpcError.Aborted $ renderText $ pretty' @mods something - whenJust solver SMT.finaliseSolver + SMT.finaliseSolver solver stop <- liftIO $ getTime Monotonic let duration = @@ -362,7 +361,7 @@ respond stateVar request = withContext CtxGetModel $ withContext CtxSMT $ logMessage ("No predicates or substitutions given, returning Unknown" :: Text) - pure $ Left SMT.Unknown + pure $ Left $ SMT.Unknown $ Just "No predicates or substitutions given" else do solver <- SMT.initSolver def smtOptions result <- SMT.getModelFor solver boolPs suppliedSubst @@ -380,12 +379,7 @@ respond stateVar request = { satisfiable = RpcTypes.Unsat , substitution = Nothing } - Left SMT.ReasonUnknown{} -> - RpcTypes.GetModelResult - { satisfiable = RpcTypes.Unknown - , substitution = Nothing - } - Left SMT.Unknown -> + Left SMT.Unknown{} -> RpcTypes.GetModelResult { satisfiable = RpcTypes.Unknown , substitution = Nothing @@ -485,7 +479,7 @@ respond stateVar request = MatchSuccess subst -> do let filteredConsequentPreds = Set.map (substituteInPredicate subst) substPatR.constraints `Set.difference` substPatL.constraints - solver <- traverse (SMT.initSolver def) mSMTOptions + solver <- maybe (SMT.noSolver) (SMT.initSolver def) mSMTOptions if null filteredConsequentPreds then implies (sortOfPattern substPatL) req.antecedent.term req.consequent.term subst @@ -555,7 +549,10 @@ handleSmtError = JsonRpcHandler $ \case let bool = externaliseSort Pattern.SortBool -- predicates are terms of sort Bool externalise = Syntax.KJAnd bool . map (externalisePredicate bool) . Set.toList allPreds = addHeader $ Syntax.KJAnd bool [externalise premises, externalise preds] - pure $ RpcError.backendError $ RpcError.SmtSolverError $ RpcError.ErrorWithTerm reason allPreds + pure $ + RpcError.backendError $ + RpcError.SmtSolverError $ + RpcError.ErrorWithTerm (fromMaybe "UNKNOWN" reason) allPreds where runtimeError prefix err = do let msg = "SMT " <> prefix <> ": " <> err diff --git a/booster/library/Booster/Pattern/ApplyEquations.hs b/booster/library/Booster/Pattern/ApplyEquations.hs index 0131e5910b..4c9d096b26 100644 --- a/booster/library/Booster/Pattern/ApplyEquations.hs +++ b/booster/library/Booster/Pattern/ApplyEquations.hs @@ -144,7 +144,7 @@ instance Pretty (PrettyWithModifiers mods EquationFailure) where data EquationConfig = EquationConfig { definition :: KoreDefinition , llvmApi :: Maybe LLVM.API - , smtSolver :: Maybe SMT.SMTContext + , smtSolver :: SMT.SMTContext , maxRecursion :: Bound "Recursion" , maxIterations :: Bound "Iterations" , logger :: Logger LogMessage @@ -281,7 +281,7 @@ runEquationT :: LoggerMIO io => KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> Set Predicate -> EquationT io a -> @@ -394,7 +394,7 @@ evaluateTerm :: Direction -> KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> Set Predicate -> Term -> io (Either EquationFailure Term, SimplifierCache) @@ -417,7 +417,7 @@ evaluatePattern :: LoggerMIO io => KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> Pattern -> io (Either EquationFailure Pattern, SimplifierCache) @@ -462,7 +462,7 @@ evaluateConstraints :: LoggerMIO io => KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> Set Predicate -> io (Either EquationFailure (Set Predicate), SimplifierCache) @@ -828,7 +828,7 @@ applyEquation term rule = -- could now be syntactically present in the path constraints, filter again stillUnclear <- lift $ filterOutKnownConstraints knownPredicates unclearConditions - mbSolver :: Maybe SMT.SMTContext <- (.smtSolver) <$> lift getConfig + solver :: SMT.SMTContext <- (.smtSolver) <$> lift getConfig -- check any conditions that are still unclear with the SMT solver -- (or abort if no solver is being used), abort if still unclear after @@ -842,7 +842,7 @@ applyEquation term rule = liftIO $ Exception.throw other Right result -> pure result - in maybe (pure Nothing) (lift . checkWithSmt) mbSolver >>= \case + in lift (checkWithSmt solver) >>= \case Nothing -> do -- no solver or still unclear: abort throwE @@ -882,23 +882,22 @@ applyEquation term rule = ) ensured -- check all ensured conditions together with the path condition - whenJust mbSolver $ \solver -> do - lift (SMT.checkPredicates solver knownPredicates mempty $ Set.fromList ensuredConditions) >>= \case - Right (Just False) -> do - let falseEnsures = Predicate $ foldl1' AndTerm $ map coerce ensuredConditions - throwE - ( \ctx -> - ctx . logMessage $ - WithJsonMessage (object ["conditions" .= map (externaliseTerm . coerce) ensuredConditions]) $ - renderOneLineText ("Ensured conditions found to be false: " <> pretty' @mods falseEnsures) - , EnsuresFalse falseEnsures - ) - Right _other -> - pure () - Left SMT.SMTSolverUnknown{} -> - pure () - Left other -> - liftIO $ Exception.throw other + lift (SMT.checkPredicates solver knownPredicates mempty $ Set.fromList ensuredConditions) >>= \case + Right (Just False) -> do + let falseEnsures = Predicate $ foldl1' AndTerm $ map coerce ensuredConditions + throwE + ( \ctx -> + ctx . logMessage $ + WithJsonMessage (object ["conditions" .= map (externaliseTerm . coerce) ensuredConditions]) $ + renderOneLineText ("Ensured conditions found to be false: " <> pretty' @mods falseEnsures) + , EnsuresFalse falseEnsures + ) + Right _other -> + pure () + Left SMT.SMTSolverUnknown{} -> + pure () + Left other -> + liftIO $ Exception.throw other lift $ pushConstraints $ Set.fromList ensuredConditions pure $ substituteInTerm subst rule.rhs where @@ -1004,19 +1003,19 @@ simplifyConstraint :: LoggerMIO io => KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> Set Predicate -> Predicate -> io (Either EquationFailure Predicate, SimplifierCache) -simplifyConstraint def mbApi mbSMT cache knownPredicates (Predicate p) = do - runEquationT def mbApi mbSMT cache knownPredicates $ (coerce <$>) . simplifyConstraint' True $ p +simplifyConstraint def mbApi smt cache knownPredicates (Predicate p) = do + runEquationT def mbApi smt cache knownPredicates $ (coerce <$>) . simplifyConstraint' True $ p simplifyConstraints :: LoggerMIO io => KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> [Predicate] -> io (Either EquationFailure [Predicate], SimplifierCache) diff --git a/booster/library/Booster/Pattern/Rewrite.hs b/booster/library/Booster/Pattern/Rewrite.hs index 56df80f675..542b4f3b3b 100644 --- a/booster/library/Booster/Pattern/Rewrite.hs +++ b/booster/library/Booster/Pattern/Rewrite.hs @@ -79,7 +79,7 @@ newtype RewriteT io a = RewriteT data RewriteConfig = RewriteConfig { definition :: KoreDefinition , llvmApi :: Maybe LLVM.API - , smtSolver :: Maybe SMT.SMTContext + , smtSolver :: SMT.SMTContext , doTracing :: Flag "CollectRewriteTraces" , logger :: Logger LogMessage , prettyModifiers :: ModifiersRep @@ -102,7 +102,7 @@ runRewriteT :: Flag "CollectRewriteTraces" -> KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> SimplifierCache -> RewriteT io a -> io (Either (RewriteFailed "Rewrite") (a, SimplifierCache)) @@ -355,7 +355,7 @@ applyRule pat@Pattern{ceilConditions} rule = stillUnclear <- lift $ filterOutKnownConstraints prior unclearRequires -- check unclear requires-clauses in the context of known constraints (prior) - mbSolver <- lift $ RewriteT $ (.smtSolver) <$> ask + solver <- lift $ RewriteT $ (.smtSolver) <$> ask let smtUnclear = do withContext CtxConstraint . withContext CtxAbort . logMessage $ @@ -366,34 +366,23 @@ applyRule pat@Pattern{ceilConditions} rule = failRewrite $ RuleConditionUnclear rule . coerce . foldl1 AndTerm $ map coerce stillUnclear - case mbSolver of - Just solver -> do - checkAllRequires <- - SMT.checkPredicates solver prior mempty (Set.fromList stillUnclear) - - case checkAllRequires of - Left SMT.SMTSolverUnknown{} -> - smtUnclear -- abort rewrite if a solver result was Unknown - Left other -> - liftIO $ Exception.throw other -- fail hard on other SMT errors - Right (Just False) -> do - -- requires is actually false given the prior - withContext CtxFailure $ logMessage ("Required clauses evaluated to #Bottom." :: Text) - RewriteRuleAppT $ pure NotApplied - Right (Just True) -> - pure () -- can proceed - Right Nothing -> - smtUnclear -- no implication could be determined - Nothing -> - unless (null stillUnclear) $ do - withContext CtxConstraint . withContext CtxAbort $ - logMessage $ - WithJsonMessage (object ["conditions" .= (externaliseTerm . coerce <$> stillUnclear)]) $ - renderOneLineText $ - "Uncertain about a condition(s) in rule, no SMT solver:" - <+> (hsep . punctuate comma . map (pretty' @mods) $ stillUnclear) - failRewrite $ - RuleConditionUnclear rule (head stillUnclear) + + checkAllRequires <- + SMT.checkPredicates solver prior mempty (Set.fromList stillUnclear) + + case checkAllRequires of + Left SMT.SMTSolverUnknown{} -> + smtUnclear -- abort rewrite if a solver result was Unknown + Left other -> + liftIO $ Exception.throw other -- fail hard on other SMT errors + Right (Just False) -> do + -- requires is actually false given the prior + withContext CtxFailure $ logMessage ("Required clauses evaluated to #Bottom." :: Text) + RewriteRuleAppT $ pure NotApplied + Right (Just True) -> + pure () -- can proceed + Right Nothing -> + smtUnclear -- no implication could be determined -- check ensures constraints (new) from rhs: stop and return `Trivial` if -- any are false, remove all that are trivially true, return the rest @@ -405,17 +394,16 @@ applyRule pat@Pattern{ceilConditions} rule = catMaybes <$> mapM (checkConstraint id trivialIfBottom prior) ruleEnsures -- check all new constraints together with the known side constraints - whenJust mbSolver $ \solver -> - (lift $ SMT.checkPredicates solver prior mempty (Set.fromList newConstraints)) >>= \case - Right (Just False) -> do - withContext CtxSuccess $ logMessage ("New constraints evaluated to #Bottom." :: Text) - RewriteRuleAppT $ pure Trivial - Right _other -> - pure () - Left SMT.SMTSolverUnknown{} -> - pure () - Left other -> - liftIO $ Exception.throw other + (lift $ SMT.checkPredicates solver prior mempty (Set.fromList newConstraints)) >>= \case + Right (Just False) -> do + withContext CtxSuccess $ logMessage ("New constraints evaluated to #Bottom." :: Text) + RewriteRuleAppT $ pure Trivial + Right _other -> + pure () + Left SMT.SMTSolverUnknown{} -> + pure () + Left other -> + liftIO $ Exception.throw other -- existential variables may be present in rule.rhs and rule.ensures, -- need to strip prefixes and freshen their names with respect to variables already @@ -714,7 +702,7 @@ performRewrite :: Flag "CollectRewriteTraces" -> KoreDefinition -> Maybe LLVM.API -> - Maybe SMT.SMTContext -> + SMT.SMTContext -> -- | maximum depth Maybe Natural -> -- | cut point rule labels @@ -723,7 +711,7 @@ performRewrite :: [Text] -> Pattern -> io (Natural, Seq (RewriteTrace ()), RewriteResult Pattern) -performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalLabels pat = do +performRewrite doTracing def mLlvmLibrary smtSolver mbMaxDepth cutLabels terminalLabels pat = do (rr, RewriteStepsState{counter, traces}) <- flip runStateT rewriteStart $ doSteps False pat pure (counter, traces, rr) @@ -748,7 +736,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL simplifyP p = withContext CtxSimplify $ do st <- get let cache = st.simplifierCache - evaluatePattern def mLlvmLibrary mSolver cache p >>= \(res, newCache) -> do + evaluatePattern def mLlvmLibrary smtSolver cache p >>= \(res, newCache) -> do updateCache newCache case res of Right newPattern -> do @@ -815,7 +803,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL doTracing def mLlvmLibrary - mSolver + smtSolver simplifierCache (withPatternContext pat' $ rewriteStep cutLabels terminalLabels pat') >>= \case diff --git a/booster/library/Booster/SMT/Base.hs b/booster/library/Booster/SMT/Base.hs index 2aacbd872b..e2f9d3d149 100644 --- a/booster/library/Booster/SMT/Base.hs +++ b/booster/library/Booster/SMT/Base.hs @@ -81,9 +81,8 @@ data Response = Success -- for command_ | Sat | Unsat - | Unknown + | Unknown (Maybe Text) | Values [(SExpr, Value)] - | ReasonUnknown Text | Error BS.ByteString deriving stock (Eq, Ord, Show) diff --git a/booster/library/Booster/SMT/Interface.hs b/booster/library/Booster/SMT/Interface.hs index 06c6d6eaff..d88a8b06f5 100644 --- a/booster/library/Booster/SMT/Interface.hs +++ b/booster/library/Booster/SMT/Interface.hs @@ -11,6 +11,7 @@ module Booster.SMT.Interface ( defaultSMTOptions, -- re-export SMTError (..), initSolver, + noSolver, finaliseSolver, getModelFor, checkPredicates, @@ -32,6 +33,7 @@ import Data.Either.Extra (fromLeft', fromRight') import Data.IORef import Data.Map (Map) import Data.Map qualified as Map +import Data.Maybe (fromMaybe) import Data.Set (Set) import Data.Set qualified as Set import Data.Text as Text (Text, pack, unlines, unwords) @@ -52,7 +54,7 @@ import Booster.Syntax.Json.Externalise (externaliseTerm) data SMTError = GeneralSMTError Text | SMTTranslationError Text - | SMTSolverUnknown Text (Set Predicate) (Set Predicate) + | SMTSolverUnknown (Maybe Text) (Set Predicate) (Set Predicate) deriving (Eq, Show) instance Exception SMTError @@ -103,6 +105,21 @@ initSolver def smtOptions = Log.withContext Log.CtxSMT $ do Log.logMessage ("Successfully initialised SMT solver with " <> (Text.pack . show $ smtOptions)) pure ctxt +{- | Returns an @SMTContext@ with no solver handle, essentially just a dummy that always returns `Unknown` for any command that is attempted. +This can be useful for unit testing or in case the user wants to call the booster without Z3. +-} +noSolver :: MonadIO io => io SMT.SMTContext +noSolver = do + solverClose <- liftIO $ newIORef $ pure () + pure + SMTContext + { mbSolver = Nothing + , solverClose + , mbTranscriptHandle = Nothing + , prelude = [] + , options = defaultSMTOptions{retryLimit = Just 0} + } + -- | Hot-swap @SMTOptions@ in the active @SMTContext@, update the query timeout swapSmtOptions :: forall io. Log.LoggerMIO io => SMTOptions -> SMT io () swapSmtOptions smtOptions = do @@ -116,11 +133,14 @@ hardResetSolver :: forall io. Log.LoggerMIO io => SMTOptions -> SMT io () hardResetSolver smtOptions = do Log.logMessage ("Restarting SMT solver" :: Text) ctxt <- SMT get - liftIO $ join $ readIORef ctxt.solverClose - (solver, handle) <- connectToSolver - liftIO $ do - writeIORef ctxt.solver solver - writeIORef ctxt.solverClose $ Backend.close handle + case ctxt.mbSolver of + Nothing -> pure () + Just solverRef -> do + liftIO $ join $ readIORef ctxt.solverClose + (solver, handle) <- connectToSolver + liftIO $ do + writeIORef solverRef solver + writeIORef ctxt.solverClose $ Backend.close handle checkPrelude swapSmtOptions smtOptions @@ -201,9 +221,7 @@ getModelFor ctxt ps subst interactWithSolver transState smtAsserts >>= \case Left response -> case response of - -- note that 'Unknown' will never be returned by 'interactWithSolver', as it will always be - -- converted to 'ReasonUnknown{}'. - ReasonUnknown{} -> do + Unknown{} -> do case opts.retryLimit of Just x | x > 0 -> do let newOpts = opts{timeout = 2 * opts.timeout, retryLimit = Just $ x - 1} @@ -240,12 +258,7 @@ getModelFor ctxt ps subst result <- case satResponse of Error msg -> throwSMT' $ BS.unpack msg Unsat -> pure $ Left Unsat - Unknown{} -> do - -- we always request the reason, even if we later retry, - -- to avoid delayed popping from the assertion stack in 'solve' - reasonUnknown <- smtRun SMT.GetReasonUnknown - pure $ Left reasonUnknown - r@ReasonUnknown{} -> pure $ Left r + r@Unknown{} -> pure $ Left r Values{} -> throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse Success -> throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse Sat -> Right <$> extractModel transState @@ -379,21 +392,26 @@ checkPredicates ctxt givenPs givenSubst psToCheck pure Nothing (Sat, Unsat) -> pure . Just $ True (Unsat, Sat) -> pure . Just $ False - (Unknown, _) -> retry smtGiven sexprsToCheck transState - (_, Unknown) -> retry smtGiven sexprsToCheck transState + (Unknown reason, _) -> retry smtGiven sexprsToCheck transState reason + (_, Unknown reason) -> retry smtGiven sexprsToCheck transState reason other -> throwE . GeneralSMTError $ ("Unexpected result while checking a condition: " :: Text) <> Text.pack (show other) - retry :: [DeclareCommand] -> [SExpr] -> TranslationState -> ExceptT SMTError (SMT io) (Maybe Bool) - retry smtGiven sexprsToCheck transState = do + retry :: + [DeclareCommand] -> + [SExpr] -> + TranslationState -> + Maybe Text -> + ExceptT SMTError (SMT io) (Maybe Bool) + retry smtGiven sexprsToCheck transState reasonUnknown = do opts <- lift . SMT $ gets (.options) case opts.retryLimit of Just x | x > 0 -> do let newOpts = opts{timeout = 2 * opts.timeout, retryLimit = Just $ x - 1} lift $ hardResetSolver newOpts solve smtGiven sexprsToCheck transState - _ -> failBecauseUnknown + _ -> failBecauseUnknown reasonUnknown translated :: Either Text (([DeclareCommand], [SExpr]), TranslationState) translated = SMT.runTranslator $ do @@ -407,18 +425,12 @@ checkPredicates ctxt givenPs givenSubst psToCheck mapM (SMT.translateTerm . coerce) $ Set.toList psToCheck pure (smtSubst <> smtPs, toCheck) - failBecauseUnknown :: ExceptT SMTError (SMT io) (Maybe Bool) - failBecauseUnknown = - smtRun GetReasonUnknown >>= \case - ReasonUnknown reason -> do - Log.withContext Log.CtxAbort $ - Log.logMessage $ - "Returned Unknown. Reason: " <> reason - throwE $ SMTSolverUnknown reason givenPs psToCheck - other -> do - let msg = "Unexpected result while calling ':reason-unknown': " <> show other - Log.withContext Log.CtxAbort $ Log.logMessage $ Text.pack msg - throwSMT' msg + failBecauseUnknown :: Maybe Text -> ExceptT SMTError (SMT io) (Maybe Bool) + failBecauseUnknown reason = do + Log.withContext Log.CtxAbort $ + Log.logMessage $ + "Returned Unknown. Reason: " <> fromMaybe "UNKNOWN" reason + throwE $ SMTSolverUnknown reason givenPs psToCheck -- Given the known truth and the expressions to check, -- interact with the solver to establish the validity of the expressions. @@ -439,7 +451,7 @@ checkPredicates ctxt givenPs givenSubst psToCheck Unsat -> do Log.logMessage ("Inconsistent ground truth" :: Text) pure (Unsat, Unsat) - Unknown -> do + Unknown reason -> do Log.getPrettyModifiers >>= \case ModifiersRep (_ :: FromModifiersT mods => Proxy mods) -> do Log.withContext Log.CtxDetail @@ -449,7 +461,7 @@ checkPredicates ctxt givenPs givenSubst psToCheck $ Pretty.renderOneLineText $ "Unknown ground truth: " <+> (hsep . punctuate (slash <> backslash) . map (pretty' @mods) . Set.toList $ givenPs) - pure (Unknown, Unknown) + pure (Unknown reason, Unknown reason) _ -> do -- save ground truth for 2nd check smtRun_ Push diff --git a/booster/library/Booster/SMT/LowLevelCodec.hs b/booster/library/Booster/SMT/LowLevelCodec.hs index 3ed849c6f4..2c5357b6ce 100644 --- a/booster/library/Booster/SMT/LowLevelCodec.hs +++ b/booster/library/Booster/SMT/LowLevelCodec.hs @@ -36,13 +36,13 @@ responseP = A.string "success" $> Success -- UNUSED? <|> A.string "sat" $> Sat <|> A.string "unsat" $> Unsat - <|> A.string "unknown" $> Unknown + <|> A.string "unknown" $> Unknown Nothing <|> A.char '(' *> errOrValuesOrReasonUnknownP <* A.char ')' errOrValuesOrReasonUnknownP :: A.Parser Response errOrValuesOrReasonUnknownP = A.string "error " *> (Error <$> stringP) - <|> A.string ":reason-unknown " *> (ReasonUnknown . decodeUtf8 <$> stringP) + <|> A.string ":reason-unknown " *> (Unknown . Just . decodeUtf8 <$> stringP) <|> Values <$> A.many1' pairP stringP :: A.Parser BS.ByteString diff --git a/booster/library/Booster/SMT/Runner.hs b/booster/library/Booster/SMT/Runner.hs index 55155c76ed..e67880f487 100644 --- a/booster/library/Booster/SMT/Runner.hs +++ b/booster/library/Booster/SMT/Runner.hs @@ -75,7 +75,7 @@ data SMTContext = SMTContext { options :: SMTOptions , -- use IORef here to ensure we only ever retain one pointer to the solver, -- otherwise the solverClose action does not actually terminate the solver instance - solver :: IORef Backend.Solver + mbSolver :: Maybe (IORef Backend.Solver) , solverClose :: IORef (IO ()) , mbTranscriptHandle :: Maybe Handle , prelude :: [DeclareCommand] @@ -111,7 +111,7 @@ mkContext opts prelude = do liftIO $ BS.hPutStrLn h "; solver initialised\n;;;;;;;;;;;;;;;;;;;;;;;" pure SMTContext - { solver + { mbSolver = Just solver , solverClose , mbTranscriptHandle , prelude @@ -180,22 +180,27 @@ runCmd :: forall cmd io. (SMTEncode cmd, LoggerMIO io) => cmd -> SMT io Response runCmd cmd = do let cmdBS = encode cmd ctxt <- SMT get - whenJust ctxt.mbTranscriptHandle $ \h -> do - whenJust (comment cmd) $ \c -> - liftIO (BS.hPutBuilder h c) - liftIO (BS.hPutBuilder h $ cmdBS <> "\n") - output <- (liftIO $ readIORef ctxt.solver) >>= \solver -> run_ cmd solver cmdBS - let result = readResponse output - whenJust ctxt.mbTranscriptHandle $ - liftIO . flip BS.hPutStrLn (BS.pack $ "; " <> show output <> ", parsed as " <> show result <> "\n") - when (isError result) $ - logMessage $ - "SMT solver reports: " <> pack (show result) - pure result - where - isError :: Response -> Bool - isError Error{} = True - isError _other = False + case ctxt.mbSolver of + Nothing -> pure $ Unknown (Just "server started without SMT solver") + Just solverRef -> do + whenJust ctxt.mbTranscriptHandle $ \h -> do + whenJust (comment cmd) $ \c -> + liftIO (BS.hPutBuilder h c) + liftIO (BS.hPutBuilder h $ cmdBS <> "\n") + output <- (liftIO $ readIORef solverRef) >>= \solver -> run_ cmd solver cmdBS + let result = readResponse output + whenJust ctxt.mbTranscriptHandle $ + liftIO . flip BS.hPutStrLn (BS.pack $ "; " <> show output <> ", parsed as " <> show result <> "\n") + case result of + Error{} -> do + logMessage $ + "SMT solver reports: " <> pack (show result) + pure result + Unknown Nothing -> + runCmd GetReasonUnknown >>= \case + unknownWithReason@(Unknown (Just _)) -> pure unknownWithReason + _ -> pure result + _ -> pure result instance SMTEncode DeclareCommand where encode = encodeDeclaration diff --git a/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs b/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs index cc51b3faa6..0aac1a2696 100644 --- a/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs +++ b/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs @@ -19,7 +19,6 @@ import Control.Monad.Logger (runNoLoggingT) import Data.Map (Map) import Data.Map qualified as Map import Data.Text (Text) -import GHC.IO.Unsafe (unsafePerformIO) import Test.Tasty import Test.Tasty.HUnit @@ -30,9 +29,11 @@ import Booster.Pattern.Base import Booster.Pattern.Bool import Booster.Pattern.Index (CellIndex (..), TermIndex (..)) import Booster.Pattern.Util (sortOfTerm) +import Booster.SMT.Interface (noSolver) import Booster.Syntax.Json.Internalise (trm) import Booster.Util (Flag (..)) import Test.Booster.Fixture hiding (inj) +import Test.Booster.Util ((@?>>=)) inj :: Symbol inj = injectionSymbol @@ -43,21 +44,21 @@ test_evaluateFunction = "Evaluating functions using rules without side conditions" [ -- f1(a) => a testCase "Simple function evaluation" $ do - eval TopDown [trm| f1{}(con2{}(A:SomeSort{})) |] @?= Right [trm| con2{}(A:SomeSort{}) |] - eval BottomUp [trm| f1{}(con2{}(A:SomeSort{})) |] @?= Right [trm| con2{}(A:SomeSort{}) |] + eval TopDown [trm| f1{}(con2{}(A:SomeSort{})) |] @?>>= Right [trm| con2{}(A:SomeSort{}) |] + eval BottomUp [trm| f1{}(con2{}(A:SomeSort{})) |] @?>>= Right [trm| con2{}(A:SomeSort{}) |] , -- f2(f1(f1(con2(a)))) => f2(con2(a)). f2 is marked as partial, so not evaluating testCase "Nested function applications, one not to be evaluated" $ do let subj = [trm| f2{}(f1{}(f1{}(con2{}(A:SomeSort{})))) |] goal = [trm| f2{}(con2{}(A:SomeSort{})) |] - eval TopDown subj @?= Right goal - eval BottomUp subj @?= Right goal + eval TopDown subj @?>>= Right goal + eval BottomUp subj @?>>= Right goal , -- f1(f2(f1(con2(a)))) => f1(f2(con2(a))). Again f2 partial, so not evaluating, -- therefore f1(x) => x not applied to unevaluated value testCase "Nested function applications with partial function inside" $ do let subj = [trm| f1{}(f2{}(f1{}(con2{}(A:SomeSort{})))) |] goal = [trm| f1{}(f2{}(con2{}(A:SomeSort{}))) |] - eval TopDown subj @?= Right goal - eval BottomUp subj @?= Right goal + eval TopDown subj @?>>= Right goal + eval BottomUp subj @?>>= Right goal , -- f1(con1(con1(..con1(con2(a))..))) => con2(con2(..con2(a)..)) -- using f1(con1(X)) => con2(X) repeatedly testCase "Recursive evaluation" $ do @@ -66,39 +67,37 @@ test_evaluateFunction = apply f = app f . (: []) n `times` f = foldr (.) id (replicate n $ apply f) -- top-down evaluation: a single iteration is enough - eval TopDown (subj 101) @?= Right (101 `times` con2 $ a) + eval TopDown (subj 101) @?>>= Right (101 `times` con2 $ a) -- bottom-up evaluation: `depth` many iterations - eval BottomUp (subj 100) @?= Right (100 `times` con2 $ a) - isTooManyIterations $ eval BottomUp (subj 101) + eval BottomUp (subj 100) @?>>= Right (100 `times` con2 $ a) + isTooManyIterations =<< eval BottomUp (subj 101) , -- con3(f1(con2(a)), f1(con1(con2(b)))) => con3(con2(a), con2(con2(b))) testCase "Several function calls inside a constructor" $ do eval TopDown [trm| con3{}(f1{}(con2{}(A:SomeSort{})), f1{}(con1{}(con2{}(B:SomeSort{})))) |] - @?= Right [trm| con3{}(con2{}(A:SomeSort{}), con2{}(con2{}(B:SomeSort{}))) |] + @?>>= Right [trm| con3{}(con2{}(A:SomeSort{}), con2{}(con2{}(B:SomeSort{}))) |] , -- f1(inj{sub,some}(con4(a, b))) => f1(a) => a (not using f1-is-identity) testCase "Matching uses priorities" $ do eval TopDown [trm| f1{}(inj{AnotherSort{}, SomeSort{}}(con4{}(A:SomeSort{}, B:SomeSort{}))) |] - @?= Right [trm| A:SomeSort{} |] + @?>>= Right [trm| A:SomeSort{} |] , -- f1(con1("hey")) unmodified, since "hey" is concrete testCase "f1 with concrete argument, constraints prevent rule application" $ do let subj = [trm| f1{}(con1{}( \dv{SomeSort{}}("hey")) ) |] - eval TopDown subj @?= Right subj - eval BottomUp subj @?= Right subj + eval TopDown subj @?>>= Right subj + eval BottomUp subj @?>>= Right subj , testCase "f2 with symbolic argument, constraint prevents rule application" $ do let subj = [trm| f2{}(con1{}(A:SomeSort{})) |] - eval TopDown subj @?= Right subj - eval BottomUp subj @?= Right subj + eval TopDown subj @?>>= Right subj + eval BottomUp subj @?>>= Right subj , testCase "f2 with concrete argument, satisfying constraint" $ do let subj = [trm| f2{}(con1{}(\dv{SomeSort{}}("hey"))) |] result = [trm| f2{}(\dv{SomeSort{}}("hey")) |] - eval TopDown subj @?= Right result - eval BottomUp subj @?= Right result + eval TopDown subj @?>>= Right result + eval BottomUp subj @?>>= Right result ] where - eval direction = - unsafePerformIO - . runNoLoggingT - . (fst <$>) - . evaluateTerm direction funDef Nothing Nothing mempty + eval direction t = do + ns <- noSolver + runNoLoggingT $ fst <$> evaluateTerm direction funDef Nothing ns mempty t isTooManyIterations (Left (TooManyIterations _n _ _)) = pure () isTooManyIterations (Left err) = assertFailure $ "Unexpected error " <> show err @@ -110,26 +109,24 @@ test_simplify = "Performing simplifications" [ testCase "No simplification applies" $ do let subj = [trm| f1{}(f2{}(A:SomeSort{})) |] - simpl TopDown subj @?= Right subj - simpl BottomUp subj @?= Right subj + simpl TopDown subj @?>>= Right subj + simpl BottomUp subj @?>>= Right subj , -- con1(con2(f2(a))) => con2(f2(a)) testCase "Simplification of constructors" $ do let subj = app con1 [app con2 [app f2 [a]]] - simpl TopDown subj @?= Right (app con2 [app f2 [a]]) - simpl BottomUp subj @?= Right (app con2 [app f2 [a]]) + simpl TopDown subj @?>>= Right (app con2 [app f2 [a]]) + simpl BottomUp subj @?>>= Right (app con2 [app f2 [a]]) , -- con3(f2(a), f2(a)) => inj{sub,some}(con4(f2(a), f2(a))) testCase "Simplification with argument match" $ do let subj = [trm| con3{}(f2{}(A:SomeSort{}), f2{}(A:SomeSort{})) |] result = [trm| inj{AnotherSort{}, SomeSort{}}(con4{}(f2{}(A:SomeSort{}), f2{}(A:SomeSort{}))) |] - simpl TopDown subj @?= Right result - simpl BottomUp subj @?= Right result + simpl TopDown subj @?>>= Right result + simpl BottomUp subj @?>>= Right result ] where - simpl direction = - unsafePerformIO - . runNoLoggingT - . (fst <$>) - . evaluateTerm direction simplDef Nothing Nothing mempty + simpl direction t = do + ns <- noSolver + runNoLoggingT $ fst <$> evaluateTerm direction simplDef Nothing ns mempty t a = var "A" someSort test_simplifyPattern :: TestTree @@ -138,28 +135,26 @@ test_simplifyPattern = "Performing Pattern simplifications" [ testCase "No simplification applies" $ do let subj = [trm| f1{}(f2{}(A:SomeSort{})) |] - simpl (Pattern_ subj) @?= Right (Pattern_ subj) - simpl (Pattern_ subj) @?= Right (Pattern_ subj) + simpl (Pattern_ subj) @?>>= Right (Pattern_ subj) + simpl (Pattern_ subj) @?>>= Right (Pattern_ subj) , -- con1(con2(f2(a))) => con2(f2(a)) testCase "Simplification of constructors" $ do let subj = app con1 [app con2 [app f2 [a]]] simpl (Pattern_ subj) - @?= Right (Pattern_ $ app con2 [app f2 [a]]) + @?>>= Right (Pattern_ $ app con2 [app f2 [a]]) simpl (Pattern_ subj) - @?= Right (Pattern_ $ app con2 [app f2 [a]]) + @?>>= Right (Pattern_ $ app con2 [app f2 [a]]) , -- con3(f2(a), f2(a)) => inj{sub,some}(con4(f2(a), f2(a))) testCase "Simplification with argument match" $ do let subj = Pattern_ [trm| con3{}(f2{}(A:SomeSort{}), f2{}(A:SomeSort{})) |] result = Pattern_ [trm| inj{AnotherSort{}, SomeSort{}}(con4{}(f2{}(A:SomeSort{}), f2{}(A:SomeSort{}))) |] - simpl subj @?= Right result + simpl subj @?>>= Right result ] where - simpl = - unsafePerformIO - . runNoLoggingT - . (fst <$>) - . evaluatePattern simplDef Nothing Nothing mempty + simpl t = do + ns <- noSolver + runNoLoggingT $ fst <$> evaluatePattern simplDef Nothing ns mempty t a = var "A" someSort test_simplifyConstraint :: TestTree @@ -218,18 +213,17 @@ test_simplifyConstraint = [ testCase name $ let subj = EqualsK (KSeq (sortOfTerm lhs) lhs) (KSeq (sortOfTerm rhs) rhs) - in simpl (Predicate subj) @?= Right (Predicate (exp1 subj)) + in simpl (Predicate subj) @?>>= Right (Predicate (exp1 subj)) , testCase (name <> " (flipped)") $ let subj = EqualsK (KSeq (sortOfTerm rhs) rhs) (KSeq (sortOfTerm lhs) lhs) - in simpl (Predicate subj) @?= Right (Predicate (exp2 subj)) + in simpl (Predicate subj) @?>>= Right (Predicate (exp2 subj)) ] - simpl = - unsafePerformIO - . runNoLoggingT - . (fst <$>) - . simplifyConstraint testDefinition Nothing Nothing mempty mempty + simpl t = + do + ns <- noSolver + runNoLoggingT $ fst <$> simplifyConstraint testDefinition Nothing ns mempty mempty t test_errors :: TestTree test_errors = @@ -241,8 +235,8 @@ test_errors = subj = f $ app con1 [a] loopTerms = [f $ app con1 [a], f $ app con2 [a], f $ app con3 [a, a], f $ app con1 [a]] - isLoop loopTerms . unsafePerformIO . runNoLoggingT $ - fst <$> evaluateTerm TopDown loopDef Nothing Nothing mempty subj + ns <- noSolver + isLoop loopTerms =<< (runNoLoggingT $ fst <$> evaluateTerm TopDown loopDef Nothing ns mempty subj) ] where isLoop ts (Left (EquationLoop ts')) = ts @?= ts' diff --git a/booster/unit-tests/Test/Booster/Pattern/Rewrite.hs b/booster/unit-tests/Test/Booster/Pattern/Rewrite.hs index f5a4c1aa37..5c30f0abec 100644 --- a/booster/unit-tests/Test/Booster/Pattern/Rewrite.hs +++ b/booster/unit-tests/Test/Booster/Pattern/Rewrite.hs @@ -15,7 +15,6 @@ import Data.List.NonEmpty qualified as NE import Data.Map (Map) import Data.Map qualified as Map import Data.Text (Text) -import GHC.IO.Unsafe (unsafePerformIO) import Numeric.Natural import Test.Tasty import Test.Tasty.HUnit @@ -25,10 +24,12 @@ import Booster.Definition.Base import Booster.Pattern.Base import Booster.Pattern.Index (CellIndex (..), TermIndex (..)) import Booster.Pattern.Rewrite +import Booster.SMT.Interface (noSolver) import Booster.Syntax.Json.Internalise (trm) import Booster.Syntax.ParsedKore.Internalise (symb) import Booster.Util (Flag (..)) import Test.Booster.Fixture hiding (inj) +import Test.Booster.Util ((@?>>=)) test_rewriteStep :: TestTree test_rewriteStep = @@ -240,39 +241,39 @@ rulePriority = ) ] -runWith :: Term -> Either (RewriteFailed "Rewrite") (RewriteResult Pattern) +runWith :: Term -> IO (Either (RewriteFailed "Rewrite") (RewriteResult Pattern)) runWith t = - second fst $ - unsafePerformIO - ( runNoLoggingT $ - runRewriteT NoCollectRewriteTraces def Nothing Nothing mempty (rewriteStep [] [] $ Pattern_ t) - ) + second fst <$> do + ns <- noSolver + runNoLoggingT $ + runRewriteT NoCollectRewriteTraces def Nothing ns mempty (rewriteStep [] [] $ Pattern_ t) rewritesTo :: Term -> (Text, Term) -> IO () t1 `rewritesTo` (lbl, t2) = - runWith t1 @?= Right (RewriteFinished (Just lbl) (Just mockUniqueId) $ Pattern_ t2) + runWith t1 @?>>= Right (RewriteFinished (Just lbl) (Just mockUniqueId) $ Pattern_ t2) getsStuck :: Term -> IO () getsStuck t1 = - runWith t1 @?= Right (RewriteStuck $ Pattern_ t1) + runWith t1 @?>>= Right (RewriteStuck $ Pattern_ t1) branchesTo :: Term -> [(Text, Term)] -> IO () t `branchesTo` ts = runWith t - @?= Right + @?>>= Right (RewriteBranch (Pattern_ t) $ NE.fromList $ map (\(lbl, t') -> (lbl, mockUniqueId, Pattern_ t')) ts) failsWith :: Term -> RewriteFailed "Rewrite" -> IO () failsWith t err = - runWith t @?= Left err + runWith t @?>>= Left err ---------------------------------------- -- tests for performRewrite (iterated rewrite in IO with logging) runRewrite :: Term -> IO (Natural, RewriteResult Term) runRewrite t = do + ns <- noSolver (counter, _, res) <- - runNoLoggingT $ performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [] $ Pattern_ t + runNoLoggingT $ performRewrite NoCollectRewriteTraces def Nothing ns Nothing [] [] $ Pattern_ t pure (counter, fmap (.term) res) aborts :: RewriteFailed "Rewrite" -> Term -> IO () @@ -413,9 +414,10 @@ supportsDepthControl = where rewritesToDepth :: MaxDepth -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO () rewritesToDepth (MaxDepth depth) (Steps n) t t' f = do + ns <- noSolver (counter, _, res) <- runNoLoggingT $ - performRewrite NoCollectRewriteTraces def Nothing Nothing (Just depth) [] [] $ + performRewrite NoCollectRewriteTraces def Nothing ns (Just depth) [] [] $ Pattern_ t (counter, fmap (.term) res) @?= (n, f t') @@ -467,9 +469,10 @@ supportsCutPoints = where rewritesToCutPoint :: Text -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO () rewritesToCutPoint lbl (Steps n) t t' f = do + ns <- noSolver (counter, _, res) <- runNoLoggingT $ - performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [lbl] [] $ + performRewrite NoCollectRewriteTraces def Nothing ns Nothing [lbl] [] $ Pattern_ t (counter, fmap (.term) res) @?= (n, f t') @@ -501,5 +504,6 @@ supportsTerminalRules = rewritesToTerminal lbl (Steps n) t t' f = do (counter, _, res) <- runNoLoggingT $ do - performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [lbl] $ Pattern_ t + ns <- noSolver + performRewrite NoCollectRewriteTraces def Nothing ns Nothing [] [lbl] $ Pattern_ t (counter, fmap (.term) res) @?= (n, f t') diff --git a/booster/unit-tests/Test/Booster/SMT/LowLevel.hs b/booster/unit-tests/Test/Booster/SMT/LowLevel.hs index c83ac0ce15..21ece3ee36 100644 --- a/booster/unit-tests/Test/Booster/SMT/LowLevel.hs +++ b/booster/unit-tests/Test/Booster/SMT/LowLevel.hs @@ -92,7 +92,7 @@ responseParsing = "Response parsing" [ "sat" `parsesTo` Sat , "unsat" `parsesTo` Unsat - , "unknown" `parsesTo` Unknown + , "unknown" `parsesTo` Unknown Nothing , "success" `parsesTo` Success , "(error \"Something was wrong\")" `parsesTo` Error "Something was wrong" , "((x 0))" `parsesTo` Values [(Atom "x", SMT.Int 0)] diff --git a/booster/unit-tests/Test/Booster/Util.hs b/booster/unit-tests/Test/Booster/Util.hs index d6b53abb11..7bfa0f7e3a 100644 --- a/booster/unit-tests/Test/Booster/Util.hs +++ b/booster/unit-tests/Test/Booster/Util.hs @@ -6,11 +6,14 @@ module Test.Booster.Util ( gitDiff, testGoldenVsString, testGoldenVsFile, + (@?>>=), ) where import Data.ByteString.Lazy (ByteString) +import GHC.Stack (HasCallStack) import Test.Tasty import Test.Tasty.Golden +import Test.Tasty.HUnit (Assertion, (@?=)) gitDiff :: FilePath -> FilePath -> [String] gitDiff ref new = @@ -21,3 +24,8 @@ testGoldenVsString name = goldenVsStringDiff name gitDiff testGoldenVsFile :: TestName -> FilePath -> FilePath -> IO () -> TestTree testGoldenVsFile name = goldenVsFileDiff name gitDiff + +infix 9 @?>>= + +(@?>>=) :: (Eq a, Show a, HasCallStack) => IO a -> a -> Assertion +ma @?>>= a' = ma >>= \a -> a @?= a'