Skip to content

Commit c60e42e

Browse files
author
TensorFlow Recommenders Authors
committed
Remove special casing by sparse core as it did not work as intended
PiperOrigin-RevId: 702977131
1 parent 151a970 commit c60e42e

File tree

1 file changed

+3
-18
lines changed

1 file changed

+3
-18
lines changed

tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,8 @@
1515
"""Keras interface for TPU Embeddings in TF2."""
1616

1717
from typing import Any, Dict, Iterable, Optional, Union
18-
import tensorflow.compat.v2 as tf
19-
2018

21-
# From tensorflow/python/layers/sparse_core_util.py to avoid circular dependency
22-
# and avoid creating another separate file.
23-
def has_sparse_core() -> bool:
24-
"""Check to see if SparseCore is available."""
25-
strategy = tf.distribute.get_strategy()
26-
if not isinstance(
27-
strategy,
28-
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
29-
):
30-
return False
31-
return (
32-
strategy.extended.tpu_hardware_feature.embedding_feature
33-
== tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
34-
)
19+
import tensorflow.compat.v2 as tf
3520

3621
_SLOT_NAME_MAPPING = {
3722
# Slot names in Keras optimizer v2 are different compared to the slot names
@@ -636,7 +621,7 @@ def __init__(
636621
will be one step old with potential correctness drawbacks). Set to True
637622
for improved performance.
638623
batch_size: Batch size of the input feature. Deprecated, support backward
639-
compatibility. Set None for sparse core for proper shape inference.
624+
compatibility.
640625
embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
641626
hardware the layer should run on.
642627
sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
@@ -680,7 +665,7 @@ def __init__(
680665
pipeline_execution_with_tensor_core,
681666
sparse_core_embedding_config
682667
)
683-
self.batch_size = None if has_sparse_core() else batch_size
668+
self.batch_size = batch_size
684669
self._tpu_call_id = 0
685670

686671
def _create_tpu_embedding_mid_level_api(

0 commit comments

Comments
 (0)