|
15 | 15 | """Keras interface for TPU Embeddings in TF2."""
|
16 | 16 |
|
17 | 17 | from typing import Any, Dict, Iterable, Optional, Union
|
18 |
| - |
19 | 18 | import tensorflow.compat.v2 as tf
|
20 | 19 |
|
| 20 | + |
| 21 | +# From tensorflow/python/layers/sparse_core_util.py to avoid circular dependency |
| 22 | +# and avoid creating another separate file. |
| 23 | +def has_sparse_core() -> bool: |
| 24 | + """Check to see if SparseCore is available.""" |
| 25 | + strategy = tf.distribute.get_strategy() |
| 26 | + if not isinstance( |
| 27 | + strategy, |
| 28 | + (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy), |
| 29 | + ): |
| 30 | + return False |
| 31 | + return ( |
| 32 | + strategy.extended.tpu_hardware_feature.embedding_feature |
| 33 | + == tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2 |
| 34 | + ) |
| 35 | + |
21 | 36 | _SLOT_NAME_MAPPING = {
|
22 | 37 | # Slot names in Keras optimizer v2 are different compared to the slot names
|
23 | 38 | # in our API.
|
@@ -621,7 +636,7 @@ def __init__(
|
621 | 636 | will be one step old with potential correctness drawbacks). Set to True
|
622 | 637 | for improved performance.
|
623 | 638 | batch_size: Batch size of the input feature. Deprecated, support backward
|
624 |
| - compatibility. |
| 639 | + compatibility. Set None for sparse core for proper shape inference. |
625 | 640 | embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
|
626 | 641 | hardware the layer should run on.
|
627 | 642 | sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
|
@@ -665,7 +680,7 @@ def __init__(
|
665 | 680 | pipeline_execution_with_tensor_core,
|
666 | 681 | sparse_core_embedding_config
|
667 | 682 | )
|
668 |
| - self.batch_size = batch_size |
| 683 | + self.batch_size = None if has_sparse_core() else batch_size |
669 | 684 | self._tpu_call_id = 0
|
670 | 685 |
|
671 | 686 | def _create_tpu_embedding_mid_level_api(
|
|
0 commit comments