Skip to content

Commit

Permalink
Improve SDFG work-depth analysis and add SDFG simulated operational i…
Browse files Browse the repository at this point in the history
…ntensity analysis (#1607)

This PR is here to merge the reviewed PR #1495, which has remained
inactive for a long time with minor comments open. The comments have
been addressed here and merge conflicts have been resolved.

---------

Co-authored-by: Cliff Hodel <hodelcl@student.ethz.ch>
Co-authored-by: Cliff Hodel <111381329+hodelcl@users.noreply.github.com>
Co-authored-by: Cliff Hodel <hodelcl@ethz.ch>
  • Loading branch information
4 people authored Jun 27, 2024
1 parent fb074f2 commit a4e0291
Show file tree
Hide file tree
Showing 8 changed files with 1,473 additions and 297 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions):
equality_subs1.update({sym: sp.Symbol(uf.find(sym))})

equality_subs2 = {}
# In a second step, each symbol gets replace with its equal number (if present)
# In a second step, each symbol gets replaced with its equal number (if present)
# using equality_subs2.
for sym, assum in condensed_assumptions.items():
for e in assum.equal:
Expand Down Expand Up @@ -182,7 +182,7 @@ def parse_assumptions(assumptions, array_symbols):
Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and
keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality
assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions.
The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent
The other assumptions get handled by N tuples of substitution dicts (N = max number of concurrent
assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol.
First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def get_backedges(graph: nx.DiGraph,
return backedges


class LoopExtractionError(Exception):
pass


def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph):
"""
Detects loops in a SDFG. For each loop, it identifies (node, oNode, exit).
Expand Down Expand Up @@ -241,15 +245,15 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph):

# sanity check:
if sdfg_nx.in_degree(artificial_end_node) == 0:
raise ValueError('No end node could be determined in the SDFG')
raise LoopExtractionError('No end node could be determined in the SDFG')

# compute dominators and backedges
iDoms = nx.immediate_dominators(sdfg_nx, start)
allDom, domTree = get_domtree(sdfg_nx, start, iDoms)
allDom, _ = get_domtree(sdfg_nx, start, iDoms)

reversed_sdfg_nx = sdfg_nx.reverse()
iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node)
allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms)
_, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms)

backedges = get_backedges(sdfg_nx, start)
backedgesDstDict = {}
Expand Down Expand Up @@ -297,7 +301,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph):
exitCandidates.add(succ)

if len(exitCandidates) == 0:
raise ValueError('failed to find any exit nodes')
raise LoopExtractionError('failed to find any exit nodes')
elif len(exitCandidates) > 1:
# Find the exit candidate that sits highest up in the
# postdominator tree (i.e., has the lowest level).
Expand All @@ -323,7 +327,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph):
if len(minSet) > 0:
exitCandidates = minSet
else:
raise ValueError('failed to find exit minSet')
raise LoopExtractionError('failed to find exit minSet')

# now we have a triple (node, oNode, exitCandidates)
nodes_oNodes_exits.append((node, oNode, exitCandidates))
Expand Down
283 changes: 283 additions & 0 deletions dace/sdfg/performance_evaluation/op_in_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position
and class AccessStack which which corresponds to the stack used to compute the stack distance.
Further, provides a curve fitting method and plotting function. """

import warnings
from dace.data import Array
import sympy as sp
from collections import deque
from scipy.optimize import curve_fit
import numpy as np
from dace import symbol


class CacheLineTracker:
""" A CacheLineTracker maps data container accesses to the corresponding accessed cache line. """

def __init__(self, L) -> None:
self.array_info = {}
self.start_lines = {}
self.next_free_line = 0
self.L = L

def add_array(self, name: str, a: Array, mapping):
if name not in self.start_lines:
# new array encountered
self.array_info[name] = a
self.start_lines[name] = self.next_free_line
# increase next_free_line
self.next_free_line += (a.total_size.subs(mapping) * a.dtype.bytes + self.L - 1) // self.L # ceil division

def cache_line_id(self, name: str, access: [int], mapping):
arr = self.array_info[name]
one_d_index = 0
for dim in range(len(access)):
i = access[dim]
one_d_index += (i + sp.sympify(arr.offset[dim]).subs(mapping)) * sp.sympify(arr.strides[dim]).subs(mapping)

# divide by L to get the cache line id
return self.start_lines[name] + (one_d_index * arr.dtype.bytes) // self.L

def copy(self):
new_clt = CacheLineTracker(self.L)
new_clt.array_info = dict(self.array_info)
new_clt.start_lines = dict(self.start_lines)
new_clt.next_free_line = self.next_free_line
return new_clt


class Node:

def __init__(self, val: int, n=None) -> None:
self.v = val
self.next = n


class AccessStack:
""" A stack of cache line ids. For each memory access, we search the corresponding cache line id
in the stack, report its distance and move it to the top of the stack. If the id was not found,
we report a distance of -1. """

def __init__(self, C) -> None:
self.top = None
self.num_calls = 0
self.length = 0
self.C = C

def touch(self, id):
self.num_calls += 1
curr = self.top
prev = None
found = False
distance = 0
while curr is not None:
# check if we found id
if curr.v == id:
# take curr node out
if prev is not None:
prev.next = curr.next
curr.next = self.top
self.top = curr

found = True
break

# iterate further
prev = curr
curr = curr.next
distance += 1

if not found:
# we accessed this cache line for the first time ever
self.top = Node(id, self.top)
self.length += 1
distance = -1

return distance

def in_cache_as_list(self):
"""
Returns a list of cache ids currently in cache. Index 0 is the most recently used.
"""
res = deque()
curr = self.top
dist = 0
while curr is not None and dist < self.C:
res.append(curr.v)
curr = curr.next
dist += 1
return res

def debug_print(self):
# prints the whole stack
print('\n')
curr = self.top
while curr is not None:
print(curr.v, end=', ')
curr = curr.next
print('\n')

def copy(self):
new_stack = AccessStack(self.C)
cache_content = self.in_cache_as_list()
if len(cache_content) > 0:
new_top_value = cache_content.popleft()
new_stack.top = Node(new_top_value)
curr = new_stack.top
for x in cache_content:
curr.next = Node(x)
curr = curr.next
return new_stack


def plot(x, work_map, cache_misses, op_in_map, symbol_name, C, L, sympy_f, element, name):
plt = None
try:
import matplotlib.pyplot as plt_import
plt = plt_import
except ModuleNotFoundError:
pass

if plt is None:
warnings.warn('Plotting only possible with matplotlib installed')
return

work_map = work_map[element]
cache_misses = cache_misses[element]
op_in_map = op_in_map[element]
sympy_f = sympy_f[element]

a = np.linspace(1, max(x) + 5, max(x) * 4)

fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].scatter(x, cache_misses, label=f'C={C*L}, L={L}')
b = []
for curr in a:
b.append(sp.N(sp.sympify(sympy_f).subs(symbol_name, curr)))
ax[0].plot(a, b)

