Skip to content

Commit

Permalink
support local v4 TPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewyates committed Sep 22, 2022
1 parent f7bfaaa commit 53f34c4
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 30 deletions.
31 changes: 4 additions & 27 deletions capreolus/trainer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
KerasLCEModel,
TFLCELoss,
)
from tensorflow.keras.mixed_precision import experimental as mixed_precision
import tensorflow.keras.mixed_precision as mixed_precision


logger = get_logger(__name__)
Expand All @@ -36,29 +36,6 @@ def get_available_gpus():
return [x.name for x in local_device_protos if x.device_type == "GPU"]


class LocalTPUClusterResolver(tf.distribute.cluster_resolver.TPUClusterResolver):
"""LocalTPUClusterResolver."""

def __init__(self):
self._tpu = ""
self.task_type = "worker"
self.task_id = 0

def master(self, task_type=None, task_id=None, rpc_layer=None):
return None

def cluster_spec(self):
return tf.train.ClusterSpec({})

def get_tpu_system_metadata(self):
return tf.tpu.experimental.TPUSystemMetadata(
num_cores=8, num_hosts=1, num_of_cores_per_host=8, topology=None, devices=tf.config.list_logical_devices()
)

def num_accelerators(self, task_type=None, task_id=None, config_proto=None):
return {"TPU": 8}


@Trainer.register
class TensorflowTrainer(Trainer):
"""
Expand Down Expand Up @@ -106,10 +83,10 @@ def build(self):
# Use TPU if available, otherwise resort to GPU/CPU
if self.config["tpuname"]:
if self.config["tpuname"] == "LOCAL":
logger.debug("using TPU VM with LocalTPUClusterResolver")
self.tpu = LocalTPUClusterResolver()
logger.debug("using TPU VM")
self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
else:
logger.debug("using TPU with TPUClusterResolver")
logger.debug("using Cloud TPU")
self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=self.config["tpuname"], zone=self.config["tpuzone"]
)
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- recommonmark
- google-api-python-client
- oauth2client
- tensorflow>=2.3,<2.5
- tensorflow>=2.3,<=2.10
- transformers~=4.9.2
- tensorflow-ranking==0.3.2
- Pillow
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ recommonmark
gdown
google-api-python-client
oauth2client
tensorflow>=2.3,<2.5
tensorflow>=2.3,<=2.10
transformers~=4.9.2
tensorflow-ranking==0.3.2
Pillow
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_version(rel_path):
"scipy",
"google-api-python-client",
"oauth2client",
"tensorflow>=2.3,<2.5",
"tensorflow>=2.3,<=2.10",
"transformers~=4.9.2",
"tensorflow-ranking==0.3.2",
"Pillow",
Expand Down

0 comments on commit 53f34c4

Please sign in to comment.