Skip to content

Commit

Permalink
Refactor justifier rule for recursion
Browse files Browse the repository at this point in the history
Use transformer for generating justifier rules for recursion reasoner.
This way, more complex rules, e.g. ones containing aggregates can
easily be transformed.

The wrapping in `model` literals is also managed in a new transformer.
  • Loading branch information
stephanzwicknagl committed Feb 29, 2024
1 parent 3df1f88 commit eb770fe
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 122 deletions.
96 changes: 25 additions & 71 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""This module is concerned with finding reasons for why a stable model is found."""
from collections import defaultdict
from typing import List, Collection, Dict, Iterable, Union, Set
from typing import List, Collection, Dict, Iterable, Union, Set, cast

import networkx as nx

from clingo import Control, Symbol, Model, ast

from clingo.ast import AST
from clingo.ast import AST, ASTType, parse_string

from .reify import ProgramAnalyzer, has_an_interval
from .reify import ProgramAnalyzer, has_an_interval, reify_recursion_transformation_list, transform
from .recursion import RecursionReasoner
from .utils import insert_atoms_into_nodes, identify_reasons, calculate_spacing_factor
from ..shared.model import Node, Transformation, SymbolIdentifier
Expand Down Expand Up @@ -96,7 +96,7 @@ def make_reason_path_from_facts_to_stable_model(wrapped_stable_model,
rule_mapping: Dict[int, Transformation],
fact_node: Node,
h_symbols: List[Symbol],
recursive_transformations:set,
recursive_transformations_hashes: Set[str],
h="h",
analyzer: ProgramAnalyzer = ProgramAnalyzer(),
pad=True) \
Expand All @@ -109,14 +109,14 @@ def make_reason_path_from_facts_to_stable_model(wrapped_stable_model,
g = nx.DiGraph()
if len(h_syms) == 1:
# If there is a stable model that is exactly the same as the facts.
if rule_mapping[min(rule_mapping.keys())].rules in recursive_transformations:
if rule_mapping[min(rule_mapping.keys())].hash in recursive_transformations_hashes:
fact_node.recursive = True
g.add_edge(fact_node, Node(frozenset(), min(rule_mapping.keys()), frozenset(fact_node.diff)),
transformation=rule_mapping[min(rule_mapping.keys())])
return g

for a, b in pairwise(h_syms):
if rule_mapping[b.rule_nr].rules in recursive_transformations:
if rule_mapping[b.rule_nr].hash in recursive_transformations_hashes:
b.recursive = get_recursion_subgraph(a.atoms,
b.diff,
rule_mapping[b.rule_nr],
Expand Down Expand Up @@ -157,7 +157,7 @@ def build_graph(wrapped_stable_models: List[List[str]],
transformed_prg: Collection[AST],
sorted_program: List[Transformation],
analyzer: ProgramAnalyzer,
recursion_transformations: set) -> nx.DiGraph:
recursion_transformations_hashes: Set[str]) -> nx.DiGraph:
paths: List[nx.DiGraph] = []
facts = analyzer.get_facts()
conflict_free_h = analyzer.get_conflict_free_h()
Expand All @@ -177,7 +177,7 @@ def build_graph(wrapped_stable_models: List[List[str]],
conflict_free_h,
conflict_free_h_showTerm)
new_path = make_reason_path_from_facts_to_stable_model(
model, mapping, fact_node, h_symbols, recursion_transformations,
model, mapping, fact_node, h_symbols, recursion_transformations_hashes,
conflict_free_h, analyzer)
paths.append(new_path)

Expand All @@ -198,9 +198,10 @@ def save_model(model: Model) -> Collection[str]:


def filter_body_aggregates(element: AST):
if (element.ast_type == ast.ASTType.Aggregate):
aggregate_types = [ASTType.Aggregate, ASTType.BodyAggregate, ASTType.ConditionalLiteral]
if (element.ast_type in aggregate_types):
return False
if (getattr(getattr(element, "atom", None), "ast_type",None) == ast.ASTType.Aggregate):
if (getattr(getattr(element, "atom", None), "ast_type",None) in aggregate_types):
return False
return True

Expand All @@ -217,71 +218,25 @@ def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
:param transformation: The recursive transformation. An ast object.
:param conflict_free_h: The name of the h predicate.
"""
# get_conflict_free_model = analyzer.get_conflict_free_model()
# get_conflict_free_iterindex = analyzer.get_conflict_free_iterindex()

