Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fa536d6
Hoist HoistFusedBitcasts above GemmFusionSwapOperands - this is part …
vwbaker Mar 30, 2026
4011747
[XLA:GPU] Migrate TmaAndLayoutParameterizedTritonEmitterTestSuite to …
pifon2a Mar 30, 2026
5a9c1fa
Move [[nodiscard]] attribute from definition to declaration.
akuegel Mar 30, 2026
7baf72c
[XLA:GPU]: Reapply flag flip for one shot all-reduce after fixes
sohaibiftikhar Mar 30, 2026
4fc7cf2
PR #39497: [XLA:GPU][oneAPI] Register base oneCCL collective support …
nhatleSummer22 Mar 30, 2026
56ae1db
[XLA:GPU] Add command buffer scheduling mode `CONCURRENT_REGIONS`.
thomasjoerg Mar 30, 2026
df778d5
Automated Code Change
tensorflower-gardener Mar 30, 2026
71e5bdd
PR #39025: [xla] Fix potentially expensive spin in ObjectPool
ezhulenev Mar 30, 2026
4378602
Automated Code Change
tensorflower-gardener Mar 30, 2026
5001218
PR #39956: [xla:gpu] Log thunk progress in chronological order
ezhulenev Mar 30, 2026
7dd4e95
PR #39417: [ROCm] Implement empty graph node support and safeguards f…
phambinhfin Mar 30, 2026
5eddbe7
PR #39508: [XLA:GPU][oneAPI] Remove hardcoded spirv-binary in SYCL tests
bhavani-subramanian Mar 30, 2026
27baa93
PR #39931: Add test coverage for AllGatherRemoveDegenerateDims untest…
kredd2506 Mar 30, 2026
6741fb7
Automated Code Change
tensorflower-gardener Mar 30, 2026
afe47aa
[XLA:GPU] Add support for constants via new tiling.
pifon2a Mar 30, 2026
2604111
[XLA:GPU] Fix ReorderFilterAndBiasHloTest.TestCudnnReorderFilterAndBias
Mar 30, 2026
23bb710
Disable unflattener for Shardy to GSPMD fallback.
ekayaaslan Mar 30, 2026
a1cf2c9
PR #39950: [ROCm] Fix hipblasLt Int8 GEMM support and autotuner outpu…
cj401-amd Mar 30, 2026
5a0c0b6
[XLA:GPU] Use copied device info to init members of GpuCostModelStats…
nputikhin Mar 30, 2026
0ff3f3f
[XLA:GPU] Add transposes via the new tiling.
pifon2a Mar 30, 2026
eb81b89
[xla] Add FloatNormalizationExcessPrecisionTest
penpornk Mar 30, 2026
c662557
[XLA:GPU] Move ScatterDeterminismExpander pass to backends/gpu/transf…
akuegel Mar 30, 2026
980b99c
[XLA:GPU] Reroute dots to a detailed cost model in the indexing model
nputikhin Mar 30, 2026
c1b8d76
[XLA:GPU][NFC] Pad the tile sizes before the tiling propagation.
pifon2a Mar 30, 2026
e62b00d
Make CPU devices report as cpu:0 instead of TFRT_CPU_0.
hawkinsp Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/gpu/autotuner/hipblaslt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/backends/gpu/autotuner/hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class HipblasLtBackend : public GpuCodegenBackend {
Compiler* compiler,
const Compiler::GpuTargetConfig* target_config)
: GpuCodegenBackend(autotuner::Backend::HIPBLASLT, debug_options,
compiler, target_config, stream_executor) {}
compiler, target_config, stream_executor,
/*uses_last_output_for_scratch=*/true) {}

absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
GetSupportedConfigs(const HloInstruction& instr) override;
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/backends/gpu/autotuner/rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class RocblasBackend : public GpuCodegenBackend {
const Compiler::GpuTargetConfig* target_config,
bool fp8_lt_fallback = false)
: GpuCodegenBackend(autotuner::Backend::ROCBLAS, debug_options, compiler,
target_config, stream_executor),
target_config, stream_executor,
/*uses_last_output_for_scratch=*/true),
fp8_lt_fallback_(fp8_lt_fallback) {}

absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ cc_library(
"//xla/codegen/tiling:tiling_specification",
"//xla/codegen/tiling/experimental:tile",
"//xla/codegen/tiling/experimental:tiled_hlo",
"//xla/codegen/xtile/codegen:emitter_helpers",
"//xla/codegen/xtile/codegen:fusion_emitter",
"//xla/codegen/xtile/ir:xtile",
"//xla/codegen/xtile/ir/transforms:passes",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,74 +151,6 @@ class WarpSpecializationTritonEmitterTest : public TritonEmitterTest {
}
};

