Skip to content

Commit

Permalink
Update tensorflow_gnn usage of tf.lite.interpreter to run ai-edge-lit…
Browse files Browse the repository at this point in the history
…ert.interpreter

PiperOrigin-RevId: 686139114
  • Loading branch information
ecalubaquib authored and tensorflower-gardener committed Oct 15, 2024
1 parent 5211122 commit 7befc1a
Show file tree
Hide file tree
Showing 25 changed files with 85 additions and 31 deletions.
9 changes: 9 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ py_library(
deps = [],
)

py_library(
name = "expect_ai_edge_litert_installed",
# This is a dummy rule used as a ai-edge-litert dependency in open-source.
# We expect ai-edge-litert to already be installed on the system, e.g. via
# `pip install ai-edge-litert`
visibility = ["//visibility:public"],
deps = [],
)

py_library(
name = "expect_tf_keras_installed",
# This is a dummy rule used as a tensorflow dependency in open-source.
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Key platform requirements:
TF_USE_LEGACY_KERAS=1**,
see our [Keras version](tensorflow_gnn/docs/guide/keras_version.md) guide for details.
* Apache Beam for distributed graph sampling.
* For some tests or scripts that requires tensorflow.lite it is required to
install ai-edge-litert by using `pip install ai-edge-litert`

TF-GNN is developed and tested on Linux. Running on other platforms supported
by TensorFlow may be possible.
Expand Down
3 changes: 3 additions & 0 deletions kokoro/github/ubuntu/cpu/build_versioned.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ rm -rf "$TEST_ROOT"
mkdir -p "$TEST_ROOT"
ln -s "$(pwd)"/tensorflow_gnn "$TEST_ROOT"/tensorflow_gnn

# Print the OS version
cat /etc/os-release

# Prepend common tag filters to a defined env_var
# For example, tests for TF 2.8 shouldn't run RNG-dependent tests
# These tag filters are enforced to start with a comma for separation
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mock
wheel
torch_geometric
torch < 2
torch < 2
# includes glibcxx older than 3.4.29
ai-edge-litert-nightly
2 changes: 2 additions & 0 deletions tensorflow_gnn/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ tf_py_test(
":graph_tensor_test_utils",
"//:expect_absl_installed_testing",
"//:expect_tensorflow_installed",
"//:expect_ai_edge_litert_installed",
],
)

Expand All @@ -280,6 +281,7 @@ tf_py_test(
":readout",
"//:expect_absl_installed_testing",
"//:expect_tensorflow_installed",
"//:expect_ai_edge_litert_installed",
],
)

Expand Down
9 changes: 6 additions & 3 deletions tensorflow_gnn/graph/graph_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from tensorflow_gnn.graph import graph_tensor_ops as ops
from tensorflow_gnn.graph import pool_ops
from tensorflow_gnn.graph import readout
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import

as_tensor = tf.convert_to_tensor
as_ragged = tf.ragged.constant
Expand Down Expand Up @@ -642,7 +645,7 @@ def testTFLite(self):
f'got TF {tf.__version__}')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -1308,7 +1311,7 @@ def testTFLite(self):
f'got TF {tf.__version__}')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -1680,7 +1683,7 @@ def testTFLite(self):
f'got TF {tf.__version__}')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']
self.assertAllEqual(expected, obtained)
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_gnn/graph/graph_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_gnn.graph import graph_tensor_test_utils as tu

# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
from tensorflow.python.framework import type_spec
# pylint: enable=g-direct-tensorflow-import

Expand Down Expand Up @@ -1549,7 +1550,7 @@ def testTFLite(self):
f'got TF {tf.__version__}')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(
**test_graph_dict)['private__make_graph_tensor_merged']
Expand Down Expand Up @@ -1752,7 +1753,7 @@ def testTFLite(self):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(
node_sizes=test_tensor)['node_set_sizes_to_test_results']
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_gnn/keras/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ tf_py_test(
"//tensorflow_gnn/graph:adjacency",
"//tensorflow_gnn/graph:graph_constants",
"//tensorflow_gnn/graph:graph_tensor",
"//:expect_ai_edge_litert_installed",
],
)

Expand Down Expand Up @@ -219,6 +220,7 @@ tf_py_test(
"//:expect_tensorflow_installed",
"//tensorflow_gnn/graph:graph_constants",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)

