Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
# Generate single f32 code if edge is detected.
isPK = False
if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1):
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if self.gwvw == 1 else 1
module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb])))
else:
Expand All @@ -2023,7 +2023,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
continue
else:
isPK = True
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if vi ==0 else 1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb])))
else:
Expand All @@ -2041,9 +2041,9 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
# Generate single f32 code if edge is detected.
isPK = False
if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1):
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if self.gwvw == 1 else 1
module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb])))
module.add(VCvtBF8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb])))
else:
sb = SelectBit.BYTE_0 if self.gwvw == 1 else SelectBit.BYTE_2
module.add(VCvtBF8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
Expand All @@ -2052,9 +2052,9 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
continue
else:
isPK = True
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if vi ==0 else 1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb])))
module.add(VCvtPkBF8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb])))
else:
sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1
module.add(VCvtPkBF8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
Expand Down
8 changes: 4 additions & 4 deletions tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,8 +2527,8 @@ def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, vop3: Optiona
self.setInst("v_cvt_f32_fp8")

class VCvtBF8toF32(VCvtInstruction):
def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, comment="") -> None:
super().__init__(CvtType.CVT_BF8_to_F32, dst, src, sdwa, None, comment)
def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, vop3: Optional[VOP3PModifiers] = None, comment="") -> None:
super().__init__(CvtType.CVT_BF8_to_F32, dst, src, sdwa, vop3, comment)
self.setInst("v_cvt_f32_bf8")

class VCvtPkFP8toF32(VCvtInstruction):
Expand All @@ -2537,8 +2537,8 @@ def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, vop3: Optiona
self.setInst("v_cvt_pk_f32_fp8")

class VCvtPkBF8toF32(VCvtInstruction):
def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, comment="") -> None:
super().__init__(CvtType.CVT_PK_BF8_to_F32, dst, src, sdwa, None, comment)
def __init__(self, dst, src, sdwa: Optional[SDWAModifiers] = None, vop3: Optional[VOP3PModifiers] = None, comment="") -> None:
super().__init__(CvtType.CVT_PK_BF8_to_F32, dst, src, sdwa, vop3, comment)
self.setInst("v_cvt_pk_f32_bf8")

class VCvtPkF32toFP8(VCvtInstruction):
Expand Down