diff --git a/deps/k_release b/deps/k_release index 2a0698a53..3bcc06003 100644 --- a/deps/k_release +++ b/deps/k_release @@ -1 +1 @@ -6.1.31 +6.1.52 diff --git a/package/version b/package/version index ac7d5f0eb..466fedc56 100644 --- a/package/version +++ b/package/version @@ -1 +1 @@ -0.1.519 +0.1.543 diff --git a/poetry.lock b/poetry.lock index 2f5c23971..48b8b6e3b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "attrs" @@ -919,6 +919,17 @@ files = [ {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"}, ] +[[package]] +name = "xdg-base-dirs" +version = "6.0.1" +description = "Variables defined by the XDG Base Directory Specification" +optional = false +python-versions = ">=3.10,<4.0" +files = [ + {file = "xdg_base_dirs-6.0.1-py3-none-any.whl", hash = "sha256:63f6ebc1721ced2e86c340856e004ef829501a30a37e17079c52cfaf0e1741b9"}, + {file = "xdg_base_dirs-6.0.1.tar.gz", hash = "sha256:b4c8f4ba72d1286018b25eea374ec6fbf4fddda3d4137edf50de95de53e195a6"}, +] + [[package]] name = "zipp" version = "3.17.0" @@ -937,4 +948,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "95ca457bf41fb60498049c04899bfb9a0d12fa3635ef1aa25c67929033140e27" +content-hash = "420202e89d34a391c914cdf05d81f67fb27867ebe5765c24bb46b31760d8ec28" diff --git a/pyproject.toml b/pyproject.toml index 21b3c8e4b..6ef193834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "pyk" -version = "0.1.519" +version = "0.1.543" description = "" authors = [ "Runtime Verification, Inc. ", @@ -20,6 +20,7 @@ psutil = "5.9.5" pybind11 = "^2.10.3" textual = "^0.27.0" tomli = "^2.0.1" +xdg-base-dirs = "^6.0.1" [tool.poetry.group.dev.dependencies] autoflake = "*" @@ -43,6 +44,7 @@ types-psutil = "^5.9.5.10" pyk = "pyk.__main__:main" pyk-covr = "pyk.kcovr:main" kbuild = "pyk.kbuild.__main__:main" +kdist = "pyk.kdist.__main__:main" krepl = "pyk.krepl.__main__:main" kore-exec-covr = "pyk.kore_exec_covr.__main__:main" diff --git a/src/pyk/__init__.py b/src/pyk/__init__.py index 2ba1663b3..88ab000d4 100644 --- a/src/pyk/__init__.py +++ b/src/pyk/__init__.py @@ -6,4 +6,4 @@ from typing import Final -K_VERSION: Final = '6.1.31' +K_VERSION: Final = '6.1.52' diff --git a/src/pyk/kcfg/explore.py b/src/pyk/kcfg/explore.py index 8d4154e7d..ed9e6e3bc 100644 --- a/src/pyk/kcfg/explore.py +++ b/src/pyk/kcfg/explore.py @@ -3,7 +3,7 @@ import logging from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, final +from typing import TYPE_CHECKING, NamedTuple, final from ..cterm import CSubst, CTerm from ..kast.inner import KApply, KLabel, KRewrite, KVariable, Subst @@ -47,6 +47,15 @@ _LOGGER: Final = logging.getLogger(__name__) +class CTermExecute(NamedTuple): + state: CTerm + unknown_predicate: KInner | None + next_states: tuple[CTerm, ...] + depth: int + vacuous: bool + logs: tuple[LogEntry, ...] + + class KCFGExplore: kprint: KPrint _kore_client: KoreClient @@ -77,40 +86,42 @@ def cterm_execute( cut_point_rules: Iterable[str] | None = None, terminal_rules: Iterable[str] | None = None, module_name: str | None = None, - ) -> tuple[KInner | None, bool, int, CTerm, list[CTerm], tuple[LogEntry, ...]]: + ) -> CTermExecute: _LOGGER.debug(f'Executing: {cterm}') kore = self.kprint.kast_to_kore(cterm.kast, GENERATED_TOP_CELL) - er = self._kore_client.execute( + response = self._kore_client.execute( kore, max_depth=depth, cut_point_rules=cut_point_rules, terminal_rules=terminal_rules, module_name=module_name, - log_successful_rewrites=self._trace_rewrites if self._trace_rewrites else None, - log_failed_rewrites=self._trace_rewrites if self._trace_rewrites else None, - log_successful_simplifications=self._trace_rewrites if self._trace_rewrites else None, - log_failed_simplifications=self._trace_rewrites if self._trace_rewrites else None, + log_successful_rewrites=self._trace_rewrites, + log_failed_rewrites=self._trace_rewrites, + log_successful_simplifications=self._trace_rewrites, + log_failed_simplifications=self._trace_rewrites, ) - _is_vacuous = er.reason is StopReason.VACUOUS - depth = er.depth - next_state = CTerm.from_kast(self.kprint.kore_to_kast(er.state.kore)) - _next_states = er.next_states if er.next_states is not None else [] - next_states = [CTerm.from_kast(self.kprint.kore_to_kast(ns.kore)) for ns in _next_states] - next_states = [cterm for cterm in next_states if not cterm.is_bottom] - if len(next_states) == 1 and len(next_states) < len(_next_states): - return None, _is_vacuous, depth + 1, next_states[0], [], er.logs - elif len(next_states) == 1: - if er.reason == StopReason.CUT_POINT_RULE: - return None, _is_vacuous, depth, next_state, next_states, er.logs - else: - next_states = [] unknown_predicate = None - if isinstance(er, AbortedResult): + if isinstance(response, AbortedResult): unknown_predicate = ( - self.kprint.kore_to_kast(er.unknown_predicate) if er.unknown_predicate is not None else None + self.kprint.kore_to_kast(response.unknown_predicate) if response.unknown_predicate is not None else None ) - return unknown_predicate, _is_vacuous, depth, next_state, next_states, er.logs + + state = CTerm.from_kast(self.kprint.kore_to_kast(response.state.kore)) + resp_next_states = response.next_states or () + next_states = tuple(CTerm.from_kast(self.kprint.kore_to_kast(ns.kore)) for ns in resp_next_states) + + assert all(not cterm.is_bottom for cterm in next_states) + assert len(next_states) != 1 or response.reason is StopReason.CUT_POINT_RULE + + return CTermExecute( + state=state, + unknown_predicate=unknown_predicate, + next_states=next_states, + depth=response.depth, + vacuous=response.reason is StopReason.VACUOUS, + logs=response.logs, + ) def cterm_simplify(self, cterm: CTerm) -> tuple[KInner | None, CTerm, tuple[LogEntry, ...]]: _LOGGER.debug(f'Simplifying: {cterm}') @@ -319,16 +330,14 @@ def step( if len(successors) != 0 and type(successors[0]) is KCFG.Split: raise ValueError(f'Cannot take step from split node {self.id}: {shorten_hashes(node.id)}') _LOGGER.info(f'Taking {depth} steps from node {self.id}: {shorten_hashes(node.id)}') - _, _, actual_depth, cterm, next_cterms, next_node_logs = self.cterm_execute( - node.cterm, depth=depth, module_name=module_name - ) - if actual_depth != depth: - raise ValueError(f'Unable to take {depth} steps from node, got {actual_depth} steps {self.id}: {node.id}') - if len(next_cterms) > 0: + exec_res = self.cterm_execute(node.cterm, depth=depth, module_name=module_name) + if exec_res.depth != depth: + raise ValueError(f'Unable to take {depth} steps from node, got {exec_res.depth} steps {self.id}: {node.id}') + if len(exec_res.next_states) > 0: raise ValueError(f'Found branch within {depth} steps {self.id}: {node.id}') - new_node = cfg.create_node(cterm) + new_node = cfg.create_node(exec_res.state) _LOGGER.info(f'Found new node at depth {depth} {self.id}: {shorten_hashes((node.id, new_node.id))}') - logs[new_node.id] = next_node_logs + logs[new_node.id] = exec_res.logs out_edges = cfg.edges(source_id=node.id) if len(out_edges) == 0: cfg.create_edge(node.id, new_node.id, depth=depth) @@ -424,7 +433,7 @@ def extend_cterm( if len(branches) > 1: return Branch(branches, heuristic=True) - unknown_predicate, _is_vacuous, depth, cterm, next_cterms, next_node_logs = self.cterm_execute( + cterm, unknown_predicate, next_cterms, depth, vacuous, next_node_logs = self.cterm_execute( _cterm, depth=execute_depth, cut_point_rules=cut_point_rules, @@ -438,7 +447,7 @@ def extend_cterm( # Stuck, Vacuous or Undecided if not next_cterms: - if _is_vacuous: + if vacuous: return Vacuous() if unknown_predicate is not None: return Undecided(unknown_predicate) diff --git a/src/pyk/kdist/__init__.py b/src/pyk/kdist/__init__.py new file mode 100644 index 000000000..051d88777 --- /dev/null +++ b/src/pyk/kdist/__init__.py @@ -0,0 +1,2 @@ +from ._cache import target_ids +from ._kdist import KDIST_DIR, KDist, kdist diff --git a/src/pyk/kdist/__main__.py b/src/pyk/kdist/__main__.py new file mode 100644 index 000000000..1a0624310 --- /dev/null +++ b/src/pyk/kdist/__main__.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import fnmatch +import logging +from argparse import ArgumentParser +from typing import TYPE_CHECKING + +from pyk.cli.args import KCLIArgs +from pyk.cli.utils import loglevel + +from ..kdist import kdist, target_ids + +if TYPE_CHECKING: + from argparse import Namespace + from typing import Final + + +_LOGGER: Final = logging.getLogger(__name__) +_LOG_FORMAT: Final = '%(levelname)s %(asctime)s %(name)s - %(message)s' + + +def main() -> None: + args = _parse_arguments() + + logging.basicConfig(level=loglevel(args), format=_LOG_FORMAT) + + if args.command == 'build': + _exec_build(**vars(args)) + + elif args.command == 'clean': + _exec_clean(args.target) + + elif args.command == 'which': + _exec_which(args.target) + + elif args.command == 'list': + _exec_list() + + else: + raise AssertionError() + + +def _exec_build( + command: str, + targets: list[str], + args: list[str], + jobs: int, + force: bool, + verbose: bool, + debug: bool, +) -> None: + kdist.build( + target_ids=_process_targets(targets), + args=_process_args(args), + jobs=jobs, + force=force, + verbose=verbose or debug, + ) + + +def _process_targets(targets: list[str]) -> list[str]: + all_target_fqns = [target_id.full_name for target_id in target_ids()] + res = [] + for pattern in targets: + matches = fnmatch.filter(all_target_fqns, pattern) + if not matches: + raise ValueError(f'No target matches pattern: {pattern!r}') + res += matches + return res + + +def _process_args(args: list[str]) -> dict[str, str]: + res: dict[str, str] = {} + for arg in args: + segments = arg.split('=') + if len(segments) < 2: + raise ValueError(f"Expected assignment of the form 'arg=value', got: {arg!r}") + key, *values = segments + res[key] = '='.join(values) + return res + + +def _exec_clean(target: str | None) -> None: + res = kdist.clean(target) + print(res) + + +def _exec_which(target: str | None) -> None: + res = kdist.which(target) + print(res) + + +def _exec_list() -> None: + targets_by_plugin: dict[str, list[str]] = {} + for plugin_name, target_name in target_ids(): + targets = targets_by_plugin.get(plugin_name, []) + targets.append(target_name) + targets_by_plugin[plugin_name] = targets + + for plugin_name in targets_by_plugin: + print(plugin_name) + for target_name in targets_by_plugin[plugin_name]: + print(f'* {target_name}') + + +def _parse_arguments() -> Namespace: + def add_target_arg(parser: ArgumentParser, help_text: str) -> None: + parser.add_argument( + 'target', + metavar='TARGET', + nargs='?', + help=help_text, + ) + + k_cli_args = KCLIArgs() + + parser = ArgumentParser(prog='kdist', parents=[k_cli_args.logging_args]) + command_parser = parser.add_subparsers(dest='command', required=True) + + build_parser = command_parser.add_parser('build', help='build targets') + build_parser.add_argument('targets', metavar='TARGET', nargs='*', default='*', help='target to build') + build_parser.add_argument( + '-a', + '--arg', + dest='args', + metavar='ARG', + action='append', + default=[], + help='build with argument', + ) + build_parser.add_argument('-f', '--force', action='store_true', default=False, help='force build') + build_parser.add_argument('-j', '--jobs', metavar='N', type=int, default=1, help='maximal number of build jobs') + + clean_parser = command_parser.add_parser('clean', help='clean targets') + add_target_arg(clean_parser, 'target to clean') + + which_parser = command_parser.add_parser('which', help='print target location') + add_target_arg(which_parser, 'target to print directory for') + + command_parser.add_parser('list', help='print list of available targets') + + return parser.parse_args() + + +if __name__ == '__main__': + main() diff --git a/src/pyk/kdist/_cache.py b/src/pyk/kdist/_cache.py new file mode 100644 index 000000000..f038d7b31 --- /dev/null +++ b/src/pyk/kdist/_cache.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping +from typing import TYPE_CHECKING, NamedTuple + +from .api import Target, TargetId, valid_id + +if TYPE_CHECKING: + from collections.abc import Iterable + from types import ModuleType + from typing import Final + + +_LOGGER: Final = logging.getLogger(__name__) + + +class CachedTarget(NamedTuple): + id: TargetId + target: Target + + +class TargetCache: + _plugins: dict[str, dict[str, CachedTarget]] + + def __init__(self, plugins: Mapping[str, Mapping[str, CachedTarget]]): + _plugins: dict[str, dict[str, CachedTarget]] = {} + for plugin_name, targets in plugins.items(): + _targets: dict[str, CachedTarget] = {} + _plugins[plugin_name] = _targets + for target_name, target in targets.items(): + _targets[target_name] = target + self._plugins = _plugins + + def resolve(self, target_id: str | TargetId) -> CachedTarget: + if isinstance(target_id, str): + target_id = TargetId.parse(target_id) + + plugin_name, target_name = target_id + try: + targets = self._plugins[plugin_name] + except KeyError: + raise ValueError(f'Undefined plugin: {plugin_name}') from None + + try: + res = targets[target_name] + except KeyError: + raise ValueError(f'Plugin {plugin_name} does not define target: {target_name}') from None + + return res + + def resolve_deps(self, target_ids: Iterable[str | TargetId]) -> dict[TargetId, list[TargetId]]: + res: dict[TargetId, list[TargetId]] = {} + pending = [self.resolve(target_id) for target_id in target_ids] + while pending: + target = pending.pop() + if target.id in res: + continue + deps = [self.resolve(target_fqn) for target_fqn in target.target.deps()] + res[target.id] = [dep.id for dep in deps] + pending += deps + return res + + @property + def target_ids(self) -> list[TargetId]: + return [target.id for plugin_name, targets in self._plugins.items() for target_name, target in targets.items()] + + @staticmethod + def load() -> TargetCache: + return TargetCache(TargetCache._load_plugins()) + + @staticmethod + def _load_plugins() -> dict[str, dict[str, CachedTarget]]: + import importlib + from importlib.metadata import entry_points + + plugins = entry_points(group='kdist') + + res: dict[str, dict[str, CachedTarget]] = {} + for plugin in plugins: + plugin_name = plugin.name + + if not valid_id(plugin_name): + _LOGGER.warning(f'Invalid plugin name, skipping: {plugin_name}') + continue + + _LOGGER.info(f'Loading plugin: {plugin_name}') + module_name = plugin.value + try: + _LOGGER.info(f'Importing module: {module_name}') + module = importlib.import_module(module_name) + except Exception: + _LOGGER.error(f'Module {module_name} cannot be imported', exc_info=True) + continue + + res[plugin_name] = TargetCache._load_targets(plugin_name, module) + + return res + + @staticmethod + def _load_targets(plugin_name: str, module: ModuleType) -> dict[str, CachedTarget]: + if not hasattr(module, '__TARGETS__'): + _LOGGER.warning(f'Module does not define __TARGETS__: {module.__name__}') + return {} + + targets = module.__TARGETS__ + + if not isinstance(targets, Mapping): + _LOGGER.warning(f'Invalid __TARGETS__ attribute: {module.__name__}') + return {} + + res: dict[str, CachedTarget] = {} + for target_name, target in targets.items(): + if not isinstance(target_name, str): + _LOGGER.warning(f'Invalid target name in {module.__name__}: {target_name!r}') + continue + + if not valid_id(target_name): + _LOGGER.warning(f'Invalid target name (in {module.__name__}): {target_name}') + continue + + if not isinstance(target, Target): + _LOGGER.warning(f'Invalid target in {module.__name__} for name {target_name}: {target!r}') + continue + + res[target_name] = CachedTarget(TargetId(plugin_name, target_name), target) + + return res + + +_TARGET_CACHE: TargetCache | None = None + + +def target_cache() -> TargetCache: + global _TARGET_CACHE + if not _TARGET_CACHE: + _LOGGER.info('Loading target cache') + start_time = time.time() + _TARGET_CACHE = TargetCache.load() + end_time = time.time() + delta_time = end_time - start_time + _LOGGER.info(f'Target cache loaded in {delta_time:.3f}s') + return _TARGET_CACHE + + +def target_ids() -> list[TargetId]: + return target_cache().target_ids diff --git a/src/pyk/kdist/_kdist.py b/src/pyk/kdist/_kdist.py new file mode 100644 index 000000000..2e172854c --- /dev/null +++ b/src/pyk/kdist/_kdist.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import concurrent.futures +import json +import logging +import os +import shutil +from concurrent.futures import ProcessPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass +from graphlib import CycleError, TopologicalSorter +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, final + +from filelock import SoftFileLock +from xdg_base_dirs import xdg_cache_home + +from ..utils import hash_str +from . import utils +from ._cache import target_cache +from .api import TargetId + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + from concurrent.futures import Future + from typing import Any, Final + + from filelock import FileLock + + from ._cache import CachedTarget + + +_LOGGER: Final = logging.getLogger(__name__) + + +@final +@dataclass(frozen=True) +class KDist: + kdist_dir: Path + + def __init__(self, kdist_dir: str | Path | None): + if kdist_dir is None: + kdist_dir = KDist.default_dir() + kdist_dir = Path(kdist_dir).resolve() + object.__setattr__(self, 'kdist_dir', kdist_dir) + + @staticmethod + def default_dir() -> Path: + import pyk + + module_dir = Path(pyk.__file__).parent + digest = hash_str({'module-dir': str(module_dir)})[:7] + return xdg_cache_home() / f'kdist-{digest}' + + def which(self, target_id: str | TargetId | None = None) -> Path: + if target_id: + target_id = target_cache().resolve(target_id).id + return self._target_dir(target_id) + return self.kdist_dir + + def clean(self, target_id: str | TargetId | None = None) -> Path: + res = self.which(target_id) + shutil.rmtree(res, ignore_errors=True) + return res + + def get(self, target_id: str | TargetId) -> Path: + if isinstance(target_id, str): + target_id = TargetId.parse(target_id) + res = self._target_dir(target_id) + if not res.exists(): + raise ValueError(f'Target undefined or not built: {target_id.full_name}') + return res + + def get_or_none(self, target_id: str | TargetId) -> Path | None: + try: + return self.get(target_id) + except ValueError: + return None + + def build( + self, + target_ids: Iterable[str | TargetId], + *, + args: Mapping[str, str] | None = None, + jobs: int = 1, + force: bool = False, + verbose: bool = False, + ) -> None: + args = dict(args) if args else {} + dep_ids = target_cache().resolve_deps(target_ids) + target_graph = TopologicalSorter(dep_ids) + try: + target_graph.prepare() + except CycleError as err: + raise RuntimeError(f'Cyclic dependencies found: {err.args[1]}') from err + + deps_fqns = [target_id.full_name for target_id in dep_ids] + _LOGGER.info(f"Building targets: {', '.join(deps_fqns)}") + + with ProcessPoolExecutor(max_workers=jobs) as pool: + pending: dict[Future[Path], TargetId] = {} + + def submit(target_id: TargetId) -> None: + future = pool.submit( + self._build_target, + target_id=target_id, + args=args, + force=force, + verbose=verbose, + ) + pending[future] = target_id + + for target_id in target_graph.get_ready(): + submit(target_id) + + while pending: + done, _ = concurrent.futures.wait(pending, return_when=concurrent.futures.FIRST_COMPLETED) + for future in done: + result = future.result() + print(result, flush=True) + target_id = pending[future] + target_graph.done(target_id) + for new_target_id in target_graph.get_ready(): + submit(new_target_id) + pending.pop(future) + + # Helpers + + def _build_target( + self, + target_id: TargetId, + args: dict[str, Any], + *, + force: bool, + verbose: bool, + ) -> Path: + target = target_cache().resolve(target_id) + output_dir = self._target_dir(target_id) + manifest_file = self._manifest_file(target_id) + + with self._lock(target_id): + manifest = self._manifest(target, args) + + if not force and self._up_to_date(target_id, manifest): + return output_dir + + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True) + manifest_file.unlink(missing_ok=True) + + with ( + self._build_dir(target_id) as build_dir, + utils.cwd(build_dir), + ): + try: + target.target.build(output_dir, deps=self._deps(target), args=args, verbose=verbose) + except BaseException as err: + shutil.rmtree(output_dir, ignore_errors=True) + raise RuntimeError(f'Build failed: {target_id.full_name}') from err + + manifest_file.write_text(json.dumps(manifest)) + return output_dir + + def _target_dir(self, target_id: TargetId) -> Path: + return self.kdist_dir / target_id.plugin_name / target_id.target_name + + def _manifest_file(self, target_id: TargetId) -> Path: + return self.kdist_dir / target_id.plugin_name / f'{target_id.target_name}.json' + + def _deps(self, target: CachedTarget) -> dict[str, Path]: + return {dep_fqn: self._target_dir(target_cache().resolve(dep_fqn).id) for dep_fqn in target.target.deps()} + + def _manifest(self, target: CachedTarget, args: dict[str, Any]) -> dict[str, Any]: + res = target.target.manifest() + res['args'] = dict(args) + res['deps'] = { + dep_fqn: utils.timestamp(self._manifest_file(target_cache().resolve(dep_fqn).id)) + for dep_fqn in target.target.deps() + } + return res + + def _up_to_date(self, target_id: TargetId, new_manifest: dict[str, Any]) -> bool: + if not self._target_dir(target_id).exists(): + return False + manifest_file = self._manifest_file(target_id) + if not manifest_file.exists(): + return False + old_manifest = json.loads(manifest_file.read_text()) + return new_manifest == old_manifest + + def _lock(self, target_id: TargetId) -> FileLock: + lock_file = self._target_dir(target_id).with_suffix('.lock') + lock_file.parent.mkdir(parents=True, exist_ok=True) + return SoftFileLock(lock_file) + + @contextmanager + def _build_dir(self, target_id: TargetId) -> Iterator[Path]: + tmp_dir_prefix = f'kdist-{target_id.plugin_name}-{target_id.target_name}-' + with TemporaryDirectory(prefix=tmp_dir_prefix) as build_dir_str: + build_dir = Path(build_dir_str) + yield build_dir + + +_KDIST_DIR_ENV: Final = os.getenv('KDIST_DIR') +KDIST_DIR: Final = Path(_KDIST_DIR_ENV) if _KDIST_DIR_ENV else None + +kdist: Final = KDist(KDIST_DIR) diff --git a/src/pyk/kdist/api.py b/src/pyk/kdist/api.py new file mode 100644 index 000000000..2af08d186 --- /dev/null +++ b/src/pyk/kdist/api.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, final + +from . import utils + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + from pathlib import Path + from typing import Any + + +_ID_PATTERN = re.compile('[a-z0-9]+(-[a-z0-9]+)*') + + +def valid_id(s: str) -> bool: + return _ID_PATTERN.fullmatch(s) is not None + + +@final +@dataclass(frozen=True) +class TargetId: + plugin_name: str + target_name: str + + def __init__(self, plugin_name: str, target_name: str): + if not valid_id(plugin_name): + raise ValueError(f'Invalid plugin name: {plugin_name!r}') + + if not valid_id(target_name): + raise ValueError(f'Invalid target name: {target_name!r}') + + object.__setattr__(self, 'plugin_name', plugin_name) + object.__setattr__(self, 'target_name', target_name) + + def __iter__(self) -> Iterator[str]: + yield self.plugin_name + yield self.target_name + + @staticmethod + def parse(fqn: str) -> TargetId: + segments = fqn.split('.') + if len(segments) != 2: + raise ValueError(f'Expected fully qualified target name, got: {fqn!r}') + + plugin_name, target_name = segments + return TargetId(plugin_name, target_name) + + @property + def full_name(self) -> str: + return f'{self.plugin_name}.{self.target_name}' + + +class Target(ABC): + @abstractmethod + def build(self, output_dir: Path, deps: dict[str, Path], args: dict[str, Any], verbose: bool) -> None: + ... + + def deps(self) -> Iterable[str]: + return () + + def source(self) -> Iterable[str | Path]: + return () + + def context(self) -> Mapping[str, str]: + return {} + + @final + def manifest(self) -> dict[str, Any]: + source = {} + package_path = utils.package_path(self) + source_files = [file.resolve() for source in self.source() for file in utils.files_for_path(source)] + for source_file in source_files: + try: + file_id = str(source_file.relative_to(package_path)) + except ValueError as err: + raise ValueError(f'Source file is not within package: {source_file}') from err + source[file_id] = utils.timestamp(source_file) + + context = dict(self.context()) + return {'context': context, 'source': source} diff --git a/src/pyk/kdist/utils.py b/src/pyk/kdist/utils.py new file mode 100644 index 000000000..446206a0b --- /dev/null +++ b/src/pyk/kdist/utils.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import inspect +import os +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING + +from pyk.utils import check_dir_path + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any + + +def package_path(obj: Any) -> Path: + module = inspect.getmodule(obj) + + if not module: + raise ValueError(f'Module not found for object: {obj}') + + if not module.__file__: + raise ValueError(f'Path not found for module: {module.__name__}') + + package_path = Path(module.__file__).parent.resolve() + while True: + init_file = package_path / '__init__.py' + if not init_file.exists(): + return package_path + if not package_path.parent.exists(): + return package_path + package_path = package_path.parent + + +def files_for_path(path: str | Path) -> list[Path]: + path = Path(path) + + if not path.exists(): + raise ValueError(f'Path does not exist: {path}') + + if path.is_file(): + return [path] + + return [file for file in path.rglob('*') if file.is_file()] + + +def timestamp(path: Path) -> int: + return path.stat().st_mtime_ns + + +@contextmanager +def cwd(path: Path) -> Iterator[None]: + check_dir_path(path) + old_cwd = os.getcwd() + os.chdir(str(path)) + yield + os.chdir(old_cwd) diff --git a/src/pyk/kllvm/importer.py b/src/pyk/kllvm/importer.py index 08aed7b40..d637f36ad 100644 --- a/src/pyk/kllvm/importer.py +++ b/src/pyk/kllvm/importer.py @@ -6,13 +6,11 @@ from ..cli.utils import check_dir_path, check_file_path from .compiler import KLLVM_MODULE_FILE_NAME, KLLVM_MODULE_NAME, RUNTIME_MODULE_FILE_NAME, RUNTIME_MODULE_NAME +from .runtime import Runtime if TYPE_CHECKING: - from collections.abc import Callable from types import ModuleType - from .ast import Pattern - def import_from_file(module_name: str, module_file: str | Path) -> ModuleType: module_file = Path(module_file).resolve() @@ -41,62 +39,9 @@ def import_kllvm(target_dir: str | Path) -> ModuleType: return import_from_file(KLLVM_MODULE_NAME, module_file) -def import_runtime(target_dir: str | Path) -> ModuleType: +def import_runtime(target_dir: str | Path) -> Runtime: target_dir = Path(target_dir) check_dir_path(target_dir) module_file = target_dir / RUNTIME_MODULE_FILE_NAME - runtime = import_from_file(RUNTIME_MODULE_NAME, module_file) - _patch_runtime(runtime) - return runtime - - -def _patch_runtime(runtime: ModuleType) -> None: - runtime.Term = _make_term_class(runtime) # type: ignore - runtime.interpret = _make_interpreter(runtime) # type: ignore - - -def _make_interpreter(runtime: ModuleType) -> Callable[..., Pattern]: - def interpret(pattern: Pattern, *, depth: int | None = None) -> Pattern: - init_term = runtime.InternalTerm(pattern) - final_term = init_term.step(depth if depth is not None else -1) - return final_term.to_pattern() - - return interpret - - -def _make_term_class(mod: ModuleType) -> type: - class Term: - def __init__(self, pattern: Pattern): - self._block = mod.InternalTerm(pattern) - - @property - def pattern(self) -> Pattern: - return self._block.to_pattern() - - @staticmethod - def deserialize(bs: bytes) -> Term | None: - block = mod.InternalTerm.deserialize(bs) - if block is None: - return None - term = object.__new__(Term) - term._block = block - return term - - def serialize(self) -> bytes: - return self._block.serialize() - - def step(self, n: int = 1) -> None: - self._block = self._block.step(n) - - def run(self) -> None: - self.step(-1) - - def copy(self) -> Term: - other = self - other._block = self._block.step(0) - return other - - def __str__(self) -> str: - return str(self._block) - - return Term + module = import_from_file(RUNTIME_MODULE_NAME, module_file) + return Runtime(module) diff --git a/src/pyk/kllvm/runtime.py b/src/pyk/kllvm/runtime.py new file mode 100644 index 000000000..e22a24750 --- /dev/null +++ b/src/pyk/kllvm/runtime.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from typing import Any + + from .ast import Pattern, Sort + + +class Runtime: + _module: ModuleType + + def __init__(self, module: ModuleType): + self._module = module + + def term(self, pattern: Pattern) -> Term: + return Term(self._module.InternalTerm(pattern)) + + def deserialize(self, bs: bytes) -> Term | None: + block = self._module.InternalTerm.deserialize(bs) + if block is None: + return None + return Term(block) + + def step(self, pattern: Pattern, depth: int | None = 1) -> Pattern: + term = self.term(pattern) + term.step(depth=depth) + return term.pattern + + def run(self, pattern: Pattern) -> Pattern: + return self.step(pattern, depth=None) + + def simplify(self, pattern: Pattern, sort: Sort) -> Pattern: + return self._module.simplify_pattern(pattern, sort) + + def simplify_bool(self, pattern: Pattern) -> bool: + return self._module.simplify_bool_pattern(pattern) + + +class Term: + _block: Any # module.InternalTerm + + def __init__(self, block: Any): + self._block = block + + @property + def pattern(self) -> Pattern: + return self._block.to_pattern() + + def serialize(self) -> bytes: + return self._block.serialize() + + def step(self, depth: int | None = 1) -> None: + self._block = self._block.step(depth if depth is not None else -1) + + def run(self) -> None: + self.step(depth=None) + + def __str__(self) -> str: + return str(self._block) diff --git a/src/pyk/kore/kompiled.py b/src/pyk/kore/kompiled.py index 231fbf19b..f03498e5e 100644 --- a/src/pyk/kore/kompiled.py +++ b/src/pyk/kore/kompiled.py @@ -32,7 +32,7 @@ _LOGGER: Final = logging.getLogger(__name__) -_PYK_DEFINITION_NAME: Final = 'pyk-definition.json' +_PYK_DEFINITION_NAME: Final = 'pyk-definition.kore.json' @final @@ -46,11 +46,19 @@ def load(definition_dir: str | Path) -> KompiledKore: definition_dir = Path(definition_dir) check_dir_path(definition_dir) + kore_file = definition_dir / 'definition.kore' + check_file_path(kore_file) + json_file = definition_dir / _PYK_DEFINITION_NAME if json_file.exists(): - return KompiledKore.load_from_json(json_file) + kore_timestamp = kore_file.stat().st_mtime_ns + json_timestamp = json_file.stat().st_mtime_ns + + if kore_timestamp < json_timestamp: + return KompiledKore.load_from_json(json_file) + + _LOGGER.warning(f'File is out of date: {json_file}') - kore_file = definition_dir / 'definition.kore' return KompiledKore.load_from_kore(kore_file) @staticmethod @@ -114,7 +122,7 @@ def add_injections(self, pattern: Pattern, sort: Sort | None = None) -> Pattern: sort = SortApp('SortK') patterns = pattern.patterns sorts = self.symbol_table.pattern_sorts(pattern) - pattern = pattern.let_patterns(self.add_injections(p, s) for p, s in zip(patterns, sorts, strict=True)) + pattern = pattern.let_patterns(tuple(self.add_injections(p, s) for p, s in zip(patterns, sorts, strict=True))) return self._inject(pattern, sort) def _inject(self, pattern: Pattern, sort: Sort) -> Pattern: diff --git a/src/pyk/kore/rpc.py b/src/pyk/kore/rpc.py index ac1ce3de3..b9ef42ebe 100644 --- a/src/pyk/kore/rpc.py +++ b/src/pyk/kore/rpc.py @@ -1061,8 +1061,7 @@ def __init__( check_dir_path(llvm_dt) if bug_report: - bug_report.add_file(llvm_definition, Path('llvm_definition/definition.kore')) - bug_report.add_file(llvm_dt, Path('llvm_definition/dt')) + self._gather_booster_report(llvm_kompiled_dir, llvm_definition, llvm_dt, bug_report) self._check_none_or_positive(smt_timeout, 'smt_timeout') self._check_none_or_positive(smt_retry_limit, 'smt_retry_limit') @@ -1090,6 +1089,15 @@ def __init__( log_axioms_file=log_axioms_file, ) + @staticmethod + def _gather_booster_report( + llvm_kompiled_dir: Path, llvm_definition: Path, llvm_dt: Path, bug_report: BugReport + ) -> None: + bug_report.add_file(llvm_definition, Path('llvm_definition/definition.kore')) + bug_report.add_file(llvm_dt, Path('llvm_definition/dt')) + llvm_version = run_process('llvm-backend-version', pipe_stderr=True, logger=_LOGGER).stdout.strip() + bug_report.add_file_contents(llvm_version, Path('llvm_version.txt')) + def kore_server( definition_dir: str | Path, diff --git a/src/pyk/kore/syntax.py b/src/pyk/kore/syntax.py index 5c5b04347..0d6efc8ba 100644 --- a/src/pyk/kore/syntax.py +++ b/src/pyk/kore/syntax.py @@ -250,7 +250,7 @@ def _tag(cls) -> str: @classmethod def from_dict(cls: type[SortApp], dct: Mapping[str, Any]) -> SortApp: cls._check_tag(dct) - return SortApp(name=dct['name'], sorts=(Sort.from_dict(arg) for arg in dct['args'])) + return SortApp(name=dct['name'], sorts=tuple(Sort.from_dict(arg) for arg in dct['args'])) @property def dict(self) -> dict[str, Any]: @@ -449,8 +449,8 @@ def from_dict(cls: type[App], dct: Mapping[str, Any]) -> App: cls._check_tag(dct) return App( symbol=dct['name'], - sorts=(Sort.from_dict(sort) for sort in dct['sorts']), - args=(Pattern.from_dict(arg) for arg in dct['args']), + sorts=tuple(Sort.from_dict(sort) for sort in dct['sorts']), + args=tuple(Pattern.from_dict(arg) for arg in dct['args']), ) @property @@ -847,9 +847,10 @@ def of(cls: type[And], symbol: str, sorts: Iterable[Sort] = (), patterns: Iterab @classmethod def from_dict(cls: type[And], dct: Mapping[str, Any]) -> And: cls._check_tag(dct) - sort = Sort.from_dict(dct['sort']) - ops = [Pattern.from_dict(op) for op in dct['patterns']] - return And(sort=sort, ops=ops) + return And( + sort=Sort.from_dict(dct['sort']), + ops=tuple(Pattern.from_dict(op) for op in dct['patterns']), + ) @final @@ -895,9 +896,10 @@ def of(cls: type[Or], symbol: str, sorts: Iterable[Sort] = (), patterns: Iterabl @classmethod def from_dict(cls: type[Or], dct: Mapping[str, Any]) -> Or: cls._check_tag(dct) - sort = Sort.from_dict(dct['sort']) - ops = [Pattern.from_dict(op) for op in dct['patterns']] - return Or(sort=sort, ops=ops) + return Or( + sort=Sort.from_dict(dct['sort']), + ops=tuple(Pattern.from_dict(op) for op in dct['patterns']), + ) class MLQuant(MLPattern, WithSort): diff --git a/src/pyk/ktool/kompile.py b/src/pyk/ktool/kompile.py index 6d82ae008..33107b4c1 100644 --- a/src/pyk/ktool/kompile.py +++ b/src/pyk/ktool/kompile.py @@ -20,6 +20,8 @@ from collections.abc import Iterable, Mapping from typing import Any, Final, Literal + from ..utils import BugReport + _LOGGER: Final = logging.getLogger(__name__) @@ -28,12 +30,20 @@ def __init__(self, kompile_command: str): super().__init__(f'Kompile command not found: {str}') +class TypeInferenceMode(Enum): + Z3 = 'z3' + SIMPLESUB = 'simplesub' + CHECKED = 'checked' + DEFAULT = 'default' + + def kompile( main_file: str | Path, *, command: Iterable[str] = ('kompile',), output_dir: str | Path | None = None, temp_dir: str | Path | None = None, + type_inference_mode: str | TypeInferenceMode | None = None, debug: bool = False, verbose: bool = False, cwd: Path | None = None, @@ -46,6 +56,7 @@ def kompile( command=command, output_dir=output_dir, temp_dir=temp_dir, + type_inference_mode=type_inference_mode, debug=debug, verbose=verbose, cwd=cwd, @@ -113,10 +124,12 @@ def __call__( *, output_dir: str | Path | None = None, temp_dir: str | Path | None = None, + type_inference_mode: str | TypeInferenceMode | None = None, debug: bool = False, verbose: bool = False, cwd: Path | None = None, check: bool = True, + bug_report: BugReport | None = None, ) -> Path: check_file_path(abs_or_rel_to(self.base_args.main_file, cwd or Path())) for include_dir in self.base_args.include_dirs: @@ -135,6 +148,10 @@ def __call__( temp_dir = Path(temp_dir) args += ['--temp-dir', str(temp_dir)] + if type_inference_mode is not None: + type_inference_mode = TypeInferenceMode(type_inference_mode) + args += ['--type-inference-mode', type_inference_mode.value] + if debug: args += ['--debug'] @@ -152,10 +169,14 @@ def __call__( ) from err if proc_res.stdout: - print(proc_res.stdout.rstrip()) + out = proc_res.stdout.rstrip() + print(out) + if bug_report: + bug_report.add_file_contents(out, Path('kompile.log')) definition_dir = output_dir if output_dir else Path(self.base_args.main_file.stem + '-kompiled') assert definition_dir.is_dir() + return definition_dir @abstractmethod diff --git a/src/pyk/prelude/utils.py b/src/pyk/prelude/utils.py index ef734a6c2..64bd64a3d 100644 --- a/src/pyk/prelude/utils.py +++ b/src/pyk/prelude/utils.py @@ -8,8 +8,6 @@ from .string import stringToken if TYPE_CHECKING: - pass - from ..kast.inner import KToken diff --git a/src/pyk/testing/_kompiler.py b/src/pyk/testing/_kompiler.py index 1aca90948..2b0893a80 100644 --- a/src/pyk/testing/_kompiler.py +++ b/src/pyk/testing/_kompiler.py @@ -11,7 +11,7 @@ from ..kllvm.importer import import_runtime from ..kore.pool import KoreServerPool from ..kore.rpc import BoosterServer, KoreClient, KoreServer -from ..ktool.kompile import DefinitionInfo, Kompile +from ..ktool.kompile import DefinitionInfo, Kompile, TypeInferenceMode from ..ktool.kprint import KPrint from ..ktool.kprove import KProve from ..ktool.krun import KRun @@ -19,13 +19,13 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator from pathlib import Path - from types import ModuleType from typing import Any, ClassVar from pytest import TempPathFactory from ..kast.outer import KDefinition from ..kcfg.semantics import KCFGSemantics + from ..kllvm.runtime import Runtime from ..ktool.kprint import SymbolTable from ..utils import BugReport @@ -43,7 +43,7 @@ def __call__(self, main_file: str | Path, **kwargs: Any) -> Path: kompile = Kompile.from_dict(kwargs) if kompile not in self._cache: output_dir = self._path / self._uid(kompile) - self._cache[kompile] = kompile(output_dir=output_dir) + self._cache[kompile] = kompile(output_dir=output_dir, type_inference_mode=TypeInferenceMode.CHECKED) return self._cache[kompile] @@ -212,6 +212,6 @@ class RuntimeTest(KompiledTest): KOMPILE_BACKEND = 'llvm' @pytest.fixture(scope='class') - def runtime(self, definition_dir: Path) -> ModuleType: + def runtime(self, definition_dir: Path) -> Runtime: compile_runtime(definition_dir) return import_runtime(definition_dir) diff --git a/src/tests/integration/kcfg/test_multiple_definitions.py b/src/tests/integration/kcfg/test_multiple_definitions.py index f0bbd74c3..e2de82cf1 100644 --- a/src/tests/integration/kcfg/test_multiple_definitions.py +++ b/src/tests/integration/kcfg/test_multiple_definitions.py @@ -44,12 +44,12 @@ def test_execute( kcfg_explore: KCFGExplore, test_id: str, ) -> None: - _, _, split_depth, split_post_term, split_next_terms, _logs = kcfg_explore.cterm_execute(self.config(), depth=1) + exec_res = kcfg_explore.cterm_execute(self.config(), depth=1) + split_next_terms = exec_res.next_states + split_k = kcfg_explore.kprint.pretty_print(exec_res.state.cell('K_CELL')) + split_next_k = [kcfg_explore.kprint.pretty_print(exec_res.state.cell('K_CELL')) for _ in split_next_terms] - split_k = kcfg_explore.kprint.pretty_print(split_post_term.cell('K_CELL')) - split_next_k = [kcfg_explore.kprint.pretty_print(split_post_term.cell('K_CELL')) for term in split_next_terms] - - assert split_depth == 0 + assert exec_res.depth == 0 assert len(split_next_terms) == 2 assert 'a ( X:KItem )' == split_k assert [ @@ -57,14 +57,10 @@ def test_execute( 'a ( X:KItem )', ] == split_next_k - _, _, step_1_depth, step_1_post_term, step_1_next_terms, _logs = kcfg_explore.cterm_execute( - split_next_terms[0], depth=1 - ) - step_1_k = kcfg_explore.kprint.pretty_print(step_1_post_term.cell('K_CELL')) + step_1_res = kcfg_explore.cterm_execute(split_next_terms[0], depth=1) + step_1_k = kcfg_explore.kprint.pretty_print(step_1_res.state.cell('K_CELL')) assert 'c' == step_1_k - _, _, step_2_depth, step_2_post_term, step_2_next_terms, _logs = kcfg_explore.cterm_execute( - split_next_terms[1], depth=1 - ) - step_2_k = kcfg_explore.kprint.pretty_print(step_1_post_term.cell('K_CELL')) + step_2_res = kcfg_explore.cterm_execute(split_next_terms[1], depth=1) + step_2_k = kcfg_explore.kprint.pretty_print(step_2_res.state.cell('K_CELL')) assert 'c' == step_2_k diff --git a/src/tests/integration/kcfg/test_simple.py b/src/tests/integration/kcfg/test_simple.py index 1e55af7d0..b28564dd7 100644 --- a/src/tests/integration/kcfg/test_simple.py +++ b/src/tests/integration/kcfg/test_simple.py @@ -78,17 +78,16 @@ def test_execute( expected_k, expected_state, *_ = expected_post # When - _, _, actual_depth, actual_post_term, actual_next_terms, _logs = kcfg_explore.cterm_execute( - self.config(kcfg_explore.kprint, *pre), depth=depth - ) - actual_k = kcfg_explore.kprint.pretty_print(actual_post_term.cell('K_CELL')) - actual_state = kcfg_explore.kprint.pretty_print(actual_post_term.cell('STATE_CELL')) + exec_res = kcfg_explore.cterm_execute(self.config(kcfg_explore.kprint, *pre), depth=depth) + actual_k = kcfg_explore.kprint.pretty_print(exec_res.state.cell('K_CELL')) + actual_state = kcfg_explore.kprint.pretty_print(exec_res.state.cell('STATE_CELL')) + actual_depth = exec_res.depth actual_next_states = [ ( kcfg_explore.kprint.pretty_print(s.cell('K_CELL')), kcfg_explore.kprint.pretty_print(s.cell('STATE_CELL')), ) - for s in actual_next_terms + for s in exec_res.next_states ] # Then diff --git a/src/tests/integration/kllvm/test_internal_term.py b/src/tests/integration/kllvm/test_internal_term.py index 72a676b11..7e8166975 100644 --- a/src/tests/integration/kllvm/test_internal_term.py +++ b/src/tests/integration/kllvm/test_internal_term.py @@ -9,17 +9,16 @@ from ..utils import K_FILES if TYPE_CHECKING: - from types import ModuleType - from pyk.kllvm.ast import Pattern + from pyk.kllvm.runtime import Runtime class TestInternalTerm(RuntimeTest): KOMPILE_MAIN_FILE = K_FILES / 'imp.k' - def test_str_llvm_backend_issue_724(self, runtime: ModuleType) -> None: + def test_str_llvm_backend_issue_724(self, runtime: Runtime) -> None: for _ in range(10000): - term = runtime.InternalTerm(start_pattern()) + term = runtime._module.InternalTerm(start_pattern()) term.step(-1) # just checking that str doesn't crash str(term) diff --git a/src/tests/integration/kllvm/test_serialize.py b/src/tests/integration/kllvm/test_serialize.py index e9badfa00..a3e907f5b 100644 --- a/src/tests/integration/kllvm/test_serialize.py +++ b/src/tests/integration/kllvm/test_serialize.py @@ -1,9 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest import pyk.kllvm.load # noqa: F401 +from pyk.kllvm import parser from pyk.kllvm.ast import Pattern from pyk.kllvm.convert import pattern_to_llvm from pyk.kore.parser import KoreParser +from pyk.testing import RuntimeTest + +from ..utils import K_FILES + +if TYPE_CHECKING: + from pathlib import Path + + from pyk.kllvm.runtime import Runtime TEST_DATA = ( '"foo"', @@ -30,3 +43,23 @@ def test_serialize(kore_text: str) -> None: # Then assert actual is not None assert str(actual) == str(pattern) + + +class TestSerializeRaw(RuntimeTest): + KOMPILE_MAIN_FILE = K_FILES / 'imp.k' + + def test_serialize_raw(self, runtime: Runtime, tmp_path: Path) -> None: + # Given + kore_text = r"""Lbl'UndsPlus'Int'Unds'{}(\dv{SortInt{}}("1"),\dv{SortInt{}}("2"))""" + pattern = parser.parse_pattern(kore_text) + term = runtime.term(pattern) + kore_file = tmp_path / 'kore' + + # When + term._block._serialize_raw(str(kore_file), 'SortInt{}') + pattern = Pattern.deserialize(kore_file.read_bytes()) + pattern_with_raw = Pattern.deserialize(kore_file.read_bytes(), strip_raw_term=False) + + # Then + assert str(pattern) == r'\dv{SortInt{}}("3")' + assert str(pattern_with_raw) == r'rawTerm{}(inj{SortInt{}, SortKItem{}}(\dv{SortInt{}}("3")))' diff --git a/src/tests/integration/kllvm/test_simplify.py b/src/tests/integration/kllvm/test_simplify.py new file mode 100644 index 000000000..865d640db --- /dev/null +++ b/src/tests/integration/kllvm/test_simplify.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import pyk.kllvm.load # noqa: F401 +from pyk.kllvm import parser +from pyk.testing import RuntimeTest + +from ..utils import K_FILES + +if TYPE_CHECKING: + from pyk.kllvm.runtime import Runtime + + +SIMPLIFY_TEST_DATA = ( + ('literal', r'\dv{SortInt{}}("0")', r'inj{SortInt{}, SortKItem{}}(\dv{SortInt{}}("0"))'), + ( + 'plus', + r"""Lbl'UndsPlus'Int'Unds'{}(\dv{SortInt{}}("1"), \dv{SortInt{}}("2"))""", + r'inj{SortInt{}, SortKItem{}}(\dv{SortInt{}}("3"))', + ), +) + +SIMPLIFY_BOOL_TEST_DATA = ( + ('true', r'\dv{SortBool{}}("true")', True), + ('false', r'\dv{SortBool{}}("false")', False), + ('true andBool false', r"""Lbl'Unds'andBool'Unds'{}(\dv{SortBool{}}("true"), \dv{SortBool{}}("false"))""", False), + ('false orBool true', r"""Lbl'Unds'orBool'Unds'{}(\dv{SortBool{}}("false"), \dv{SortBool{}}("true"))""", True), +) + + +class TestSimplify(RuntimeTest): + KOMPILE_MAIN_FILE = K_FILES / 'imp.k' + + @pytest.mark.parametrize( + 'test_id,pattern_text,expected', + SIMPLIFY_TEST_DATA, + ids=[test_id for test_id, *_ in SIMPLIFY_TEST_DATA], + ) + def test_simplify(self, runtime: Runtime, test_id: str, pattern_text: str, expected: str) -> None: + # Given + pattern = parser.parse_pattern(pattern_text) + sort = parser.parse_sort('SortInt{}') + + # When + simplified = runtime.simplify(pattern, sort) + + # Then + assert str(simplified) == expected + + @pytest.mark.parametrize( + 'test_id,pattern_text,expected', + SIMPLIFY_BOOL_TEST_DATA, + ids=[test_id for test_id, *_ in SIMPLIFY_BOOL_TEST_DATA], + ) + def test_simplify_bool(self, runtime: Runtime, test_id: str, pattern_text: str, expected: bool) -> None: + # Given + pattern = parser.parse_pattern(pattern_text) + + # When + actual = runtime.simplify_bool(pattern) + + # Then + assert actual == expected diff --git a/src/tests/integration/kllvm/test_step.py b/src/tests/integration/kllvm/test_step.py index 7ff64fa2e..2415f2610 100644 --- a/src/tests/integration/kllvm/test_step.py +++ b/src/tests/integration/kllvm/test_step.py @@ -3,22 +3,21 @@ from typing import TYPE_CHECKING import pyk.kllvm.load # noqa: F401 -from pyk.kllvm.parser import Parser +from pyk.kllvm.parser import parse_pattern from pyk.testing import RuntimeTest from ..utils import K_FILES if TYPE_CHECKING: - from types import ModuleType - from pyk.kllvm.ast import Pattern + from pyk.kllvm.runtime import Runtime class TestStep(RuntimeTest): KOMPILE_MAIN_FILE = K_FILES / 'steps.k' - def test_steps_1(self, runtime: ModuleType) -> None: - term = runtime.Term(start_pattern()) + def test_steps_1(self, runtime: Runtime) -> None: + term = runtime.term(start_pattern()) term.step(0) assert str(term) == foo_output(0) term.step() @@ -27,29 +26,25 @@ def test_steps_1(self, runtime: ModuleType) -> None: term.step(200) assert str(term) == bar_output() - def test_steps_2(self, runtime: ModuleType) -> None: - term = runtime.Term(start_pattern()) + def test_steps_2(self, runtime: Runtime) -> None: + term = runtime.term(start_pattern()) assert str(term) == foo_output(0) term.step(50) assert str(term) == foo_output(50) term.step(-1) assert str(term) == bar_output() - def test_steps_3(self, runtime: ModuleType) -> None: - term = runtime.Term(start_pattern()) + def test_steps_3(self, runtime: Runtime) -> None: + term = runtime.term(start_pattern()) term.run() assert str(term) == bar_output() - def test_steps_to_pattern(self, runtime: ModuleType) -> None: - term = runtime.Term(start_pattern()) + def test_steps_to_pattern(self, runtime: Runtime) -> None: + term = runtime.term(start_pattern()) term.run() pattern = term.pattern assert str(pattern) == bar_output() - def test_interpret(self, runtime: ModuleType) -> None: - pattern = runtime.interpret(start_pattern()) - assert str(pattern) == bar_output() - def start_pattern() -> Pattern: """ @@ -70,7 +65,7 @@ def start_pattern() -> Pattern: ) ) """ - return Parser.from_string(text).pattern() + return parse_pattern(text) def foo_output(n: int) -> str: diff --git a/src/tests/integration/kllvm/test_term.py b/src/tests/integration/kllvm/test_term.py index 7874bb7c5..b01101b56 100644 --- a/src/tests/integration/kllvm/test_term.py +++ b/src/tests/integration/kllvm/test_term.py @@ -11,7 +11,7 @@ from ..utils import K_FILES if TYPE_CHECKING: - from types import ModuleType + from pyk.kllvm.runtime import Runtime class TestTerm(RuntimeTest): @@ -21,14 +21,14 @@ class TestTerm(RuntimeTest): } @pytest.mark.parametrize('ctor', ('one', 'two', 'three')) - def test_construct(self, runtime: ModuleType, ctor: str) -> None: + def test_construct(self, runtime: Runtime, ctor: str) -> None: # Given label = f"Lbl{ctor}'LParRParUnds'CTOR'Unds'Foo" pattern = CompositePattern(label) - term = runtime.Term(pattern) + term = runtime.term(pattern) # Then assert str(term) == str(pattern) assert str(term.pattern) == str(pattern) assert term.serialize() == pattern.serialize() - assert str(runtime.Term.deserialize(pattern.serialize())) == str(term) + assert str(runtime.deserialize(pattern.serialize())) == str(term) diff --git a/src/tests/integration/kore/test_pool.py b/src/tests/integration/kore/test_pool.py index 7e0e29236..46bdbefd9 100644 --- a/src/tests/integration/kore/test_pool.py +++ b/src/tests/integration/kore/test_pool.py @@ -24,8 +24,6 @@ from ..utils import K_FILES if TYPE_CHECKING: - pass - from pyk.kore.pool import KoreServerPool diff --git a/src/tests/integration/proof/test_cell_map.py b/src/tests/integration/proof/test_cell_map.py index 3121b933a..dd26354ca 100644 --- a/src/tests/integration/proof/test_cell_map.py +++ b/src/tests/integration/proof/test_cell_map.py @@ -104,10 +104,9 @@ def test_execute( expected_k, _, _ = expected_post # When - _, _, actual_depth, actual_post_term, _, _logs = kcfg_explore.cterm_execute( - self.config(kcfg_explore.kprint, k, aacounts, accounts), depth=depth - ) - actual_k = kcfg_explore.kprint.pretty_print(actual_post_term.cell('K_CELL')) + exec_res = kcfg_explore.cterm_execute(self.config(kcfg_explore.kprint, k, aacounts, accounts), depth=depth) + actual_k = kcfg_explore.kprint.pretty_print(exec_res.state.cell('K_CELL')) + actual_depth = exec_res.depth # Then assert actual_depth == expected_depth diff --git a/src/tests/integration/proof/test_imp.py b/src/tests/integration/proof/test_imp.py index 4f470778a..d840b38ac 100644 --- a/src/tests/integration/proof/test_imp.py +++ b/src/tests/integration/proof/test_imp.py @@ -780,18 +780,17 @@ def test_execute( expected_k, expected_state = expected_post # When - _, _, actual_depth, actual_post_term, actual_next_terms, _logs = kcfg_explore.cterm_execute( - self.config(kcfg_explore.kprint, k, state), depth=depth - ) - actual_k = kcfg_explore.kprint.pretty_print(actual_post_term.cell('K_CELL')) - actual_state = kcfg_explore.kprint.pretty_print(actual_post_term.cell('STATE_CELL')) + exec_res = kcfg_explore.cterm_execute(self.config(kcfg_explore.kprint, k, state), depth=depth) + actual_k = kcfg_explore.kprint.pretty_print(exec_res.state.cell('K_CELL')) + actual_state = kcfg_explore.kprint.pretty_print(exec_res.state.cell('STATE_CELL')) + actual_depth = exec_res.depth actual_next_states = [ ( kcfg_explore.kprint.pretty_print(s.cell('K_CELL')), kcfg_explore.kprint.pretty_print(s.cell('STATE_CELL')), ) - for s in actual_next_terms + for s in exec_res.next_states ] # Then diff --git a/src/tests/integration/test_bytes.py b/src/tests/integration/test_bytes.py index 92ffcb7ec..227a05e69 100644 --- a/src/tests/integration/test_bytes.py +++ b/src/tests/integration/test_bytes.py @@ -24,11 +24,11 @@ if TYPE_CHECKING: from pathlib import Path - from types import ModuleType from typing import Final from pytest import FixtureRequest + from pyk.kllvm.runtime import Runtime from pyk.kore.syntax import Pattern from pyk.testing import Kompiler @@ -78,7 +78,7 @@ def definition_dir(request: FixtureRequest, backend: str) -> Path: @pytest.fixture(scope='module') -def runtime(llvm_dir: Path) -> ModuleType: +def runtime(llvm_dir: Path) -> Runtime: import pyk.kllvm.load # noqa: F401 compile_runtime(llvm_dir) @@ -200,7 +200,7 @@ def test_krun(backend: str, definition_dir: Path, value: bytes) -> None: @pytest.mark.parametrize('value', TEST_DATA) -def test_bindings(runtime: ModuleType, value: bytes) -> None: +def test_bindings(runtime: Runtime, value: bytes) -> None: from pyk.kllvm.convert import llvm_to_pattern, pattern_to_llvm # Given @@ -208,7 +208,7 @@ def test_bindings(runtime: ModuleType, value: bytes) -> None: expected = kore_config(None, value) # When - kore_llvm = runtime.interpret(pattern_to_llvm(kore)) + kore_llvm = runtime.run(pattern_to_llvm(kore)) actual = llvm_to_pattern(kore_llvm) # Then diff --git a/src/tests/integration/test_string.py b/src/tests/integration/test_string.py index cdace2de8..f5cb51e34 100644 --- a/src/tests/integration/test_string.py +++ b/src/tests/integration/test_string.py @@ -24,11 +24,11 @@ if TYPE_CHECKING: from pathlib import Path - from types import ModuleType from typing import Final from pytest import FixtureRequest + from pyk.kllvm.runtime import Runtime from pyk.kore.syntax import Pattern from pyk.testing import Kompiler @@ -75,7 +75,7 @@ def definition_dir(request: FixtureRequest, backend: str) -> Path: @pytest.fixture(scope='module') -def runtime(llvm_dir: Path) -> ModuleType: +def runtime(llvm_dir: Path) -> Runtime: import pyk.kllvm.load # noqa: F401 compile_runtime(llvm_dir) @@ -204,7 +204,7 @@ def test_krun(backend: str, definition_dir: Path, text: str) -> None: @pytest.mark.parametrize('text', TEST_DATA, ids=TEST_DATA) -def test_bindings(runtime: ModuleType, text: str) -> None: +def test_bindings(runtime: Runtime, text: str) -> None: from pyk.kllvm.convert import llvm_to_pattern, pattern_to_llvm # Given @@ -212,7 +212,7 @@ def test_bindings(runtime: ModuleType, text: str) -> None: expected = kore_config(None, text) # When - kore_llvm = runtime.interpret(pattern_to_llvm(kore)) + kore_llvm = runtime.run(pattern_to_llvm(kore)) actual = llvm_to_pattern(kore_llvm) # Then diff --git a/src/tests/profiling/profile_kast_to_kore.py b/src/tests/profiling/profile_kast_to_kore.py index 44e3c06fa..8271a9423 100644 --- a/src/tests/profiling/profile_kast_to_kore.py +++ b/src/tests/profiling/profile_kast_to_kore.py @@ -1,5 +1,6 @@ from __future__ import annotations +import shutil import sys from typing import TYPE_CHECKING @@ -20,6 +21,7 @@ def test_kast_to_kore(profile: Profiler, tmp_path: Path) -> None: kast_to_kore_dir = TEST_DATA_DIR / 'kast-to-kore' kast_defn_file = kast_to_kore_dir / 'compiled.json' + kore_defn_file = kast_to_kore_dir / 'definition.kore' kinner_file = kast_to_kore_dir / 'kinner.json' sys.setrecursionlimit(10**8) @@ -27,8 +29,9 @@ def test_kast_to_kore(profile: Profiler, tmp_path: Path) -> None: with profile('init-kast-defn.prof', sort_keys=('cumtime',), limit=50): kast_defn = read_kast_definition(kast_defn_file) + shutil.copy(kore_defn_file, tmp_path) with profile('init-kore-defn.prof', sort_keys=('cumtime',), limit=50): - kore_defn = KompiledKore.load(kast_to_kore_dir) # first time from KORE + kore_defn = KompiledKore.load(tmp_path) # first time from KORE kore_defn.write(tmp_path) with profile('reinit-kore-defn.prof', sort_keys=('cumtime',), limit=25): diff --git a/src/tests/unit/test_kore_exec_covr.py b/src/tests/unit/test_kore_exec_covr.py index 98d520151..c2e57240d 100644 --- a/src/tests/unit/test_kore_exec_covr.py +++ b/src/tests/unit/test_kore_exec_covr.py @@ -7,8 +7,6 @@ from pyk.kore_exec_covr.kore_exec_covr import HaskellLogEntry, _parse_haskell_oneline_log, parse_rule_applications if TYPE_CHECKING: - pass - from pytest import TempPathFactory