|
15 | 15 | """Keras interface for TPU Embeddings in TF2."""
|
16 | 16 |
|
17 | 17 | from typing import Any, Dict, Iterable, Optional, Union
|
18 |
| -import tensorflow.compat.v2 as tf |
19 |
| - |
20 | 18 |
|
21 |
| -# From tensorflow/python/layers/sparse_core_util.py to avoid circular dependency |
22 |
| -# and avoid creating another separate file. |
23 |
| -def has_sparse_core() -> bool: |
24 |
| - """Check to see if SparseCore is available.""" |
25 |
| - strategy = tf.distribute.get_strategy() |
26 |
| - if not isinstance( |
27 |
| - strategy, |
28 |
| - (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy), |
29 |
| - ): |
30 |
| - return False |
31 |
| - return ( |
32 |
| - strategy.extended.tpu_hardware_feature.embedding_feature |
33 |
| - == tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2 |
34 |
| - ) |
| 19 | +import tensorflow.compat.v2 as tf |
35 | 20 |
|
36 | 21 | _SLOT_NAME_MAPPING = {
|
37 | 22 | # Slot names in Keras optimizer v2 are different compared to the slot names
|
@@ -636,7 +621,7 @@ def __init__(
|
636 | 621 | will be one step old with potential correctness drawbacks). Set to True
|
637 | 622 | for improved performance.
|
638 | 623 | batch_size: Batch size of the input feature. Deprecated, support backward
|
639 |
| - compatibility. Set None for sparse core for proper shape inference. |
| 624 | + compatibility. |
640 | 625 | embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
|
641 | 626 | hardware the layer should run on.
|
642 | 627 | sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
|
@@ -680,7 +665,7 @@ def __init__(
|
680 | 665 | pipeline_execution_with_tensor_core,
|
681 | 666 | sparse_core_embedding_config
|
682 | 667 | )
|
683 |
| - self.batch_size = None if has_sparse_core() else batch_size |
| 668 | + self.batch_size = batch_size |
684 | 669 | self._tpu_call_id = 0
|
685 | 670 |
|
686 | 671 | def _create_tpu_embedding_mid_level_api(
|
|
0 commit comments