diff --git a/BUILD b/BUILD index 46b95072..ae08e6ee 100644 --- a/BUILD +++ b/BUILD @@ -43,6 +43,15 @@ py_library( deps = [], ) +py_library( + name = "expect_ai_edge_litert_installed", + # This is a dummy rule used as a ai-edge-litert dependency in open-source. + # We expect ai-edge-litert to already be installed on the system, e.g. via + # `pip install ai-edge-litert` + visibility = ["//visibility:public"], + deps = [], +) + py_library( name = "expect_tf_keras_installed", # This is a dummy rule used as a tensorflow dependency in open-source. diff --git a/README.md b/README.md index e5ac22c6..5ab84b37 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,8 @@ Key platform requirements: TF_USE_LEGACY_KERAS=1**, see our [Keras version](tensorflow_gnn/docs/guide/keras_version.md) guide for details. * Apache Beam for distributed graph sampling. + * For some tests or scripts that requires tensorflow.lite it is required to + install ai-edge-litert by using `pip install ai-edge-litert` TF-GNN is developed and tested on Linux. Running on other platforms supported by TensorFlow may be possible. diff --git a/kokoro/github/ubuntu/cpu/build_versioned.sh b/kokoro/github/ubuntu/cpu/build_versioned.sh index 4effe283..cc00e248 100644 --- a/kokoro/github/ubuntu/cpu/build_versioned.sh +++ b/kokoro/github/ubuntu/cpu/build_versioned.sh @@ -39,6 +39,9 @@ rm -rf "$TEST_ROOT" mkdir -p "$TEST_ROOT" ln -s "$(pwd)"/tensorflow_gnn "$TEST_ROOT"/tensorflow_gnn +# Print the OS version +cat /etc/os-release + # Prepend common tag filters to a defined env_var # For example, tests for TF 2.8 shouldn't run RNG-dependent tests # These tag filters are enforced to start with a comma for separation diff --git a/requirements-dev.txt b/requirements-dev.txt index 1c2a0468..6a047f94 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,6 @@ mock wheel torch_geometric -torch < 2 \ No newline at end of file +torch < 2 +# includes glibcxx older than 3.4.29 +ai-edge-litert-nightly \ No newline at end of file diff --git a/tensorflow_gnn/graph/BUILD b/tensorflow_gnn/graph/BUILD index 0a3b0b86..dc8c7a6f 100644 --- a/tensorflow_gnn/graph/BUILD +++ b/tensorflow_gnn/graph/BUILD @@ -263,6 +263,7 @@ tf_py_test( ":graph_tensor_test_utils", "//:expect_absl_installed_testing", "//:expect_tensorflow_installed", + "//:expect_ai_edge_litert_installed", ], ) @@ -280,6 +281,7 @@ tf_py_test( ":readout", "//:expect_absl_installed_testing", "//:expect_tensorflow_installed", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/graph/graph_tensor_ops_test.py b/tensorflow_gnn/graph/graph_tensor_ops_test.py index e2f564c0..2d03c74d 100644 --- a/tensorflow_gnn/graph/graph_tensor_ops_test.py +++ b/tensorflow_gnn/graph/graph_tensor_ops_test.py @@ -25,6 +25,9 @@ from tensorflow_gnn.graph import graph_tensor_ops as ops from tensorflow_gnn.graph import pool_ops from tensorflow_gnn.graph import readout +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import as_tensor = tf.convert_to_tensor as_ragged = tf.ragged.constant @@ -642,7 +645,7 @@ def testTFLite(self): f'got TF {tf.__version__}') converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner(**test_graph_dict)['final_edge_adjacency'] self.assertAllEqual(expected, obtained) @@ -1308,7 +1311,7 @@ def testTFLite(self): f'got TF {tf.__version__}') converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner(**test_graph_dict)['final_edge_adjacency'] self.assertAllEqual(expected, obtained) @@ -1680,7 +1683,7 @@ def testTFLite(self): f'got TF {tf.__version__}') converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner(**test_graph_dict)['final_edge_adjacency'] self.assertAllEqual(expected, obtained) diff --git a/tensorflow_gnn/graph/graph_tensor_test.py b/tensorflow_gnn/graph/graph_tensor_test.py index e9d25246..6a9ae548 100644 --- a/tensorflow_gnn/graph/graph_tensor_test.py +++ b/tensorflow_gnn/graph/graph_tensor_test.py @@ -25,6 +25,7 @@ from tensorflow_gnn.graph import graph_tensor_test_utils as tu # pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter from tensorflow.python.framework import type_spec # pylint: enable=g-direct-tensorflow-import @@ -1549,7 +1550,7 @@ def testTFLite(self): f'got TF {tf.__version__}') converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner( **test_graph_dict)['private__make_graph_tensor_merged'] @@ -1752,7 +1753,7 @@ def testTFLite(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner( node_sizes=test_tensor)['node_set_sizes_to_test_results'] diff --git a/tensorflow_gnn/keras/layers/BUILD b/tensorflow_gnn/keras/layers/BUILD index 4e97129b..aca80548 100644 --- a/tensorflow_gnn/keras/layers/BUILD +++ b/tensorflow_gnn/keras/layers/BUILD @@ -119,6 +119,7 @@ tf_py_test( "//tensorflow_gnn/graph:adjacency", "//tensorflow_gnn/graph:graph_constants", "//tensorflow_gnn/graph:graph_tensor", + "//:expect_ai_edge_litert_installed", ], ) @@ -219,6 +220,7 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn/graph:graph_constants", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) @@ -257,6 +259,7 @@ tf_py_test( "//tensorflow_gnn/graph:preprocessing_common", "//tensorflow_gnn/keras:keras_tensors", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/keras/layers/graph_ops_test.py b/tensorflow_gnn/keras/layers/graph_ops_test.py index 6422f547..f1144d32 100644 --- a/tensorflow_gnn/keras/layers/graph_ops_test.py +++ b/tensorflow_gnn/keras/layers/graph_ops_test.py @@ -23,6 +23,9 @@ from tensorflow_gnn.graph import graph_constants as const from tensorflow_gnn.graph import graph_tensor as gt from tensorflow_gnn.keras.layers import graph_ops +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class ReadoutTest(tf.test.TestCase, parameterized.TestCase): @@ -170,7 +173,7 @@ def testTFLite(self, location): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_134_dict)["test_readout"] self.assertAllEqual(expected, obtained) @@ -304,7 +307,7 @@ def testTFLite(self): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_22_dict)["test_readout_first"] self.assertAllEqual(expected, obtained) @@ -431,7 +434,7 @@ def testTFLite(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") actual = signature_runner( **test_graph_structured_readout_dict)["output_layer"] @@ -565,7 +568,7 @@ def testTFLite(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") actual = signature_runner( **test_graph_structured_readout_dict)["output_layer"] @@ -628,7 +631,7 @@ def testTFLite(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") actual = signature_runner( **test_graph_structured_readout_dict)["output_layer"] @@ -750,7 +753,7 @@ def testTFLite(self): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_134_dict)["final_edge_states"] self.assertAllEqual(expected, obtained) @@ -952,7 +955,7 @@ def testTFLite(self, tag, location): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_values)["test_broadcast"] self.assertAllEqual(expected, obtained) @@ -1260,7 +1263,7 @@ def testTFLite(self, tag, location, reduce_type): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_values)["test_pool"] self.assertAllEqual(expected, obtained) diff --git a/tensorflow_gnn/keras/layers/next_state_test.py b/tensorflow_gnn/keras/layers/next_state_test.py index 5bd35a64..184cb1c7 100644 --- a/tensorflow_gnn/keras/layers/next_state_test.py +++ b/tensorflow_gnn/keras/layers/next_state_test.py @@ -19,6 +19,9 @@ from tensorflow_gnn.graph import graph_constants as const from tensorflow_gnn.keras.layers import next_state as next_state_lib from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class NextStateFromConcatTest(tf.test.TestCase, parameterized.TestCase): @@ -179,7 +182,7 @@ def testTFLite(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_input_dict)["residual_next_state"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/keras/layers/padding_ops_test.py b/tensorflow_gnn/keras/layers/padding_ops_test.py index 6264bf20..5101378f 100644 --- a/tensorflow_gnn/keras/layers/padding_ops_test.py +++ b/tensorflow_gnn/keras/layers/padding_ops_test.py @@ -23,6 +23,9 @@ from tensorflow_gnn.keras import keras_tensors # For registration. pylint: disable=unused-import from tensorflow_gnn.keras.layers import padding_ops from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class PadToTotalSizesTest(tf.test.TestCase, parameterized.TestCase): @@ -119,7 +122,7 @@ def testBasic(self): self.skipTest("Padding ops are unsupported in TFLite.") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/gat_v2/BUILD b/tensorflow_gnn/models/gat_v2/BUILD index 2f6ffb10..67fb29a5 100644 --- a/tensorflow_gnn/models/gat_v2/BUILD +++ b/tensorflow_gnn/models/gat_v2/BUILD @@ -110,5 +110,6 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/gat_v2/layers_test.py b/tensorflow_gnn/models/gat_v2/layers_test.py index b73d4821..1af9833a 100644 --- a/tensorflow_gnn/models/gat_v2/layers_test.py +++ b/tensorflow_gnn/models/gat_v2/layers_test.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for GATv2.""" - from absl.testing import parameterized import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.models import gat_v2 from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class GATv2Test(tf.test.TestCase, parameterized.TestCase): @@ -706,7 +707,7 @@ def testBasic(self): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/gcn/BUILD b/tensorflow_gnn/models/gcn/BUILD index a89ebd63..708c3aac 100644 --- a/tensorflow_gnn/models/gcn/BUILD +++ b/tensorflow_gnn/models/gcn/BUILD @@ -47,5 +47,6 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/gcn/gcn_conv_test.py b/tensorflow_gnn/models/gcn/gcn_conv_test.py index affe1142..cfc651f5 100644 --- a/tensorflow_gnn/models/gcn/gcn_conv_test.py +++ b/tensorflow_gnn/models/gcn/gcn_conv_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for gcn_conv.""" import math from absl.testing import parameterized @@ -20,6 +19,9 @@ import tensorflow_gnn as tfgnn from tensorflow_gnn.models.gcn import gcn_conv from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class GcnConvTest(tf.test.TestCase, parameterized.TestCase): @@ -873,7 +875,7 @@ def testBasic(self, add_self_loops, edge_weight_feature_name): f'got TF {tf.__version__}') converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner('serving_default') obtained = signature_runner(**test_graph_1_dict)['final_node_states'] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/graph_sage/BUILD b/tensorflow_gnn/models/graph_sage/BUILD index a996f0ec..676c302c 100644 --- a/tensorflow_gnn/models/graph_sage/BUILD +++ b/tensorflow_gnn/models/graph_sage/BUILD @@ -46,5 +46,6 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/graph_sage/layers_test.py b/tensorflow_gnn/models/graph_sage/layers_test.py index 2d6745e4..d257e41a 100644 --- a/tensorflow_gnn/models/graph_sage/layers_test.py +++ b/tensorflow_gnn/models/graph_sage/layers_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for graph_sage.""" - import math from absl.testing import parameterized @@ -21,7 +19,9 @@ import tensorflow_gnn as tfgnn from tensorflow_gnn.models.graph_sage import layers as graph_sage from tensorflow_gnn.utils import tf_test_utils as tftu - +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import _FEATURE_NAME = "f" @@ -626,7 +626,7 @@ def testBasic(self, use_pooling, hidden_units, combine_type): f"got TF {tf.__version__}") converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/hgt/BUILD b/tensorflow_gnn/models/hgt/BUILD index ebfc2e0a..235744e3 100644 --- a/tensorflow_gnn/models/hgt/BUILD +++ b/tensorflow_gnn/models/hgt/BUILD @@ -106,5 +106,6 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/hgt/layers_test.py b/tensorflow_gnn/models/hgt/layers_test.py index d185df44..ce518b70 100644 --- a/tensorflow_gnn/models/hgt/layers_test.py +++ b/tensorflow_gnn/models/hgt/layers_test.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for hgt.""" from absl.testing import parameterized import numpy as np import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.models.hgt import layers from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import def _homogeneous_cycle_graph(node_state, edge_state=None): @@ -806,7 +808,7 @@ def testBasic(self): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_engine_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/mt_albis/BUILD b/tensorflow_gnn/models/mt_albis/BUILD index bb7439d4..fae90385 100644 --- a/tensorflow_gnn/models/mt_albis/BUILD +++ b/tensorflow_gnn/models/mt_albis/BUILD @@ -81,6 +81,7 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/mt_albis/layers_test.py b/tensorflow_gnn/models/mt_albis/layers_test.py index 05275318..dd6ba52f 100644 --- a/tensorflow_gnn/models/mt_albis/layers_test.py +++ b/tensorflow_gnn/models/mt_albis/layers_test.py @@ -22,6 +22,9 @@ import tensorflow_gnn as tfgnn from tensorflow_gnn.models.mt_albis import layers from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class MtAlbisNextNodeStateTest(tf.test.TestCase, parameterized.TestCase): @@ -378,7 +381,7 @@ def test(self, converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/multi_head_attention/BUILD b/tensorflow_gnn/models/multi_head_attention/BUILD index 0201ae59..808d8f24 100644 --- a/tensorflow_gnn/models/multi_head_attention/BUILD +++ b/tensorflow_gnn/models/multi_head_attention/BUILD @@ -110,5 +110,6 @@ tf_py_test( "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/utils:tf_test_utils", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/multi_head_attention/layers_test.py b/tensorflow_gnn/models/multi_head_attention/layers_test.py index cdb6634f..fec7a6ef 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers_test.py +++ b/tensorflow_gnn/models/multi_head_attention/layers_test.py @@ -23,6 +23,9 @@ import tensorflow_gnn as tfgnn from tensorflow_gnn.models import multi_head_attention from tensorflow_gnn.utils import tf_test_utils as tftu +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase): @@ -1415,7 +1418,7 @@ def testBasic( converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained) diff --git a/tensorflow_gnn/models/vanilla_mpnn/BUILD b/tensorflow_gnn/models/vanilla_mpnn/BUILD index 5c2375d0..848fd34c 100644 --- a/tensorflow_gnn/models/vanilla_mpnn/BUILD +++ b/tensorflow_gnn/models/vanilla_mpnn/BUILD @@ -109,5 +109,6 @@ tf_py_test( ":vanilla_mpnn", "//:expect_tensorflow_installed", "//tensorflow_gnn", + "//:expect_ai_edge_litert_installed", ], ) diff --git a/tensorflow_gnn/models/vanilla_mpnn/layers_test.py b/tensorflow_gnn/models/vanilla_mpnn/layers_test.py index 96a09047..74f34527 100644 --- a/tensorflow_gnn/models/vanilla_mpnn/layers_test.py +++ b/tensorflow_gnn/models/vanilla_mpnn/layers_test.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for VanillaMPNN.""" from absl.testing import parameterized import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.models import vanilla_mpnn +# pylint: disable=g-direct-tensorflow-import +from ai_edge_litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import # The components of VanillaMPNNGraphUpdate have been tested elsewhere. @@ -139,7 +141,7 @@ def testBasic(self, use_layer_normalization): converter = tf.lite.TFLiteConverter.from_keras_model(model) model_content = converter.convert() - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) signature_runner = interpreter.get_signature_runner("serving_default") obtained = signature_runner(**test_graph_1_dict)["final_node_states"] self.assertAllClose(expected, obtained)