Skip to content

Commit a974e2c

Browse files
committed
fix: remove duplicate function
1 parent de6ef5a commit a974e2c

File tree

3 files changed

+36
-64
lines changed

3 files changed

+36
-64
lines changed

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -874,44 +874,18 @@ PjRtComputationClient::ExecuteComputation(
874874
return datas;
875875
}
876876

877-
namespace {
878-
879-
/**
880-
* Filters a list of device strings to include only those with IDs matching
881-
* the provided indices.
882-
*
883-
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
884-
* @param indices List of device IDs to filter by
885-
* @return Filtered list of device strings, or error status if parsing fails
886-
*
887-
* Example:
888-
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
889-
* indices = [1, 3]
890-
* result = ["TPU:1", "TPU:3"]
891-
*/
877+
// wrapped function to handle absl::Span instead of std::vector
892878
absl::Span<const std::string> FilterDevicesByAddressableDevices(
893879
absl::Span<const std::string> devices,
894880
const std::vector<int64_t>& indices) {
895881
static std::vector<std::string> filtered_devices_;
896882
filtered_devices_.clear();
897883
filtered_devices_.reserve(indices.size());
898-
for (auto& index : indices) {
899-
for (auto& device : devices) {
900-
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
901-
if ((std::stoi(device_spec_parts[1]) == index) &&
902-
(std::find(filtered_devices_.begin(), filtered_devices_.end(),
903-
device) == filtered_devices_.end())) {
904-
filtered_devices_.push_back(device);
905-
break;
906-
}
907-
}
908-
}
909-
// Return a span that points to our filtered data
910-
return absl::Span<const std::string>(filtered_devices_);
884+
filtered_devices_ = torch_xla::runtime::util::FilterDevicesByAddressableDevices(
885+
devices, indices);
886+
return absl::MakeConstSpan(filtered_devices_);
911887
}
912888

913-
} // namespace
914-
915889
std::vector<ComputationClient::DataPtr>
916890
PjRtComputationClient::ExecuteReplicated(
917891
const ComputationClient::Computation& computation,

torch_xla/csrc/runtime/util.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <vector>
1616

1717
#include "absl/status/statusor.h"
18+
#include "absl/strings/str_split.h"
1819
#include "absl/types/optional.h"
1920
#include "absl/types/span.h"
2021
#include "torch_xla/csrc/runtime/types.h"
@@ -177,6 +178,36 @@ RaisePythonValueErrorOnFailure(const Func& func) {
177178
throw std::invalid_argument(std::string(result.status().message()));
178179
}
179180

181+
/**
182+
* Filters a list of device strings to include only those with IDs matching
183+
* the provided indices.
184+
*
185+
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
186+
* @param indices List of device IDs to filter by
187+
* @return Filtered list of device strings
188+
*
189+
* Example:
190+
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
191+
* indices = [1, 3]
192+
* result = ["TPU:1", "TPU:3"]
193+
*/
194+
template<typename DeviceContainer>
195+
std::vector<std::string> FilterDevicesByAddressableDevices(
196+
const DeviceContainer& devices, const std::vector<int64_t>& indices) {
197+
std::vector<std::string> filtered_devices_;
198+
filtered_devices_.reserve(indices.size());
199+
for (auto& index : indices) {
200+
for (auto& device : devices) {
201+
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
202+
if (std::stoi(device_spec_parts[1]) == index) {
203+
filtered_devices_.push_back(device);
204+
break;
205+
}
206+
}
207+
}
208+
return filtered_devices_;
209+
}
210+
180211
} // namespace util
181212
} // namespace runtime
182213
} // namespace torch_xla

torch_xla/csrc/tensor_util.cpp

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -837,39 +837,6 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
837837
runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors));
838838
}
839839

840-
namespace {
841-
842-
/**
843-
* Filters a list of device strings to include only those with IDs matching
844-
* the provided indices.
845-
*
846-
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
847-
* @param indices List of device IDs to filter by
848-
* @return Filtered list of device strings, or error status if parsing fails
849-
*
850-
* Example:
851-
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
852-
* indices = [1, 3]
853-
* result = ["TPU:1", "TPU:3"]
854-
*/
855-
std::vector<std::string> FilterDevicesByAddressableDevices(
856-
std::vector<std::string> devices, const std::vector<int64_t>& indices) {
857-
std::vector<std::string> filtered_devices_;
858-
filtered_devices_.reserve(indices.size());
859-
for (auto& index : indices) {
860-
for (auto& device : devices) {
861-
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
862-
if (std::stoi(device_spec_parts[1]) == index) {
863-
filtered_devices_.push_back(device);
864-
break;
865-
}
866-
}
867-
}
868-
return filtered_devices_;
869-
}
870-
871-
} // namespace
872-
873840
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
874841
const std::vector<at::Tensor>& tensors,
875842
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
@@ -899,7 +866,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
899866
if ((!denormalized_tile_assignment.empty()) &&
900867
(denormalized_tile_assignment.size() !=
901868
addressable_devices.size())) {
902-
addressable_devices = FilterDevicesByAddressableDevices(
869+
addressable_devices = torch_xla::runtime::util::FilterDevicesByAddressableDevices(
903870
addressable_devices, denormalized_tile_assignment);
904871
}
905872
}

0 commit comments

Comments
 (0)