diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 71d9ddc4865e..d89ae3819b3c 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -141,6 +141,10 @@ def run_multiprocess(fn: Callable[..., R], Dict of the form {device_ordinal: return_value}, where return_value is the result of calling `fn`. """ + return _WORLD_SIZE + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() if torch_xla._XLAC._xla_runtime_is_initialized(): raise RuntimeError('Runtime is already initialized. Do not use the XLA ' 'device before calling xmp.spawn.') diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 3699a4c89c3f..7011d27497c3 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -307,11 +307,17 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: def _spmd_find_master_ip(current_worker_hostname: str) -> str: import torch_xla.runtime as xr import torch_xla.distributed.spmd as xs + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards # Translate the hostname to an IP address, e.g. for TPUs on GKE. current_worker_ip = socket.gethostbyname(current_worker_hostname) ip_int = int(ip_address(current_worker_ip)) n_dev = xr.global_runtime_device_count() + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices()) # Create a global (n_dev x 2) tensor containing all process indices and IPs, # and find the process 0 IP as the master IP. diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 931115db6d85..7cbc72702bbb 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -90,6 +90,9 @@ def get_xla_supported_devices(devkind: Optional[str] = None, # TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support # multiple device types. if not devkind: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() devices = torch_xla._XLAC._xla_get_devices() return [ f'xla:{i}' @@ -224,6 +227,9 @@ def xla_replication_devices( 'Cannot replicate if number of devices ({}) is different from {}'. format(len(local_devices), len(kind_devices))) replication_devices = [] + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() for device in torch_xla._XLAC._xla_get_all_devices(): # device is like 'CUDA:0' xdev = _utils.parse_xla_device(device) @@ -255,8 +261,12 @@ def set_replication(device: torch.device, devctx = _get_device_context(device=device) devices = [str(x) for x in devices] if devices: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") # sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3'] replication_devices = xla_replication_devices(devices) + traceback.print_stack() + print(f"Current line: {inspect.currentframe().f_lineno}") torch_xla._XLAC._xla_set_replication_devices(replication_devices) devctx.device_index = devices.index(device) else: diff --git a/torch_xla/debug/metrics.py b/torch_xla/debug/metrics.py index 11718e8376ba..e522f35fd48d 100644 --- a/torch_xla/debug/metrics.py +++ b/torch_xla/debug/metrics.py @@ -61,6 +61,9 @@ def clear_all(): def metrics_report(): """Retrieves a string containing the full metrics and counters report.""" + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_metrics_report() @@ -78,6 +81,10 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None): 'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime', 'TransferToDeviceTime', 'TransferFromDeviceTime' ] + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() + return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index aedfd6a801e3..2f9c08c58080 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -115,6 +115,9 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): # which results from the sharding. @property def local_shards(self) -> List[XLAShard]: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() shard_dev = torch_xla._XLAC._get_local_shards([self.global_tensor])[0] replica_ind = torch_xla._XLAC._get_local_shard_replica_and_indices( [self.global_tensor])[0] @@ -128,6 +131,9 @@ def local_shards(self) -> List[XLAShard]: def load_local_shards_(self, shards: List[XLAShard]): data = [s.data for s in shards] devices = [s.shard_device for s in shards] + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) @property diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index ed015f733748..357cae4f409f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -54,6 +54,9 @@ def _extract_backend_config( def jax_import_guard(): # Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang. + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() torch_xla._XLAC._init_computation_client() diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 1946ae05a52b..37fd425c7a6a 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -45,6 +45,9 @@ def set_device_type(pjrt_device: str) -> None: Args: pjrt_device: 'TPU' or 'CPU' """ + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() if torch_xla._XLAC._xla_runtime_is_initialized() and os.environ.get( xenv.PJRT_DEVICE) != pjrt_device: raise RuntimeError( @@ -133,6 +136,9 @@ def local_process_count() -> int: def global_device_count() -> int: """Returns the total number of devices across all processes/hosts.""" + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return len(torch_xla._XLAC._xla_get_all_devices()) @@ -141,6 +147,9 @@ def world_size() -> int: global _WORLD_SIZE if _WORLD_SIZE is not None: return _WORLD_SIZE + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() if torch_xla._XLAC._xla_get_replication_devices_count() == 0: _WORLD_SIZE = 1 else: @@ -158,6 +167,9 @@ def local_device_count() -> int: def addressable_device_count() -> int: """Returns the number of devices visible to this process.""" + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_num_devices() @@ -183,10 +195,16 @@ def local_ordinal() -> int: def process_index() -> int: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_get_process_index() def process_count() -> int: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_get_num_processes() @@ -202,16 +220,25 @@ def host_index() -> int: # API below will be used to query physcial device attribute. def runtime_device_attributes(device: str) -> Dict[str, object]: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_get_device_attributes(device) def global_runtime_device_attributes() -> List[Dict[str, object]]: + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return torch_xla._XLAC._xla_get_all_device_attributes() @functools.lru_cache() def global_runtime_device_count() -> int: """Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.""" + import traceback,inspect + print(f"Current line: {inspect.currentframe().f_lineno}") + traceback.print_stack() return len(torch_xla._XLAC._xla_get_all_runtime_devices())