init = [fact.symbol for fact in facts]
justification_program = ""
model_str: str = analyzer.get_conflict_free_model() if analyzer else "model"
n_str: str = analyzer.get_conflict_free_iterindex() if analyzer else "n"

for rule in transformation.rules:
deps = defaultdict(list)
loc = rule.location

_ = analyzer.visit(rule.head, deps=deps) # type: ignore
if not deps:
deps[rule.head] = []
for dependant, conditions in deps.items():
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
ast.Variable(loc, analyzer.get_conflict_free_variable())
if arg.ast_type == ast.ASTType.Interval else arg
for arg in dependant.atom.symbol.arguments
]
symbol = ast.SymbolicAtom(ast.Function(loc,
dependant.atom.symbol.name,
variables,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)

reason_literals: List[ast.Literal] = [] # type: ignore
_ = analyzer.visit_sequence(
rule.body, reasons=reason_literals, conditions=conditions, rename_variables=False)
loc_fun = ast.Function(loc, n_str, [], False)
loc_atm = ast.SymbolicAtom(loc_fun)
loc_lit = ast.Literal(loc, ast.Sign.NoSign, loc_atm)
for literal in conditions:
if literal.atom.ast_type == ast.ASTType.SymbolicAtom:
reason_literals.append(literal.atom)
reason_literals.reverse()
reason_literals = [r for i,r in enumerate(reason_literals) if r not in reason_literals[:i]]
reason_fun = ast.Function(loc, '', reason_literals, 0)
reason_lit = ast.Literal(loc, ast.Sign.NoSign, reason_fun)

new_head_s = [
ast.Function(loc, analyzer.get_conflict_free_h(),
[loc_lit, dependant, reason_lit], 0)
]

conditions.extend(conditions)
# Remove duplicates but preserve order
conditions = [
x for i, x in enumerate(conditions) if x not in conditions[:i]
]
conditions = [
ast.Function(loc, model_str, [bb], 0)
for bb in filter(filter_body_aggregates, conditions)
]
conditions.append(
ast.Function(loc, f"not {model_str}", [dependant], 0))
justification_program += "\n".join(
map(str, (ast.Rule(rule.location, new_head, conditions)
for new_head in new_head_s)))
# TODO: add proper edge generation

justification_program += f"{model_str}(@new())."
justifier_rules = reify_recursion_transformation_list(
transformation.rules,
h=analyzer.get_conflict_free_h(),
h_showTerm=analyzer.get_conflict_free_h_showTerm(),
model=analyzer.get_conflict_free_model(),
conflict_free_showTerm=analyzer.get_conflict_free_showTerm(),
get_conflict_free_variable=analyzer.get_conflict_free_variable,
clear_temp_names=analyzer.clear_temp_names,
conflict_free_model=analyzer.get_conflict_free_model(),
conflict_free_iterindex=analyzer.get_conflict_free_iterindex(),
)
justification_program += "\n".join(map(str, justifier_rules))
justification_program += f"\n{model_str}(@new())."

h_syms = set()

try:
Expand All @@ -294,7 +249,6 @@ def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
return False

h_syms = collect_h_symbols_and_create_nodes(h_syms, relevant_indices = [], pad = False, supernode_symbols = supernode_symbols)
# here: rule_nr is iteration number
h_syms.sort(key=lambda node: node.rule_nr)
h_syms.insert(0, Node(frozenset(facts), -1))
insert_atoms_into_nodes(h_syms)
Expand Down
Loading

0 comments on commit eb770fe

Please sign in to comment.