Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1904,39 +1904,35 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,

void CreateFlexbufferVector(
const std::unique_ptr<flexbuffers::Builder>& flex_builder,
std::string& name, const mlir::Attribute& attr) {
auto start = flex_builder->StartVector(name.c_str());
std::optional<absl::string_view> key, const mlir::Attribute& attr) {
auto start = key.has_value()
? flex_builder->StartVector(std::string(*key).c_str())
: flex_builder->StartVector();
auto array = mlir::cast<mlir::vhlo::ArrayV1Attr>(attr).getValue();

for (int i = 0; i < array.size(); i++) {
if (llvm::isa<mlir::BoolAttr>(array[i])) {
flex_builder->Bool(name.c_str(),
mlir::cast<mlir::BoolAttr>(array[i]).getValue());
} else if (llvm::isa<mlir::StringAttr>(attr)) {
flex_builder->Bool(mlir::cast<mlir::BoolAttr>(array[i]).getValue());
} else if (llvm::isa<mlir::StringAttr>(array[i])) {
flex_builder->String(
name.c_str(),
mlir::cast<mlir::StringAttr>(array[i]).getValue().str());
} else if (llvm::isa<mlir::vhlo::BooleanV1Attr>(array[i])) {
flex_builder->Bool(
name.c_str(),
mlir::cast<mlir::vhlo::BooleanV1Attr>(array[i]).getValue());
} else if (llvm::isa<mlir::vhlo::StringV1Attr>(array[i])) {
flex_builder->String(
name.c_str(),
mlir::cast<mlir::vhlo::StringV1Attr>(array[i]).getValue().str());
} else if (llvm::isa<mlir::vhlo::IntegerV1Attr>(array[i])) {
flex_builder->Int(name.c_str(),
mlir::cast<mlir::vhlo::IntegerV1Attr>(array[i])
flex_builder->Int(mlir::cast<mlir::vhlo::IntegerV1Attr>(array[i])
.getValue()
.getSExtValue());
} else if (llvm::isa<mlir::vhlo::FloatV1Attr>(array[i])) {
flex_builder->Float(name.c_str(),
mlir::cast<mlir::vhlo::FloatV1Attr>(array[i])
flex_builder->Float(mlir::cast<mlir::vhlo::FloatV1Attr>(array[i])
.getValue()
.convertToFloat());

} else if (llvm::isa<mlir::vhlo::ArrayV1Attr>(array[i])) {
CreateFlexbufferVector(flex_builder, name, array[i]);
CreateFlexbufferVector(flex_builder, std::nullopt, array[i]);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
// CHECK-NEXT: builtin_options_2: {
// CHECK-NEXT: name: "test.TEST_COMPOSITE",
// CHECK-NEXT: decomposition_subgraph_index: 2,
// CHECK-NEXT: composite_attributes: [ 0, 0, 1, 0, 0, 36, 1 ]
// CHECK-NEXT: composite_attributes: [ 109, 121, 95, 97, 114, 114, 97, 121, 0, 1, 97, 0, 1, 98, 0, 2, 6, 4, 20, 20, 1, 21, 1, 1, 1, 9, 40, 2, 36, 1 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: inputs: [ 4, 1 ],
Expand All @@ -97,7 +97,7 @@
// CHECK-NEXT: builtin_options_2: {
// CHECK-NEXT: name: "test.TEST_COMPOSITE",
// CHECK-NEXT: decomposition_subgraph_index: 1,
// CHECK-NEXT: composite_attributes: [ 0, 0, 1, 0, 0, 36, 1 ]
// CHECK-NEXT: composite_attributes: [ 109, 121, 95, 97, 114, 114, 97, 121, 0, 1, 97, 0, 1, 98, 0, 2, 6, 4, 20, 20, 1, 21, 1, 1, 1, 9, 40, 2, 36, 1 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
Expand Down Expand Up @@ -294,8 +294,8 @@
func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> (tensor<10xf32>) {
%cst = arith.constant dense<1.000000e+01> : tensor<f32>
%cst_0 = arith.constant dense<2.000000e+01> : tensor<f32>
%0 = "vhlo.composite_v1"(%arg0, %arg1) <{composite_attributes = #vhlo.dict_v1<{}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%1 = "vhlo.composite_v1"(%0, %arg1) <{composite_attributes = #vhlo.dict_v1<{}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%0 = "vhlo.composite_v1"(%arg0, %arg1) <{composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_array"> = #vhlo.array_v1<[#vhlo.string_v1<"a">, #vhlo.string_v1<"b">]>}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%1 = "vhlo.composite_v1"(%0, %arg1) <{composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_array"> = #vhlo.array_v1<[#vhlo.string_v1<"a">, #vhlo.string_v1<"b">]>}>, decomposition = #vhlo.string_v1<"XlaCallModule_test.TEST_COMPOSITE.impl_0">, name = #vhlo.string_v1<"test.TEST_COMPOSITE">, version = #vhlo.integer_v1<0 : i64>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%2 = tfl.add(%1, %cst) <{fused_activation_function = "NONE"}> : (tensor<10xf32>, tensor<f32>) -> tensor<10xf32>
%3 = tfl.sub(%2, %cst_0) <{fused_activation_function = "NONE"}> : (tensor<10xf32>, tensor<f32>) -> tensor<10xf32>
return %3 : tensor<10xf32>
Expand Down
42 changes: 30 additions & 12 deletions tensorflow/lite/kernels/kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ limitations under the License.

#include <algorithm>
#include <complex>
#include <initializer_list>
#include <limits>
#include <memory>

#ifndef TF_LITE_STATIC_MEMORY
#include <string>

#include "absl/types/span.h"
#include "tensorflow/lite/array.h"
#endif // TF_LITE_STATIC_MEMORY

Expand All @@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/util.h"

#if defined(__APPLE__)
#include "TargetConditionals.h"
Expand Down Expand Up @@ -598,23 +597,42 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
}

TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
absl::Span<const int> dims,
std::initializer_list<int> dims,
const char* error_message, size_t& product) {
// The CheckedNumElements function already checks for negative dimensions, so
// we don't do it here.
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
"%s", error_message);
size_t checked_count = 1;
for (const int d : dims) {
TF_LITE_ENSURE_MSG(context, d >= 0, "Encountered a negative dimension.");
TF_LITE_ENSURE_MSG(
context,
checked_count == 0 ||
static_cast<size_t>(d) <=
std::numeric_limits<size_t>::max() / checked_count,
"%s", error_message);
checked_count *= d;
}
product = checked_count;
return kTfLiteOk;
}

TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
absl::Span<const int> dims,
std::initializer_list<int> dims,
const char* error_message, int& product) {
for (const int dim : dims) {
TF_LITE_ENSURE_MSG(context, dim >= 0, "Encountered a negative dimension.");
size_t checked_count = 1;
for (const int d : dims) {
TF_LITE_ENSURE_MSG(context, d >= 0, "Encountered a negative dimension.");
TF_LITE_ENSURE_MSG(
context,
checked_count == 0 ||
static_cast<size_t>(d) <=
std::numeric_limits<size_t>::max() / checked_count,
"%s", error_message);
checked_count *= d;
}
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
"%s", error_message);
TF_LITE_ENSURE_MSG(
context,
checked_count <= static_cast<size_t>(std::numeric_limits<int>::max()),
"%s", error_message);
product = static_cast<int>(checked_count);
return kTfLiteOk;
}

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ limitations under the License.
#include <stdint.h>

#include <cstddef>
#include <initializer_list>
#include <limits>
#ifndef TF_LITE_STATIC_MEMORY
#include <string>
#endif // TF_LITE_STATIC_MEMORY

#include "absl/types/span.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#ifndef NDEBUG
Expand Down Expand Up @@ -352,7 +352,7 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor);
* @param product The output parameter to store the product.
*/
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
absl::Span<const int> dims,
std::initializer_list<int> dims,
const char* error_message, size_t& product);

/**
Expand All @@ -364,7 +364,7 @@ TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
* @param product The output parameter to store the product.
*/
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
absl::Span<const int> dims,
std::initializer_list<int> dims,
const char* error_message, int& product);

} // namespace tflite
Expand Down
133 changes: 76 additions & 57 deletions tensorflow/python/grappler/item_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the swig wrapper of items."""
"""Tests for the pybind11 wrapper of Grappler items."""

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand All @@ -30,82 +30,81 @@


class ItemTest(test.TestCase):
"""Unit tests for Grappler Item pybind11 wrapper functionality."""

def testInvalidItem(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b # pylint: disable=unused-variable
mg = meta_graph.create_meta_graph_def(graph=g)

# The train op isn't specified: this should raise an InvalidArgumentError
# exception.
with self.assertRaises(errors_impl.InvalidArgumentError):
item.Item(mg)

def testImportantOps(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_list = grappler_item.IdentifyImportantOps()
self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)

def testOpProperties(self):
def _create_sample_metagraph(self, include_train_op=True):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
z = control_flow_ops.no_op()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()

# All the nodes in this model have one scalar output
for node in grappler_item.metagraph.graph_def.node:
node_prop = op_properties[node.name]

if node.name == z.name:
self.assertEqual(0, len(node_prop))
else:
self.assertEqual(1, len(node_prop))
self.assertEqual(dtypes.int32, node_prop[0].dtype)
self.assertEqual(tensor_shape.TensorShape([]), node_prop[0].shape)

def testUpdates(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
if include_train_op:
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
return meta_graph.create_meta_graph_def(graph=g), z

def test_invalid_item_missing_train_op_raises(self):
"""Verifies that Item raises InvalidArgumentError when train_op is missing."""
mg, _ = self._create_sample_metagraph(include_train_op=False)
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'train_op not specified in the metagraph'):
item.Item(mg)

def test_important_ops_identification(self):
"""Verifies important ops are correctly identified from the metagraph."""
mg, _ = self._create_sample_metagraph(include_train_op=True)
grappler_item = item.Item(mg)
op_list = grappler_item.IdentifyImportantOps()
self.assertCountEqual(['Const', 'Const_1', 'add'], op_list)

def test_op_properties_extraction(self):
"""Verifies op properties are correctly extracted for graph nodes."""
mg, z = self._create_sample_metagraph(include_train_op=True)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()

z_prop = op_properties[z.name]
self.assertEmpty(z_prop)

const_prop = op_properties['Const']
self.assertLen(const_prop, 1)
self.assertEqual(dtypes.int32, const_prop[0].dtype)
self.assertEqual(tensor_shape.TensorShape([]), const_prop[0].shape)

def test_tf_item_initial_properties_equal(self):
"""Verifies initial tf_item properties are consistent and equal."""
mg, _ = self._create_sample_metagraph(include_train_op=True)
grappler_item = item.Item(mg)
initial_tf_item = grappler_item.tf_item
no_change_tf_item = grappler_item.tf_item
self.assertEqual(initial_tf_item, no_change_tf_item)

# Modify the placement.
def test_tf_item_device_modification_updates_wrapper(self):
"""Verifies modifying node placement creates a new underlying tf_item."""
mg, _ = self._create_sample_metagraph(include_train_op=True)
grappler_item = item.Item(mg)
initial_tf_item = grappler_item.tf_item
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
new_tf_item = grappler_item.tf_item
self.assertNotEqual(initial_tf_item, new_tf_item)

# Assign the same placement.
def test_tf_item_identical_device_reassignment_unchanged(self):
"""Verifies re-assigning identical placement keeps tf_item unchanged."""
mg, _ = self._create_sample_metagraph(include_train_op=True)
grappler_item = item.Item(mg)
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
new_tf_item = grappler_item.tf_item
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
newest_tf_item = grappler_item.tf_item
self.assertEqual(new_tf_item, newest_tf_item)

@test_util.run_v1_only('b/120545219')
def testColocationConstraints(self):
def test_colocation_constraints(self):
"""Verifies colocation constraints are correctly grouped."""
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
v = variable_v1.VariableV1([3], dtype=dtypes.int32)
Expand All @@ -116,10 +115,30 @@ def testColocationConstraints(self):
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
groups = grappler_item.GetColocationGroups()
self.assertEqual(len(groups), 1)
self.assertItemsEqual(
self.assertLen(groups, 1)
self.assertCountEqual(
groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])

@test_util.run_v1_only('b/120545219')
def test_colocation_constraints_missing_input(self):
"""Verifies standalone nodes with missing inputs are correctly grouped."""
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
v = variable_v1.VariableV1([3], dtype=dtypes.int32)
i = gen_array_ops.ref_identity(v)
a = state_ops.assign(i, c)
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(a)
mg = meta_graph.create_meta_graph_def(graph=g)
for node in mg.graph_def.node:
if node.op == 'Assign':
del node.input[:]
grappler_item = item.Item(mg)
groups = grappler_item.GetColocationGroups()
self.assertLen(groups, 1)
self.assertCountEqual(
groups[0], ['RefIdentity', 'Variable'])


if __name__ == '__main__':
test.main()
Loading
Loading