Skip to content

Commit 756be76

Browse files
Fix flexbuffers vector creation bug in MLIR TFLite export
--- When exporting StableHLOComposite attributes, flexbuffers::Builder was being called with the attribute name as a key for every vector element. In flexbuffers, passing a key to a Vector element pushes both the key and the value to the stack, causing the vector to be parsed with twice its intended length (with the key interleaved). This change removes the key from the vector element calls, and only uses the key when starting the parent vector, which matches the flexbuffers::Builder API expectation for Vectors. - Also fixed a typo where `isa<StringAttr>(attr)` was checking the Array attribute instead of the element. - Updated test to pass actual vector values for serialization PiperOrigin-RevId: 941177131
1 parent 482a036 commit 756be76

2 files changed

Lines changed: 13 additions & 17 deletions

File tree

tensorflow/compiler/mlir/lite/flatbuffer_export.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,39 +1904,35 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
19041904

19051905
void CreateFlexbufferVector(
19061906
const std::unique_ptr<flexbuffers::Builder>& flex_builder,
1907-
std::string& name, const mlir::Attribute& attr) {
1908-
auto start = flex_builder->StartVector(name.c_str());
1907+
std::optional<absl::string_view> key, const mlir::Attribute& attr) {
1908+
auto start = key.has_value()
1909+
? flex_builder->StartVector(std::string(*key).c_str())
1910+
: flex_builder->StartVector();
19091911
auto array = mlir::cast<mlir::vhlo::ArrayV1Attr>(attr).getValue();
19101912

19111913
for (int i = 0; i < array.size(); i++) {
19121914
if (llvm::isa<mlir::BoolAttr>(array[i])) {
1913-
flex_builder->Bool(name.c_str(),
1914-
mlir::cast<mlir::BoolAttr>(array[i]).getValue());
1915-
} else if (llvm::isa<mlir::StringAttr>(attr)) {
1915+
flex_builder->Bool(mlir::cast<mlir::BoolAttr>(array[i]).getValue());
1916+
} else if (llvm::isa<mlir::StringAttr>(array[i])) {
19161917
flex_builder->String(
1917-
name.c_str(),
19181918
mlir::cast<mlir::StringAttr>(array[i]).getValue().str());
19191919
} else if (llvm::isa<mlir::vhlo::BooleanV1Attr>(array[i])) {
19201920
flex_builder->Bool(
1921-
name.c_str(),
19221921
mlir::cast<mlir::vhlo::BooleanV1Attr>(array[i]).getValue());
19231922
} else if (llvm::isa<mlir::vhlo::StringV1Attr>(array[i])) {
19241923
flex_builder->String(
1925-
name.c_str(),
19261924
mlir::cast<mlir::vhlo::StringV1Attr>(array[i]).getValue().str());
19271925
} else if (llvm::isa<mlir::vhlo::IntegerV1Attr>(array[i])) {
1928-
flex_builder->Int(name.c_str(),
1929-
mlir::cast<mlir::vhlo::IntegerV1Attr>(array[i])
1926+
flex_builder->Int(mlir::cast<mlir::vhlo::IntegerV1Attr>(array[i])
19301927
.getValue()
19311928
.getSExtValue());
19321929
} else if (llvm::isa<mlir::vhlo::FloatV1Attr>(array[i])) {
1933-
flex_builder->Float(name.c_str(),
1934-
mlir::cast<mlir::vhlo::FloatV1Attr>(array[i])
1930+
flex_builder->Float(mlir::cast<mlir::vhlo::FloatV1Attr>(array[i])
19351931
.getValue()
19361932
.convertToFloat());
19371933

19381934
} else if (llvm::isa<mlir::vhlo::ArrayV1Attr>(array[i])) {
1939-
CreateFlexbufferVector(flex_builder, name, array[i]);
1935+
CreateFlexbufferVector(flex_builder, std::nullopt, array[i]);
19401936
}
19411937
}
19421938

tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/composite.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
// CHECK-NEXT: builtin_options_2: {
8989
// CHECK-NEXT: name: "test.TEST_COMPOSITE",
9090
// CHECK-NEXT: decomposition_subgraph_index: 2,
91-
// CHECK-NEXT: composite_attributes: [ 0, 0, 1, 0, 0, 36, 1 ]
91+
// CHECK-NEXT: composite_attributes: [ 109, 121, 95, 97, 114, 114, 97, 121, 0, 1, 97, 0, 1, 98, 0, 2, 6, 4, 20, 20, 1, 21, 1, 1, 1, 9, 40, 2, 36, 1 ]
9292
// CHECK-NEXT: }
9393
// CHECK-NEXT: }, {
9494
// CHECK-NEXT: inputs: [ 4, 1 ],
@@ -97,7 +97,7 @@
9797
// CHECK-NEXT: builtin_options_2: {
9898
// CHECK-NEXT: name: "test.TEST_COMPOSITE",
9999
// CHECK-NEXT: decomposition_subgraph_index: 1,
100-
// CHECK-NEXT: composite_attributes: [ 0, 0, 1, 0, 0, 36, 1 ]
100+
// CHECK-NEXT: composite_attributes: [ 109, 121, 95, 97, 114, 114, 97, 121, 0, 1, 97, 0, 1, 98, 0, 2, 6, 4, 20, 20, 1, 21, 1, 1, 1, 9, 40, 2, 36, 1 ]
101101
// CHECK-NEXT: }
102102
// CHECK-NEXT: }, {
103103
// CHECK-NEXT: opcode_index: 1,
@@ -294,8 +294,8 @@
294294
func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> (tensor<10xf32>) {
295295
%cst = arith.constant dense<1.000000e+01> : tensor<f32>
296296
%cst_0 = arith.constant dense<2.000000e+01> : tensor<f32>
297-
%0 = "vhlo.composite_v1"(%arg0, %arg1) <{composite_attributes = #vhlo.dict_v1<{}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
298-
%1 = "vhlo.composite_v1"(%0, %arg1) <{composite_attributes = #vhlo.dict_v1<{}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
297+
%0 = "vhlo.composite_v1"(%arg0, %arg1) <{composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_array"> = #vhlo.array_v1<[#vhlo.string_v1<"a">, #vhlo.string_v1<"b">]>}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
298+
%1 = "vhlo.composite_v1"(%0, %arg1) <{composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_array"> = #vhlo.array_v1<[#vhlo.string_v1<"a">, #vhlo.string_v1<"b">]>}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
299299
%2 = tfl.add(%1, %cst) <{fused_activation_function = "NONE"}> : (tensor<10xf32>, tensor<f32>) -> tensor<10xf32>
300300
%3 = tfl.sub(%2, %cst_0) <{fused_activation_function = "NONE"}> : (tensor<10xf32>, tensor<f32>) -> tensor<10xf32>
301301
return %3 : tensor<10xf32>

0 commit comments

Comments
 (0)