Skip to content

Commit

Permalink
[Tensilelite] use TensileInstructions to instead globalParameters when
Browse files Browse the repository at this point in the history
need to get the CurrentISA
  • Loading branch information
vin-huang committed Jul 14, 2023
1 parent 0043cae commit d01fc1f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
31 changes: 18 additions & 13 deletions tensilelite/Tensile/Activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from collections import OrderedDict

from .TensileInstructions import Module, TextBlock, HolderContainer, RegisterContainer, \
VCC, EXEC, vgpr, sgpr, Holder, fastdeepcopy, DataType, SNop
VCC, EXEC, vgpr, sgpr, Holder, fastdeepcopy, DataType, SNop, \
TensileInstructions
from .TensileInstructions.Enums import *
from .TensileInstructions.Instructions import *
from .Common import printExit, printWarning, globalParameters
Expand Down Expand Up @@ -445,6 +446,7 @@ def getClippedReluModule(self, cDataType, vgprIn, vgprOut, activationAlpha, acti
return module

def getExpModule(self, cDataType, vgprIn, vgprOut):
ti = TensileInstructions()
module = Module("Exp")
if cDataType.isHalf():
sgprMagic = self.getSgpr(1)
Expand All @@ -457,17 +459,17 @@ def getExpModule(self, cDataType, vgprIn, vgprOut):
sdwa=SDWAModifiers(dst_sel=select_bit, dst_unused=UnusedBit.UNUSED_PRESERVE, \
src0_sel=select_bit), \
comment="exp step 2"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
else:
module.add(VMulF16(dst=self.vgprPrefix(vgprOut), src0=sgpr(Holder(idx=sgprMagic)), src1=self.vgprPrefix(vgprIn), comment="exp step 1"))
module.add(VExpF16(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="exp step 2"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
elif cDataType.isSingle():
module.add(VMulF32(dst=self.vgprPrefix(vgprOut), src0=math.log(math.e,2), src1=self.vgprPrefix(vgprIn), comment="exp step 1"))
module.add(VExpF32(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="exp step 2" ))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
else:
raise RuntimeError("Unsupported data type %s."%cDataType.toDevice("HIP"))
Expand Down Expand Up @@ -575,6 +577,7 @@ def getReluModule(self, cDataType, vgprIn, vgprOut):
return module

def getSigmoidModule(self, cDataType, vgprIn, vgprOut):
ti = TensileInstructions()
self.needCombine = True
module = Module("Sigmoid")
if cDataType.isHalf():
Expand All @@ -588,27 +591,28 @@ def getSigmoidModule(self, cDataType, vgprIn, vgprOut):
module.add(VRcpF16(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), \
sdwa=SDWAModifiers(dst_sel=select_bit, dst_unused=UnusedBit.UNUSED_PRESERVE, src0_sel=select_bit), \
comment="1 / (1 + exp(-x))"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
else:
module.add(VMulF16(dst=self.vgprPrefix(vgprOut), src0=-1.0, src1=self.vgprPrefix(vgprIn), comment=" x = -x"))
module.add(self.getExpModule(cDataType, vgprOut, vgprOut))
module.add(VAddF16(dst=self.vgprPrefix(vgprOut), src0=1.0, src1=self.vgprPrefix(vgprOut), comment="1 + exp(-x)"))
module.add(VRcpF16(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="1 / (1 + exp(-x))"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
elif cDataType.isSingle():
module.add(VMulF32(dst=self.vgprPrefix(vgprOut), src0=-1.0, src1=self.vgprPrefix(vgprIn), comment=" x = -x"))
module.add(self.getExpModule(cDataType, vgprOut, vgprOut))
module.add(VAddF32(dst=self.vgprPrefix(vgprOut), src0=1.0, src1=self.vgprPrefix(vgprOut), comment="1 + exp(-x)" ))
module.add(VRcpF32(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="1 / (1 + exp(-x))" ))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
else:
raise RuntimeError("Unsupported data type %s."%cDataType.toDevice("HIP"))
return module

def getTanhModule(self, cDataType, vgprIn, vgprOut, activationAlpha, activationBeta):
ti = TensileInstructions()
self.needCombine = True
module = Module("Tanh")
if cDataType.isHalf():
Expand All @@ -629,7 +633,7 @@ def getTanhModule(self, cDataType, vgprIn, vgprOut, activationAlpha, activationB
sdwa=SDWAModifiers(dst_sel=select_bit, dst_unused=UnusedBit.UNUSED_PRESERVE, \
src0_sel=select_bit), \
comment="1 / (1 + exp(-x))"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
module.add(VFmaPKF16(dst=self.vgprPrefix(vgprOut), src0=-2.0, src1=self.vgprPrefix(vgprOut), src2=1.0, \
vop3=VOP3PModifiers(op_sel_hi=[0,1,0,1]), comment="tanh(x) = (1 / (e^2x + 1)) * (-2) + 1"))
Expand All @@ -644,7 +648,7 @@ def getTanhModule(self, cDataType, vgprIn, vgprOut, activationAlpha, activationB
module.add(self.getExpModule(cDataType, vgprOut, vgprOut))
module.add(VAddF16(dst=self.vgprPrefix(vgprOut), src0=1.0, src1=self.vgprPrefix(vgprOut), comment="e^2x + 1"))
module.add(VRcpF16(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="1 / (1 + exp(-x))"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
module.add(VFmaF16(dst=self.vgprPrefix(vgprOut), src0=-2.0, src1=self.vgprPrefix(vgprOut), src2=1.0, comment="tanh(x) = (1 / (e^2x + 1)) * (-2) + 1"))
if activationBeta:
Expand All @@ -658,7 +662,7 @@ def getTanhModule(self, cDataType, vgprIn, vgprOut, activationAlpha, activationB
module.add(self.getExpModule(cDataType, vgprOut, vgprOut))
module.add(VAddF32(dst=self.vgprPrefix(vgprOut), src0=1.0, src1=self.vgprPrefix(vgprOut), comment="e^2x + 1"))
module.add(VRcpF32(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="1 / (e^2x + 1)"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
module.add(VFmaF32(dst=self.vgprPrefix(vgprOut), src0=-2.0, src1=self.vgprPrefix(vgprOut), src2=1.0, comment="(-2) * (1 / (e^2x + 1)) + 1"))
if activationBeta:
Expand All @@ -668,6 +672,7 @@ def getTanhModule(self, cDataType, vgprIn, vgprOut, activationAlpha, activationB
return module

def getDGeluModule(self, cDataType, vgprIn, vgprOut):
ti = TensileInstructions()
self.needCombine = True
module = Module("Gradient Gelu")
# x1 = (0.0535161 * pow(x, 3) + 0.398942 * x)
Expand Down Expand Up @@ -698,7 +703,7 @@ def getDGeluModule(self, cDataType, vgprIn, vgprOut):
module.add(VAddF32(dst=self.vgprPrefix(vgprOut), src0=vgpr(Holder(idx=vgprTemp3)), src1=vgpr(Holder(idx=vgprTemp1)), comment="out = e^xx + e^-xx"))
module.add(VSubF32(dst=vgpr(Holder(idx=vgprTemp1)), src0=vgpr(Holder(idx=vgprTemp3)), src1=vgpr(Holder(idx=vgprTemp1)), comment="tmp1 = e^xx - e^-xx"))
module.add(VRcpF32(dst=vgpr(Holder(idx=vgprTemp3)), src=self.vgprPrefix(vgprOut), comment="tmp3 = 1/out"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
module.add(VMulF32(dst=vgpr(Holder(idx=vgprTemp3)), src0=vgpr(Holder(idx=vgprTemp1)), src1=vgpr(Holder(idx=vgprTemp3)), comment="tmp3 = tmp1 * tmp3"))
if self.enableGuard:
Expand All @@ -711,7 +716,7 @@ def getDGeluModule(self, cDataType, vgprIn, vgprOut):
module.add(VMulF32(dst=vgpr(Holder(idx=vgprTemp1)), src0=0.5, src1=vgpr(Holder(idx=vgprTemp3)), comment="tmp1 = 0.5 * tmp1"))
module.add(VMulF32(dst=self.vgprPrefix(vgprOut), src0=self.vgprPrefix(vgprOut), src1=self.vgprPrefix(vgprOut), comment="out = out * out"))
module.add(VRcpF32(dst=self.vgprPrefix(vgprOut), src=self.vgprPrefix(vgprOut), comment="out = 1/out"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
else:
module.add(self.getTanhModule(cDataType, Holder(idx=vgprTemp1), vgprOut, "", ""))
Expand All @@ -721,7 +726,7 @@ def getDGeluModule(self, cDataType, vgprIn, vgprOut):
module.add(VMulF32(dst=vgpr(Holder(idx=vgprTemp1)), src0=0.5, src1=self.vgprPrefix(vgprOut), comment="tmp1 = 0.5 * tmp1"))
module.add(VMulF32(dst=vgpr(Holder(idx=vgprTemp3)), src0=vgpr(Holder(idx=vgprTemp3)), src1=vgpr(Holder(idx=vgprTemp3)), comment="out = out * out"))
module.add(VRcpF32(dst=self.vgprPrefix(vgprOut), src=vgpr(Holder(idx=vgprTemp3)), comment="out = 1/out"))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["TransOpWait"]:
if ti.getArchCaps()["TransOpWait"]:
module.add(SNop(waitState=0, comment="1 wait states")) #workaround for emulator
coef = floatUnion(f=4)
module.add(VMulF32(dst=self.vgprPrefix(vgprOut), src0=hex(coef.u), src1=self.vgprPrefix(vgprOut), comment="out = 4 * out"))
Expand Down
12 changes: 6 additions & 6 deletions tensilelite/Tensile/Components/PackData.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
VOrB32, VPackF16toB32, \
VAndOrB32, VBfeU32, VLShiftLeftB16, \
VRndneF32, SNop, SMovkI32, VMovB32, VMed3I32, \
vgpr, sgpr, DataType
vgpr, sgpr, DataType, TensileInstructions
from ..Component import PackData
from ..Common import globalParameters

Expand Down Expand Up @@ -93,7 +93,7 @@ def __call__(self, gwvw, destIdx, elementSumIdx, bf16CVTVgprStruct, tmpS01, lane
class PackData_INT8(PackData):
kernel = {"ProblemType": {"DestDataType": DataType(DataType.int8)}}
def __call__(self, gwvw, destIdx, elementSumIdx, tmpVgpr, tmpS01, SaturateTypeInt8 = SaturateCastType.NORMAL, inputPrefix="", prefixOffset=0):

ti = TensileInstructions()
module = Module("PackData int8")
gwvw4 = (gwvw // 4) * 4
for vi in range(0, gwvw4):
Expand All @@ -110,19 +110,19 @@ def __call__(self, gwvw, destIdx, elementSumIdx, tmpVgpr, tmpS01, SaturateTypeIn
src1=vgpr(formatting(sumIdxV-2, inputPrefix, prefixOffset)), \
sdwa=SDWAModifiers(dst_sel=SelectBit.DWORD, dst_unused=UnusedBit.UNUSED_PAD, \
src0_sel=SelectBit.BYTE_0, src1_sel=SelectBit.DWORD)))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["SDWAWait"]:
if ti.getArchCaps()["SDWAWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
module.add(VOrB32(dst=vgpr(formatting(sumIdxV-2, inputPrefix, prefixOffset)), \
src0=vgpr(formatting(sumIdxV-1, inputPrefix, prefixOffset)), \
src1=vgpr(formatVgpr), \
sdwa=SDWAModifiers(dst_sel=SelectBit.WORD_1, dst_unused=UnusedBit.UNUSED_PAD, \
src0_sel=SelectBit.BYTE_0, src1_sel=SelectBit.DWORD)))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["SDWAWait"]:
if ti.getArchCaps()["SDWAWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
module.add(VOrB32(dst=vgpr(d), src0=vgpr(formatting(sumIdxV-3, inputPrefix, prefixOffset)), src1=vgpr(formatting(sumIdxV-2, inputPrefix, prefixOffset)), \
sdwa=SDWAModifiers(dst_sel=SelectBit.DWORD, dst_unused=UnusedBit.UNUSED_PAD, \
src0_sel=SelectBit.WORD_0, src1_sel=SelectBit.DWORD)))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["SDWAWait"]:
if ti.getArchCaps()["SDWAWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
# Left
for vi in range(gwvw4, gwvw):
Expand All @@ -138,7 +138,7 @@ def __call__(self, gwvw, destIdx, elementSumIdx, tmpVgpr, tmpS01, SaturateTypeIn
src1=vgpr(formatting(sumIdxV, inputPrefix, prefixOffset)), \
sdwa=SDWAModifiers(dst_sel=SelectBit.DWORD, dst_unused=UnusedBit.UNUSED_PAD, \
src0_sel=SelectBit.BYTE_0, src1_sel=SelectBit.DWORD)))
if globalParameters["ArchCaps"][globalParameters["CurrentISA"]]["SDWAWait"]:
if ti.getArchCaps()["SDWAWait"]:
module.add(SNop(waitState=0, comment="1 wait states"))
elif vi + 1 >= gwvw:
module.add(VSaturateCastInt(sumIdxV, tmpVgpr, tmpS01, -128, 127, type=SaturateTypeInt8, initGpr=True))
Expand Down
3 changes: 1 addition & 2 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,7 @@ def __init__( self, kernelMinNaming, kernelSerialNaming ):
##############################################################################
def makeSchedule(self, kernel, tensorParametersA, tensorParametersB, localWriteEndIter, uDu=0, skipGlobalReadInc=False, firstIter=False, lastLoop=False, lastLc=False):

currentIsa = globalParameters["CurrentISA"]
maxVmcnt = globalParameters["AsmCaps"][currentIsa]["MaxVmcnt"]
maxVmcnt = self.states.asmCaps["MaxVmcnt"]

self.codes.unrollLoopHeader = Module()
# schedule of work for each local_read iteration:
Expand Down

0 comments on commit d01fc1f

Please sign in to comment.