Skip to content

Commit

Permalink
Fixes to analyzer and reifier
Browse files Browse the repository at this point in the history
Literals are now added to the positive conditions in the transformer's
visiting process, not later. This way, we know the context of the
literals.

Literals in the body of a rule are only added to the positive conditions
if they are non-negated and do not occur inside an aggregate.

In the reification step, nothing changes, but datatypes are fixed
to work with the tuple of conditions and positive conditions.

Mpve the transformation of intervals to a separate function, as it is
needed at multiple places. Also, it is fixed to work for ShowTerm
statements where the dependent needs to be wrapped in a Symbolic Atom
and Literal first, before checking for intervals.

Removing the visiting of body elements, as it is not necessary anymore.

Contributes: #62
  • Loading branch information
stephanzwicknagl committed Mar 2, 2024
1 parent eb770fe commit 1d5464c
Showing 1 changed file with 57 additions and 69 deletions.
126 changes: 57 additions & 69 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def visit_ConditionalLiteral(

if in_head:
# collect deps for choice rules
deps[conditional_literal.literal] = []
deps[conditional_literal.literal] = ([], [])
for condition in conditional_literal.condition:
deps[conditional_literal.literal].append(condition)
deps[conditional_literal.literal][0].append(condition)
if (not in_aggregate and not in_analyzer):
# add simple Cond.Literals from rule body to justifier rule body
conditions.append(conditional_literal)
Expand All @@ -128,13 +128,16 @@ def visit_Literal(
literal: ast.Literal, # type: ignore
**kwargs: Any) -> AST:
conditions: List[AST] = kwargs.get("conditions", [])
positive_conditions: List[AST] = kwargs.get("positive_conditions", [])
in_analyzer = kwargs.get("in_analyzer", False)
in_aggregate = kwargs.get("in_aggregate", False)

if (in_analyzer and literal.atom.ast_type
not in [ASTType.Aggregate, ASTType.BodyAggregate]):
# all non-aggregate Literals in the rule body are conditions of the rule
conditions.append(literal)
if literal.sign == ast.Sign.NoSign and not in_aggregate:
positive_conditions.append(literal)
if (not in_analyzer and not in_aggregate):
# add all Literals outside of aggregates from rule body to justifier rule body
conditions.append(literal)
Expand Down Expand Up @@ -368,12 +371,13 @@ def register_rule_dependencies(
ast.Literal, # type: ignore
List[ast.Literal]] # type: ignore
) -> None:
for uu in deps.values():
for u in filter(filter_body_arithmetic, uu):
u_sig = make_signature(u)
self.conditions[u_sig].add(rule)
if u.sign == ast.Sign.NoSign:
self.positive_conditions[u_sig].add(rule)
for (cond, pos_cond) in deps.values():
for c in filter(filter_body_arithmetic, cond):
c_sig = make_signature(c)
self.conditions[c_sig].add(rule)
for c in filter(filter_body_arithmetic, pos_cond):
c_sig = make_signature(c)
self.positive_conditions[c_sig].add(rule)

for v in deps.keys():
if v.ast_type == ASTType.Literal and v.atom.ast_type != ASTType.BooleanConstant:
Expand All @@ -387,27 +391,25 @@ def get_body_aggregate_elements(self, body: Sequence[AST]) -> List[AST]:
return body_aggregate_elements

def process_body(self, head, body, deps, in_analyzer=True):
for b in body:
self.visit(b, deps=deps)
if not len(deps) and len(body):
deps[head] = []
for _, cond in deps.items():
self.visit_sequence(body, conditions=cond, in_analyzer=in_analyzer)
deps[head] = ([], [])
for _, (cond, pos_cond) in deps.items():
self.visit_sequence(body, conditions=cond, positive_conditions=pos_cond, in_analyzer=in_analyzer)

def register_dependencies_and_append_rule(self, rule, deps):
self.register_rule_dependencies(rule, deps)
self.rules.append(rule)

def visit_Rule(self, rule: ast.Rule): # type: ignore
deps = defaultdict(list)
deps = defaultdict(tuple)
_ = self.visit(rule.head, deps=deps, in_head=True)
self.process_body(rule.head, rule.body, deps)
self.register_dependencies_and_append_rule(rule, deps)
if is_fact(rule, deps):
self.facts.add(rule.head)
self.register_dependencies_and_append_rule(rule, deps)