struct TmaAndDotLayoutTestParams {
std::vector<int64_t> lhs_layout;
std::vector<int64_t> rhs_layout;
std::vector<int64_t> out_layout;
bool enable_tma;
};

class TmaAndLayoutParameterizedTritonEmitterTest
: public TritonEmitterTest,
public ::testing::WithParamInterface<TmaAndDotLayoutTestParams> {};

std::string TmaAndDotLayoutTestParamsToString(
const ::testing::TestParamInfo<TmaAndDotLayoutTestParams>& data) {
return absl::StrCat("lhs_", absl::StrJoin(data.param.lhs_layout, "_"),
"_rhs_", absl::StrJoin(data.param.rhs_layout, "_"),
"_out_", absl::StrJoin(data.param.out_layout, "_"),
data.param.enable_tma ? "_tma" : "");
}

INSTANTIATE_TEST_SUITE_P(
TmaAndLayoutParameterizedTritonEmitterTestSuite,
TmaAndLayoutParameterizedTritonEmitterTest,
::testing::ValuesIn({
TmaAndDotLayoutTestParams{{2, 1, 0}, {2, 1, 0}, {2, 1, 0}, false},
TmaAndDotLayoutTestParams{{2, 1, 0}, {2, 1, 0}, {2, 1, 0}, true},
TmaAndDotLayoutTestParams{{0, 2, 1}, {2, 0, 1}, {2, 1, 0}, false},
TmaAndDotLayoutTestParams{{0, 2, 1}, {2, 0, 1}, {2, 1, 0}, true},
TmaAndDotLayoutTestParams{{2, 1, 0}, {2, 1, 0}, {1, 0, 2}, false},
TmaAndDotLayoutTestParams{{2, 1, 0}, {2, 1, 0}, {1, 0, 2}, true},
TmaAndDotLayoutTestParams{{2, 0, 1}, {0, 1, 2}, {2, 0, 1}, false},
TmaAndDotLayoutTestParams{{2, 0, 1}, {0, 1, 2}, {2, 0, 1}, true},
}),
TmaAndDotLayoutTestParamsToString);

TEST_P(TmaAndLayoutParameterizedTritonEmitterTest, Dot) {
const std::string hlo_text = absl::Substitute(
R"(
fdot {
fdot.p0 = f32[32,16,256]{$0} parameter(0)
fdot.p1 = f32[256,16,512]{$1} parameter(1)
lhs.root = f32[32,16,256]{$0} negate(fdot.p0)
frhs.root = f32[256,16,512]{$1} abs(fdot.p1)
ROOT fdot.root = f32[16,32,512]{$2} dot(lhs.root, frhs.root),
lhs_contracting_dims={2}, rhs_contracting_dims={0},
lhs_batch_dims={1}, rhs_batch_dims={1},
algorithm=dot_f32_f32_f32, backend_config={sizes:[32]}
}

ENTRY entry {
entry.p0 = f32[32,16,256]{$0} parameter(0)
entry.p1 = f32[256,16,512]{$1} parameter(1)
ROOT fusion = f32[16,32,512]{$2} fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["1", "16", "64"]}],
"num_warps":"1",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":$3}}}
})",
absl::StrJoin(GetParam().lhs_layout, ","),
absl::StrJoin(GetParam().rhs_layout, ","),
absl::StrJoin(GetParam().out_layout, ","), GetParam().enable_tma);
EXPECT_TRUE(RunAndCompareNoHloPasses(
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}

TEST_F(TritonEmitterTest, BitcastReduceWithStride4Tiling) {
constexpr absl::string_view kHloText = R"(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: sed 's/ENABLE_TMA/false/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6
// RUN: sed 's/ENABLE_TMA/true/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6

fusion {
fdot.p0 = f32[32,16,256]{0,2,1} parameter(0)
fdot.p1 = f32[256,16,512]{2,0,1} parameter(1)
lhs.root = f32[32,16,256]{0,2,1} negate(fdot.p0)
frhs.root = f32[256,16,512]{2,0,1} abs(fdot.p1)
ROOT fdot.root = f32[16,32,512]{2,1,0} dot(lhs.root, frhs.root),
lhs_contracting_dims={2}, rhs_contracting_dims={0},
lhs_batch_dims={1}, rhs_batch_dims={1},
algorithm=dot_f32_f32_f32, backend_config={sizes:[32]}
}

ENTRY main {
entry.p0 = f32[32,16,256]{0,2,1} parameter(0)
entry.p1 = f32[256,16,512]{2,0,1} parameter(1)
ROOT fusion = f32[16,32,512]{2,1,0} fusion(entry.p0, entry.p1),
kind=kCustom, calls=fusion, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["1", "16", "64"]}],
"num_warps":"1",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":ENABLE_TMA}}}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: sed 's/ENABLE_TMA/false/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6
// RUN: sed 's/ENABLE_TMA/true/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6

fusion {
fdot.p0 = f32[32,16,256]{2,0,1} parameter(0)
fdot.p1 = f32[256,16,512]{0,1,2} parameter(1)
lhs.root = f32[32,16,256]{2,0,1} negate(fdot.p0)
frhs.root = f32[256,16,512]{0,1,2} abs(fdot.p1)
ROOT fdot.root = f32[16,32,512]{2,0,1} dot(lhs.root, frhs.root),
lhs_contracting_dims={2}, rhs_contracting_dims={0},
lhs_batch_dims={1}, rhs_batch_dims={1},
algorithm=dot_f32_f32_f32, backend_config={sizes:[32]}
}

ENTRY main {
entry.p0 = f32[32,16,256]{2,0,1} parameter(0)
entry.p1 = f32[256,16,512]{0,1,2} parameter(1)
ROOT fusion = f32[16,32,512]{2,0,1} fusion(entry.p0, entry.p1),
kind=kCustom, calls=fusion, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["1", "16", "64"]}],
"num_warps":"1",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":ENABLE_TMA}}}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: sed 's/ENABLE_TMA/false/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6
// RUN: sed 's/ENABLE_TMA/true/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6

