Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.utils.buffer_donor_context import alias_with_buffer_donor_config
from torch_xla.distributed.spmd.xla_sharding import get_xla_sharding_specs
import torch_xla.utils.dlpack as torch_xla_dlpack

dynamo_debug = int(os.environ.get('XLA_DYNAMO_DEBUG', '0')) == 1
Expand Down Expand Up @@ -339,8 +340,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
}

if xr.is_spmd():
xla_args_sharding_spec = torch_xla._XLAC._get_xla_sharding_specs(
xla_args_tensor_only)
xla_args_sharding_spec = get_xla_sharding_specs(xla_args_tensor_only)
else:
xla_args_sharding_spec = ()

Expand Down Expand Up @@ -531,7 +531,7 @@ def optimized_mod(*args: tuple):
# if the input sharding was the same for skip_checking_input_sharding_threashold times
# we will skip checking the input sharding since it can be expensive.
if skip_checking_input_sharding_threashold > 0:
if torch_xla._XLAC._get_xla_sharding_specs(
if get_xla_sharding_specs(
xla_args_tensor_only) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,29 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
return t


def get_xla_sharding_specs(tensors: list) -> list:
"""
Returns XLA sharding specs for the given tensors, normalizing unsharded
tensors to '{replicated}'.

Unsharded tensors have an empty sharding spec, but after dispatch they are
annotated as '{replicated}'. Without normalization, the empty spec and the
post-dispatch replicated spec would differ under equality comparison, causing
unnecessary graph retracing on every step.

Args:
tensors (list[torch.Tensor]): XLA tensors to query sharding specs for.

Returns:
list[str]: HLO sharding spec strings, one per tensor. Unsharded tensors
are returned as '{replicated}'.
"""
return [
s if s else "{replicated}"
for s in torch_xla._XLAC._get_xla_sharding_specs(tensors)
]


def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
mesh_shape=None,
partition_spec=None) -> XLAShardedTensor:
Expand Down
Loading