Skip to content

Commit 8389fc3

Browse files
author
TensorFlow Recommenders Authors
committed
Special case batch_size setting by hardware type.
PiperOrigin-RevId: 690792941
1 parent 36a1836 commit 8389fc3

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,24 @@
1515
"""Keras interface for TPU Embeddings in TF2."""
1616

1717
from typing import Any, Dict, Iterable, Optional, Union
18-
1918
import tensorflow.compat.v2 as tf
2019

20+
21+
# From tensorflow/python/layers/sparse_core_util.py to avoid circular dependency
22+
# and avoid creating another separate file.
23+
def has_sparse_core() -> bool:
24+
"""Check to see if SparseCore is available."""
25+
strategy = tf.distribute.get_strategy()
26+
if not isinstance(
27+
strategy,
28+
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
29+
):
30+
return False
31+
return (
32+
strategy.extended.tpu_hardware_feature.embedding_feature
33+
== tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
34+
)
35+
2136
_SLOT_NAME_MAPPING = {
2237
# Slot names in Keras optimizer v2 are different compared to the slot names
2338
# in our API.
@@ -621,7 +636,7 @@ def __init__(
621636
will be one step old with potential correctness drawbacks). Set to True
622637
for improved performance.
623638
batch_size: Batch size of the input feature. Deprecated, support backward
624-
compatibility.
639+
compatibility. Set None for sparse core for proper shape inference.
625640
embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
626641
hardware the layer should run on.
627642
sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
@@ -665,7 +680,7 @@ def __init__(
665680
pipeline_execution_with_tensor_core,
666681
sparse_core_embedding_config
667682
)
668-
self.batch_size = batch_size
683+
self.batch_size = None if has_sparse_core() else batch_size
669684
self._tpu_call_id = 0
670685

671686
def _create_tpu_embedding_mid_level_api(

0 commit comments

Comments
 (0)