fusion {
fdot.p0 = f32[32,16,256]{2,1,0} parameter(0)
fdot.p1 = f32[256,16,512]{2,1,0} parameter(1)
lhs.root = f32[32,16,256]{2,1,0} negate(fdot.p0)
frhs.root = f32[256,16,512]{2,1,0} abs(fdot.p1)
ROOT fdot.root = f32[16,32,512]{1,0,2} dot(lhs.root, frhs.root),
lhs_contracting_dims={2}, rhs_contracting_dims={0},
lhs_batch_dims={1}, rhs_batch_dims={1},
algorithm=dot_f32_f32_f32, backend_config={sizes:[32]}
}

ENTRY main {
entry.p0 = f32[32,16,256]{2,1,0} parameter(0)
entry.p1 = f32[256,16,512]{2,1,0} parameter(1)
ROOT fusion = f32[16,32,512]{1,0,2} fusion(entry.p0, entry.p1),
kind=kCustom, calls=fusion, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["1", "16", "64"]}],
"num_warps":"1",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":ENABLE_TMA}}}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: sed 's/ENABLE_TMA/false/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6
// RUN: sed 's/ENABLE_TMA/true/' %s > %t.hlo && triton_test_correctness %t.hlo --abs_error_bound=1e-4 --rel_error_bound=1e-6

fusion {
fdot.p0 = f32[32,16,256]{2,1,0} parameter(0)
fdot.p1 = f32[256,16,512]{2,1,0} parameter(1)
lhs.root = f32[32,16,256]{2,1,0} negate(fdot.p0)
frhs.root = f32[256,16,512]{2,1,0} abs(fdot.p1)
ROOT fdot.root = f32[16,32,512]{2,1,0} dot(lhs.root, frhs.root),
lhs_contracting_dims={2}, rhs_contracting_dims={0},
lhs_batch_dims={1}, rhs_batch_dims={1},
algorithm=dot_f32_f32_f32, backend_config={sizes:[32]}
}

