Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
Expand All @@ -50,11 +52,13 @@
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distributed_backend
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
257 changes: 257 additions & 0 deletions keras/src/backend/torch/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal

import torch
import torch.distributed as dist


def compute_gradients(
loss: torch.Tensor, trainable_vars: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Computes gradients of the loss with respect to trainable variables.

This function leverages PyTorch's `autograd.grad` for a stateless,
functional approach similar to `jax.grad`.

Args:
loss (torch.Tensor): The loss value for which to compute gradients.
trainable_vars (List[torch.Tensor]): A list of variables (tensors with
`requires_grad=True`) to compute gradients with respect to.

Returns:
List[torch.Tensor]: A list of gradients corresponding to the
trainable variables.
"""
return list(torch.autograd.grad(loss, trainable_vars))


def apply_gradients(
gradients: List[torch.Tensor],
trainable_vars: List[torch.Tensor],
learning_rate: float = 0.001,
) -> List[torch.Tensor]:
"""Applies gradients and returns the updated variables.

Updates are performed in-place within a `torch.no_grad()` context
to prevent the update operation from being part of the computation graph.
"""
with torch.no_grad():
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
var.sub_(learning_rate * grad)
updated_vars.append(var)
return updated_vars


def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]:
"""Creates a configuration dictionary for a PyTorch optimizer.

This function returns a dictionary containing the optimizer's configuration,
maintaining a consistent interface with the JAX backend. The user is
expected to instantiate the optimizer from this config.

Args:
optimizer_class (str): The name of the optimizer to create (e.g.,
`"adam"`, `"sgd"`).
**kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`).

Returns:
Dict[str, Any]: A dictionary representing the optimizer configuration.
"""
config = kwargs.copy()
config["name"] = optimizer_class.lower()
config.setdefault("learning_rate", 0.001)
return config


def get_device_info() -> Dict[str, Any]:
"""Retrieves information about the available PyTorch devices.

Returns:
Dict[str, Any]: A dictionary containing the backend name, a list of
available device strings, and the total device count.
"""
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
devices = [torch.cuda.get_device_name(i) for i in range(device_count)]
else:
device_count = 1
devices = ["cpu"]
return {
"backend": "pytorch",
"devices": devices,
"device_count": device_count,
}


def is_multi_device_capable() -> bool:
"""Checks if more than one CUDA device is available.

Returns:
bool: `True` if PyTorch reports more than one CUDA device, `False`
otherwise.
"""
return torch.cuda.device_count() > 1


def get_communication_ops() -> Dict[str, Callable]:
"""Provides a dictionary of PyTorch collective communication operations.

These operations rely on the `torch.distributed` package. They are
designed to work in a multi-process, multi-device environment. If the
distributed package is not initialized, they provide a sensible fallback
for single-device execution.

Returns:
Dict[str, Callable]: A dictionary mapping operation names to their
PyTorch implementations.
"""

def _is_distributed() -> bool:
"""Checks if the default process group is initialized."""
return dist.is_available() and dist.is_initialized()

def all_reduce(
x: torch.Tensor,
op: Literal["sum", "mean"] = "sum",
) -> torch.Tensor:
"""Reduces a tensor across all devices.

Args:
x (torch.Tensor): The tensor to reduce.
op (Literal["sum", "mean"], optional): The reduction operation.
Defaults to "sum".

Returns:
torch.Tensor: The reduced tensor.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if op == "sum":
return x * float(world_size)
elif op == "mean":
return x
else:
raise ValueError(f"Unsupported all_reduce op: {op}")

reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get(
op
)
if reduce_op is None:
raise ValueError(f"Unsupported all_reduce op: {op}")

result = x.clone()
dist.all_reduce(result, op=reduce_op)
return result

def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor:
"""Gathers tensors from all devices and concatenates them.

Args:
x (torch.Tensor): The local tensor to gather.
axis (int, optional): The axis along which to concatenate.
Defaults to 0.

Returns:
torch.Tensor: The concatenated tensor from all devices.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
return torch.cat([x] * world_size, dim=axis)

world_size = dist.get_world_size()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(tensor_list, x)
return torch.cat(tensor_list, dim=axis)

def broadcast(x: torch.Tensor, root: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a root device to all other devices.

Args:
x (torch.Tensor): The tensor to broadcast.
root (int, optional): The rank of the source device. Defaults to 0.

Returns:
torch.Tensor: The tensor received from the root device.
"""
if not _is_distributed():
return x

# `dist.broadcast` is in-place.
dist.broadcast(x, src=root)
return x

def scatter(
x: torch.Tensor,
root: int = 0,
axis: int = 0,
) -> torch.Tensor:
"""Scatters a tensor from a root device to all devices.

Note: The current implementation of `dist.scatter` requires the input
tensor `x` to be organized differently for the root process. This
wrapper simplifies it by handling the splitting automatically on the
root process.

Args:
x (torch.Tensor): The tensor on the root device to be scattered.
root (int, optional): The rank of the device holding the tensor.
Defaults to 0.
axis (int, optional): The axis along which to split the tensor.
Defaults to 0.

Returns:
torch.Tensor: The chunk of the tensor for the local device.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)
return torch.chunk(x, world_size, dim=axis)[0]

world_size = dist.get_world_size()
rank = dist.get_rank()

if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)

if rank == root:
scatter_list = list(torch.chunk(x, world_size, dim=axis))
else:
scatter_list = None

chunk_shape = list(x.shape)
chunk_shape[axis] //= world_size
local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device)

dist.scatter(local_chunk, scatter_list, src=root)
return local_chunk

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
"broadcast": broadcast,
"scatter": scatter,
}
Loading
Loading