Skip to content

Commit 21ad6c4

Browse files
kvshbg-awsroot
authored andcommitted
fix: remove duplicate function
1 parent 0e229b5 commit 21ad6c4

File tree

3 files changed

+36
-64
lines changed

3 files changed

+36
-64
lines changed

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -869,44 +869,18 @@ PjRtComputationClient::ExecuteComputation(
869869
return datas;
870870
}
871871

872-
namespace {
873-
874-
/**
875-
* Filters a list of device strings to include only those with IDs matching
876-
* the provided indices.
877-
*
878-
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
879-
* @param indices List of device IDs to filter by
880-
* @return Filtered list of device strings, or error status if parsing fails
881-
*
882-
* Example:
883-
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
884-
* indices = [1, 3]
885-
* result = ["TPU:1", "TPU:3"]
886-
*/
872+
// wrapped function to handle absl::Span instead of std::vector
887873
absl::Span<const std::string> FilterDevicesByAddressableDevices(
888874
absl::Span<const std::string> devices,
889875
const std::vector<int64_t>& indices) {
890876
static std::vector<std::string> filtered_devices_;
891877
filtered_devices_.clear();
892878
filtered_devices_.reserve(indices.size());
893-
for (auto& index : indices) {
894-
for (auto& device : devices) {
895-
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
896-
if ((std::stoi(device_spec_parts[1]) == index) &&
897-
(std::find(filtered_devices_.begin(), filtered_devices_.end(),
898-
device) == filtered_devices_.end())) {
899-
filtered_devices_.push_back(device);
900-
break;
901-
}
902-
}
903-
}
904-
// Return a span that points to our filtered data
905-
return absl::Span<const std::string>(filtered_devices_);
879+
filtered_devices_ = torch_xla::runtime::util::FilterDevicesByAddressableDevices(
880+
devices, indices);
881+
return absl::MakeConstSpan(filtered_devices_);
906882
}
907883

908-
} // namespace
909-
910884
std::vector<ComputationClient::DataPtr>
911885
PjRtComputationClient::ExecuteReplicated(
912886
const ComputationClient::Computation& computation,

torch_xla/csrc/runtime/util.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <vector>
1616

1717
#include "absl/status/statusor.h"
18+
#include "absl/strings/str_split.h"
1819
#include "absl/types/optional.h"
1920
#include "absl/types/span.h"
2021
#include "torch_xla/csrc/runtime/types.h"
@@ -177,6 +178,36 @@ RaisePythonValueErrorOnFailure(const Func& func) {
177178
throw std::invalid_argument(std::string(result.status().message()));
178179
}
179180

181+
/**
182+
* Filters a list of device strings to include only those with IDs matching
183+
* the provided indices.
184+
*
185+
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
186+
* @param indices List of device IDs to filter by
187+
* @return Filtered list of device strings
188+
*
189+
* Example:
190+
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
191+
* indices = [1, 3]
192+
* result = ["TPU:1", "TPU:3"]
193+
*/
194+
template<typename DeviceContainer>
195+
std::vector<std::string> FilterDevicesByAddressableDevices(
196+
const DeviceContainer& devices, const std::vector<int64_t>& indices) {
197+
std::vector<std::string> filtered_devices_;
198+
filtered_devices_.reserve(indices.size());
199+
for (auto& index : indices) {
200+
for (auto& device : devices) {
201+
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
202+
if (std::stoi(device_spec_parts[1]) == index) {
203+
filtered_devices_.push_back(device);
204+
break;
205+
}
206+
}
207+
}
208+
return filtered_devices_;
209+
}
210+
180211
} // namespace util
181212
} // namespace runtime
182213
} // namespace torch_xla

torch_xla/csrc/tensor_util.cpp

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -838,39 +838,6 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
838838
runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors));
839839
}
840840

841-
namespace {
842-
843-
/**
844-
* Filters a list of device strings to include only those with IDs matching
845-
* the provided indices.
846-
*
847-
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
848-
* @param indices List of device IDs to filter by
849-
* @return Filtered list of device strings, or error status if parsing fails
850-
*
851-
* Example:
852-
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
853-
* indices = [1, 3]
854-
* result = ["TPU:1", "TPU:3"]
855-
*/
856-
std::vector<std::string> FilterDevicesByAddressableDevices(
857-
std::vector<std::string> devices, const std::vector<int64_t>& indices) {
858-
std::vector<std::string> filtered_devices_;
859-
filtered_devices_.reserve(indices.size());
860-
for (auto& index : indices) {
861-
for (auto& device : devices) {
862-
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
863-
if (std::stoi(device_spec_parts[1]) == index) {
864-
filtered_devices_.push_back(device);
865-
break;
866-
}
867-
}
868-
}
869-
return filtered_devices_;
870-
}
871-
872-
} // namespace
873-
874841
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
875842
const std::vector<at::Tensor>& tensors,
876843
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
@@ -900,7 +867,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
900867
if ((!denormalized_tile_assignment.empty()) &&
901868
(denormalized_tile_assignment.size() !=
902869
addressable_devices.size())) {
903-
addressable_devices = FilterDevicesByAddressableDevices(
870+
addressable_devices = torch_xla::runtime::util::FilterDevicesByAddressableDevices(
904871
addressable_devices, denormalized_tile_assignment);
905872
}
906873
}

0 commit comments

Comments
 (0)