def visit_ShowTerm(self, showTerm: ast.ShowTerm): # type: ignore
deps = defaultdict(list)
deps = defaultdict(tuple)
_ = self.visit(showTerm.term, deps=deps, in_head=True)
head_literal = ast.Literal(showTerm.location, ast.Sign.NoSign,
ast.SymbolicAtom(showTerm.term))
Expand Down Expand Up @@ -588,7 +590,8 @@ def _nest_rule_head_in_h_with_explanation_tuple(
reasons.append(literal.atom)
reasons.reverse()
reasons = [r for i, r in enumerate(reasons) if r not in reasons[:i]]
reason_fun = ast.Function(loc, "", [r for r in reasons if r is not None], 0)
reason_fun = ast.Function(loc, "",
[r for r in reasons if r is not None], 0)
reason_lit = ast.Literal(loc, ast.Sign.NoSign, reason_fun)

h_attribute = self.h_showTerm if use_h_showTerm else self.h
Expand All @@ -600,6 +603,26 @@ def _nest_rule_head_in_h_with_explanation_tuple(
def post_rule_creation(self):
self.clear_temp_names()

def process_dependant_intervals(
self, loc: ast.Location,
dependant: Union[ast.Literal, ast.Function]): # type: ignore
if dependant.ast_type == ASTType.Function:
print(f"Type of dependant {dependant.ast_type}, going to make lit",
flush=True)
dependant = ast.Literal(loc, ast.Sign.NoSign, ast.SymbolicAtom(dependant))
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
ast.Variable(loc, self.get_conflict_free_variable())
if arg.ast_type == ASTType.Interval else arg
for arg in dependant.atom.symbol.arguments
]
symbol = ast.SymbolicAtom(
ast.Function(loc, dependant.atom.symbol.name, variables,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)
return dependant

def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
"""
Reify a rule into a set of new rules.
Expand All @@ -612,28 +635,18 @@ def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
:param rule: The rule to reify
:return: A list of new rules"""
# Embed the head
deps = defaultdict(list)
deps = defaultdict(tuple)
loc = rule.location
_ = self.visit(rule.head, deps=deps, in_head=True)

if is_fact(rule, deps) or is_constraint(rule):
return [rule]
if not deps:
# if it's a "simple head"
deps[rule.head] = []
deps[rule.head] = ([], [])
new_rules: List[ast.Rule] = [] # type: ignore
for dependant, conditions in deps.items():
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
ast.Variable(loc, self.get_conflict_free_variable())
if arg.ast_type == ASTType.Interval else arg
for arg in dependant.atom.symbol.arguments
]
symbol = ast.SymbolicAtom(
ast.Function(loc, dependant.atom.symbol.name, variables,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)
for dependant, (conditions, _) in deps.items():
dependant = self.process_dependant_intervals(loc, dependant)

_ = self.visit_sequence(
rule.body,
Expand Down Expand Up @@ -663,17 +676,7 @@ def visit_ShowTerm(self, showTerm: ast.ShowTerm): # type: ignore
_ = self.visit(showTerm.term, deps=deps, in_head=True)

new_rules = []
if has_an_interval(showTerm.term):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
ast.Variable(loc, self.get_conflict_free_variable())
if arg.ast_type == ASTType.Interval else arg
for arg in showTerm.term.atom.symbol.arguments
]
symbol = ast.SymbolicAtom(
ast.Function(loc, showTerm.term.atom.symbol.name, variables,
False))
showTerm.term = ast.Literal(loc, ast.Sign.NoSign, symbol)
showTerm.term = self.process_dependant_intervals(loc, showTerm.term)

conditions: List[AST] = []
_ = self.visit_sequence(
Expand Down Expand Up @@ -724,8 +727,10 @@ class LiteralWrapper(Transformer):

def __init__(self, *args, **kwargs):
self.wrap_str: str = kwargs.pop("wrap_str", "model")
self.no_wrap_types: List[
ASTType] = [ASTType.Aggregate, ASTType.BodyAggregate, ASTType.Comparison, ASTType.BooleanConstant]
self.no_wrap_types: List[ASTType] = [
ASTType.Aggregate, ASTType.BodyAggregate, ASTType.Comparison,
ASTType.BooleanConstant
]
super().__init__(*args, **kwargs)

def visit_Literal(self,
Expand Down Expand Up @@ -756,17 +761,7 @@ def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
deps[rule.head] = []
new_rules: List[ast.Rule] = [] # type: ignore
for dependant, conditions in deps.items():
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
ast.Variable(loc, self.get_conflict_free_variable())
if arg.ast_type == ASTType.Interval else arg
for arg in dependant.atom.symbol.arguments
]
symbol = ast.SymbolicAtom(
ast.Function(loc, dependant.atom.symbol.name, variables,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)
dependant = self.process_dependant_intervals(loc, dependant)

_ = self.visit_sequence(
rule.body,
Expand Down Expand Up @@ -876,23 +871,16 @@ def has_an_interval(literal: ast.Literal) -> bool: # type: ignore
return False


def reify_recursion_transformation(rule: str, **kwargs):
def reify_recursion_transformation(transformation: Transformation, **kwargs) -> List[AST]:
visitor = ProgramReifierForRecursions(**kwargs)
result: List[AST] = []
rules = []
if isinstance(rule, str):
parse_string(
rule, lambda rule: rules.append(rule)
if rule.ast_type != ASTType.Program else None)
rules = transformation.rules
if any(isinstance(r, str) for r in rules):
rules_str = rules
rules = []
for rule in rules_str:
parse_string(rule, lambda rule: rules.append(rule) if rule.ast_type != ASTType.Program else None)

for r in rules:
result.extend(cast(Iterable[AST], visitor.visit(r)))
return result


def reify_recursion_transformation_list(recursion_rules: Iterable[str],
**kwargs) -> List[AST]:
reified = []
for part in recursion_rules:
reified.extend(reify_recursion_transformation(part, **kwargs))
return reified

0 comments on commit 1d5464c

Please sign in to comment.