Skip to content

Commit

Permalink
Fix some underlying issues with tensor core sample (#1336)
Browse files Browse the repository at this point in the history

Co-authored-by: Phillip Allen Lane <lane47@lassen709.coral.llnl.gov>
  • Loading branch information
computablee and Phillip Allen Lane authored Jul 29, 2023
1 parent c432824 commit 60b4045
Showing 1 changed file with 36 additions and 51 deletions.
87 changes: 36 additions & 51 deletions samples/codegen/tensor_cores.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.state import StateSubgraphView
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.dispatcher import DefinedType
from typing import Any, List

# Other imports
Expand Down Expand Up @@ -76,6 +77,9 @@ def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG):
def allocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode,
nodedesc: dt.Array, function_stream: CodeIOStream, declaration_stream: CodeIOStream,
allocation_stream: CodeIOStream):
# Make sure the codegen includes the appropriate header files
_include_mma(sdfg)

name = node.data

# Based on the hardware, the total size must be 16^2
Expand All @@ -85,14 +89,16 @@ def allocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int,

# Write a fragment based on the storage type
if nodedesc.storage == dace.StorageType.TensorCore_Accumulator:
declaration_stream.write('wmma::fragment<wmma::accumulator, '
'16, 16, 16, float> {};'.format(name), sdfg, state_id, node)
ctype = 'wmma::fragment<wmma::accumulator, 16, 16, 16, float>'
declaration_stream.write(f'{ctype} {name};', sdfg, state_id, node)
else:
declaration_stream.write(
'wmma::fragment<wmma::matrix_{mat}, '
'16, 16, 16, half, wmma::{maj}_major> '
'{name};'.format(mat=('a' if 'A' in nodedesc.storage.name else 'b'), maj=maj, name=name), sdfg,
state_id, node)
ctype = 'wmma::fragment<wmma::matrix_{mat}, 16, 16, 16, half, wmma::{maj}_major>'.format(
mat=('a' if 'A' in nodedesc.storage.name else 'b'), maj=maj)
declaration_stream.write(f'{ctype} {name};', sdfg, state_id, node)

# Add the ctype to defined_vars so that the codegen can properly pass
# fragments to functions as an object reference.
self._dispatcher.defined_vars.add(name, DefinedType.Stream, ctype)

def deallocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode,
nodedesc: dt.Array, function_stream: CodeIOStream, callsite_stream: CodeIOStream):
Expand Down Expand Up @@ -187,50 +193,29 @@ def _include_mma(sdfg: dace.SDFG):
sdfg.append_global_code(global_code, 'cuda')


@replaces('frag_fill')
def frag_fill(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState, frag: str, fill: Any) -> List[str]:
# Replacement functions receive the SDFG and the current state as the first
# two arguments, followed by all the other arguments. Here we treat them as
# two strings representing the array name to fill and what to fill it with.

# NOTE: If a slice is used in the `frag` argument, the Python frontend
# automatically creates a new array for it, and uses the correct string as
# the argument.
wnode = state.add_write(frag)
tasklet = state.add_tasklet('fill',
set(), {'out'},
'''
wmma::fill_fragment(out, %s);''' % fill,
language=dace.Language.CPP)

state.add_edge(tasklet, 'out', wnode, None, dace.Memlet.from_array(frag, wnode.desc(sdfg)))

_include_mma(sdfg)

# Function has no return value
return []


@replaces('wmma')
def wmma(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState, a_frag: str, b_frag: str,
c_frag: str) -> List[str]:
# Implemented similarly to `frag_fill`, but with inputs and outputs.
anode = state.add_read(a_frag)
bnode = state.add_read(b_frag)
cnode = state.add_write(c_frag)
tasklet = state.add_tasklet('wmma', {'afrag', 'bfrag'}, {'cfrag'},
'''
wmma::mma_sync(cfrag, afrag, bfrag, cfrag);''',
language=dace.Language.CPP)

state.add_edge(anode, None, tasklet, 'afrag', dace.Memlet.from_array(a_frag, anode.desc(sdfg)))
state.add_edge(bnode, None, tasklet, 'bfrag', dace.Memlet.from_array(b_frag, bnode.desc(sdfg)))
state.add_edge(tasklet, 'cfrag', cnode, None, dace.Memlet.from_array(c_frag, cnode.desc(sdfg)))

_include_mma(sdfg)

# Function has no return value
return []
def frag_fill(frag, fill):
# Define a tasklet with the appropriate input and output connectors.
# Then we can directly emit CUDA for the tasklet.
with dace.tasklet(dace.Language.CPP):
val << fill
out >> frag
"""
wmma::fill_fragment(out, val);
"""

def wmma(a_frag, b_frag, c_frag):
# We do the same here as we did with frag_fill. Since c_frag is used
# as both an input and an output, we specify two separate variables
# to be passed to mma_sync and declare c_frag as an input to one and
# an output to the other. This ensures proper dataflow.
with dace.tasklet(dace.Language.CPP):
afrag << a_frag
bfrag << b_frag
cfrag << c_frag
dfrag >> c_frag
"""
wmma::mma_sync(dfrag, afrag, bfrag, cfrag);
"""


############################################################################
Expand Down

0 comments on commit 60b4045

Please sign in to comment.