ENTRY main {
entry.p0 = f32[32,16,256]{2,1,0} parameter(0)
entry.p1 = f32[256,16,512]{2,1,0} parameter(1)
ROOT fusion = f32[16,32,512]{2,1,0} fusion(entry.p0, entry.p1),
kind=kCustom, calls=fusion, backend_config={
"fusion_backend_config":{
"kind":"__triton_nested_gemm_fusion",
"block_level_fusion_config":{
"output_tiles":[{"sizes":["1", "16", "64"]}],
"num_warps":"1",
"num_ctas":"1",
"num_stages":"1",
"is_tma_allowed":ENABLE_TMA}}}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s --abs_error_bound=6e-1 --rel_error_bound=6e-1

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --abs_error_bound=6e-1 --rel_error_bound=6e-1 --xla_gpu_experimental_enable_tiling_propagation

fusion {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
p0 = s32[3,2,2]{2,1,0} parameter(0)
ROOT convert0 = pred[3,2,2]{2,1,0} convert(p0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = s32[] parameter(0)
param_1 = s32[] parameter(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = s32[] parameter(0)
denominator = s32[] constant(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
numerator = s32[] constant(10)
denominator = s32[] constant(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = s32[] parameter(0)
param_1 = s32[] parameter(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = pred[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = pred[15] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = f32[15] parameter(0)
param_1 = f32[15] parameter(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
p = f32[5,7] parameter(0)
ROOT rp = f32[5,7] reduce-precision(p), exponent_bits=2, mantissa_bits=2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@
// RUN: triton_test_correctness %s

fusion {
p0 = bf16[2048,4,256]{2,1,0} parameter(0)
p0 = bf16[128,4,256]{2,1,0} parameter(0)
c0 = bf16[] constant(0)
reduce = bf16[2048,4]{1,0} reduce(p0, c0), dimensions={2}, to_apply={
reduce = bf16[128,4]{1,0} reduce(p0, c0), dimensions={2}, to_apply={
a = bf16[] parameter(0)
b = bf16[] parameter(1)
ROOT maximum = bf16[] maximum(a, b)
}
add_unnecessary_dim = bf16[1,2048,4]{2,1,0} bitcast(reduce)
upcast = f32[1,2048,4]{2,1,0} convert(add_unnecessary_dim)
some_high_precision_op = f32[1,2048,4]{2,1,0} sqrt(upcast)
downcast = bf16[1,2048,4]{2,1,0} convert(some_high_precision_op)
remove_dim = bf16[2048,4]{1,0} bitcast(downcast)
broadcast = bf16[2048,4,256]{2,1,0} broadcast(remove_dim), dimensions={0,1}
ROOT slice = bf16[2048,4,128]{2,1,0} slice(broadcast),
slice={[0:2048], [0:4], [0:128]}
add_unnecessary_dim = bf16[1,128,4]{2,1,0} bitcast(reduce)
upcast = f32[1,128,4]{2,1,0} convert(add_unnecessary_dim)
some_high_precision_op = f32[1,128,4]{2,1,0} sqrt(upcast)
downcast = bf16[1,128,4]{2,1,0} convert(some_high_precision_op)
remove_dim = bf16[128,4]{1,0} bitcast(downcast)
broadcast = bf16[128,4,256]{2,1,0} broadcast(remove_dim), dimensions={0,1}
ROOT slice = bf16[128,4,128]{2,1,0} slice(broadcast),
slice={[0:128], [0:4], [0:128]}
}
// CHECK: xtile.extract
// CHECK: tt.reduce
// CHECK: tt.broadcast
// CHECK: xtile.insert

ENTRY main {
%p0 = bf16[2048,4,256]{2,1,0} parameter(0)
ROOT fusion = bf16[2048,4,128]{2,1,0} fusion(p0), kind=kCustom,
%p0 = bf16[128,4,256]{2,1,0} parameter(0)
ROOT fusion = bf16[128,4,128]{2,1,0} fusion(p0), kind=kCustom,
calls=fusion, backend_config={
"fusion_backend_config":{
"kind":"__triton",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// RUN: hlo_to_xtileir %s | FileCheck %s
// RUN: triton_test_correctness %s

// RUN: hlo_to_xtileir %s --use_experimental_tiling | FileCheck %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = f32[15,7,3] parameter(0)
ROOT transpose = f32[3,15,7]{2,1,0} transpose(param_0), dimensions={2,0,1}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: triton_test_correctness %s
// RUN: triton_test_correctness %s --xla_gpu_experimental_enable_tiling_propagation

fusion {
param_0 = f32[3,8,20] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1730,9 +1730,8 @@ ENTRY e {
}
EXPECT_THAT(
instr,
GmockMatch(
m::Fusion(m::Parameter(), m::Parameter(), m::Bitcast(m::Parameter()))
.WithFusionKind(HloInstruction::FusionKind::kCustom)));
GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())
.WithFusionKind(HloInstruction::FusionKind::kCustom)));

EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-2, /*arel=*/2e-2}));
}
Expand Down
Loading
Loading