c = []
for i, curr in enumerate(x):
if work_map[0].subs(symbol_name, curr) == 0:
c.append(0)
elif (cache_misses[i] * L) == 0:
c.append(9999)
else:
c.append(work_map[0].subs(symbol_name, curr) / (cache_misses[i] * L))
c = np.array(c).astype(np.float64)

ax[1].scatter(x, c, label=f'C={C*L}, L={L}')
b = []
for curr in a:
b.append(sp.N(sp.sympify(op_in_map).subs(symbol_name, curr)))
ax[1].plot(a, b)

ax[0].set_ylim(bottom=0, top=max(cache_misses) + max(cache_misses) / 10)
ax[0].set_xlim(left=0, right=max(x) + 1)
ax[0].set_xlabel(symbol_name)
ax[0].set_ylabel('Number of Cache Misses')
ax[0].set_title(name)
ax[0].legend(fancybox=True, framealpha=0.5)

ax[1].set_ylim(bottom=0, top=max(c) + max(c) / 10)
ax[1].set_xlim(left=0, right=max(x) + 1)
ax[1].set_xlabel(symbol_name)
ax[1].set_ylabel('Operational Intensity')
ax[1].set_title(name)

fig.show()


def compute_mape(f, test_x, test_y, test_set_size):
total_error = 0
for i in range(test_set_size):
pred = f(test_x[i])
err = abs(test_y[i] - pred)
total_error += err / test_y[i]
return total_error / test_set_size


def r_squared(pred, y):
if np.sum(np.square(y - y.mean())) <= 0.0001:
return 1
return 1 - np.sum(np.square(y - pred)) / np.sum(np.square(y - y.mean()))


def find_best_model(x, y, I, J, symbol_name):
""" Find the best model out of all combinations of (i, j) from I and J via leave-one-out cross validation. """
min_error = None
for i in I:
for j in J:
# current model
if i == 0 and j == 0:

def f(x, b):
return b * np.ones_like(x)
else:

def f(x, c, b):
return c * np.power(x, i) * np.power(np.log2(x), j) + b

error_sum = 0
for left_out in range(len(x)):
xx = np.delete(x, left_out)
yy = np.delete(y, left_out)
try:
param, _ = curve_fit(f, xx, yy)

# predict on left out sample
pred = f(x[left_out], *param)
squared_error = np.square(pred - y[left_out])
error_sum += squared_error
except RuntimeError:
# triggered if no fit was found --> give huge error
error_sum += 999999

mean_error = error_sum / len(x)
if min_error is None or mean_error < min_error:
# new best model found
min_error = mean_error
best_i_j = (i, j)
if best_i_j[0] == 0 and best_i_j[1] == 0:

def f_best(x, b):
return b * np.ones_like(x)
else:

def f_best(x, c, b):
return c * np.power(x, best_i_j[0]) * np.power(np.log2(x), best_i_j[1]) + b

# fit best model to all data points
final_p, _ = curve_fit(f_best, x, y)

def final_f(x):
return f_best(x, *final_p)

if best_i_j[0] == 0 and best_i_j[1] == 0:
sympy_f = final_p[0]
else:
sympy_f = sp.simplify(final_p[0] * symbol(symbol_name)**best_i_j[0] *
sp.log(symbol(symbol_name), 2)**best_i_j[1] + final_p[1])
# compute r^2
r_s = r_squared(final_f(x), y)
return final_f, sympy_f, r_s


def fit_curve(x, y, symbol_name):
"""
Fits a function throught the data set.
:param x: The independent values.
:param y: The dependent values.
:param symbol_name: The name of the SDFG symbol.
"""
x = np.array(x).astype(np.int32)
y = np.array(y).astype(np.float64)

# model search space
I = [x / 4 for x in range(13)]
J = [0, 1, 2]
final_f, sympy_final_f, r_s = find_best_model(x, y, I, J, symbol_name)

return final_f, sympy_final_f, r_s
Loading

0 comments on commit a4e0291

Please sign in to comment.