Expand Down Expand Up @@ -257,6 +259,7 @@ tf_py_test(
"//tensorflow_gnn/graph:preprocessing_common",
"//tensorflow_gnn/keras:keras_tensors",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)

Expand Down
19 changes: 11 additions & 8 deletions tensorflow_gnn/keras/layers/graph_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from tensorflow_gnn.graph import graph_constants as const
from tensorflow_gnn.graph import graph_tensor as gt
from tensorflow_gnn.keras.layers import graph_ops
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


class ReadoutTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -170,7 +173,7 @@ def testTFLite(self, location):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_134_dict)["test_readout"]
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -304,7 +307,7 @@ def testTFLite(self):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_22_dict)["test_readout_first"]
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -431,7 +434,7 @@ def testTFLite(self):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
actual = signature_runner(
**test_graph_structured_readout_dict)["output_layer"]
Expand Down Expand Up @@ -565,7 +568,7 @@ def testTFLite(self):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
actual = signature_runner(
**test_graph_structured_readout_dict)["output_layer"]
Expand Down Expand Up @@ -628,7 +631,7 @@ def testTFLite(self):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
actual = signature_runner(
**test_graph_structured_readout_dict)["output_layer"]
Expand Down Expand Up @@ -750,7 +753,7 @@ def testTFLite(self):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_134_dict)["final_edge_states"]
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -952,7 +955,7 @@ def testTFLite(self, tag, location):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_values)["test_broadcast"]
self.assertAllEqual(expected, obtained)
Expand Down Expand Up @@ -1260,7 +1263,7 @@ def testTFLite(self, tag, location, reduce_type):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_values)["test_pool"]
self.assertAllEqual(expected, obtained)
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_gnn/keras/layers/next_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from tensorflow_gnn.graph import graph_constants as const
from tensorflow_gnn.keras.layers import next_state as next_state_lib
from tensorflow_gnn.utils import tf_test_utils as tftu
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


class NextStateFromConcatTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -179,7 +182,7 @@ def testTFLite(self):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_input_dict)["residual_next_state"]
self.assertAllClose(expected, obtained)
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_gnn/keras/layers/padding_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from tensorflow_gnn.keras import keras_tensors # For registration. pylint: disable=unused-import
from tensorflow_gnn.keras.layers import padding_ops
from tensorflow_gnn.utils import tf_test_utils as tftu
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


class PadToTotalSizesTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -119,7 +122,7 @@ def testBasic(self):
self.skipTest("Padding ops are unsupported in TFLite.")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]
self.assertAllClose(expected, obtained)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/models/gat_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,6 @@ tf_py_test(
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)
7 changes: 4 additions & 3 deletions tensorflow_gnn/models/gat_v2/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for GATv2."""

from absl.testing import parameterized
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import gat_v2
from tensorflow_gnn.utils import tf_test_utils as tftu
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


class GATv2Test(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -706,7 +707,7 @@ def testBasic(self):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]
self.assertAllClose(expected, obtained)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/models/gcn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ tf_py_test(
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)
6 changes: 4 additions & 2 deletions tensorflow_gnn/models/gcn/gcn_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gcn_conv."""
import math

from absl.testing import parameterized
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models.gcn import gcn_conv
from tensorflow_gnn.utils import tf_test_utils as tftu
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


class GcnConvTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -873,7 +875,7 @@ def testBasic(self, add_self_loops, edge_weight_feature_name):
f'got TF {tf.__version__}')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner('serving_default')
obtained = signature_runner(**test_graph_1_dict)['final_node_states']
self.assertAllClose(expected, obtained)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/models/graph_sage/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ tf_py_test(
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)
8 changes: 4 additions & 4 deletions tensorflow_gnn/models/graph_sage/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_sage."""

import math

from absl.testing import parameterized
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models.graph_sage import layers as graph_sage
from tensorflow_gnn.utils import tf_test_utils as tftu

# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import

_FEATURE_NAME = "f"

Expand Down Expand Up @@ -626,7 +626,7 @@ def testBasic(self, use_pooling, hidden_units, combine_type):
f"got TF {tf.__version__}")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_content = converter.convert()
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
signature_runner = interpreter.get_signature_runner("serving_default")
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]
self.assertAllClose(expected, obtained)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/models/hgt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ tf_py_test(
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/utils:tf_test_utils",
"//:expect_ai_edge_litert_installed",
],
)
Loading

0 comments on commit 7befc1a

Please sign in to comment.