Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ bool check_for_arg(const sycl::detail::ArgDesc &Arg,
}
} // anonymous namespace

void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto Node : MGraphImpl->MRoots) {
Node->topology_sort(Node, MSchedule);
}
}
}

std::shared_ptr<node_impl> graph_impl::add_subgraph_nodes(
const std::list<std::shared_ptr<node_impl>> &NodeList) {
// Find all input and output nodes from the node list
Expand Down Expand Up @@ -564,7 +556,6 @@ command_graph<graph_state::executable>::command_graph(

void command_graph<graph_state::executable>::finalize_impl() {
// Create PI command-buffers for each device in the finalized context
impl->schedule();

auto Context = impl->get_context();
for (auto Device : impl->get_context().get_devices()) {
Expand Down
104 changes: 82 additions & 22 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <functional>
#include <list>
#include <set>
#include <optional>
#include <map>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -73,22 +75,28 @@ class node_impl {
std::unique_ptr<sycl::detail::CG> &&CommandGroup)
: MCGType(CGType), MCommandGroup(std::move(CommandGroup)) {}

/// Recursively add nodes to execution stack.
/// @param NodeImpl Node to schedule.
/// @param Schedule Execution ordering to add node to.
void topology_sort(std::shared_ptr<node_impl> NodeImpl,
std::list<std::shared_ptr<node_impl>> &Schedule) {
for (auto Next : MSuccessors) {
// Check if we've already scheduled this node
if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end())
Next->topology_sort(Next, Schedule);
}
// We don't need to schedule empty nodes as they are only used when
// calculating dependencies
if (!NodeImpl->is_empty())
Schedule.push_front(NodeImpl);
}
private:
/// Depth of this node in a containing graph
///
/// The first call to graph.exec_order_recompute computes & caches the value
/// It will likely become stale whenever the containing graph is changed and
/// a single value will be inequate if this node is added to multiple graphs
/// Caching is dangerous but recomputing takes O(graph_size) worst-case time
std::optional<int> MDepth;

public:
int get_depth(node_impl &V) { return V.get_depth(); };
int get_depth() {
if (!MDepth.has_value()) {
int max_depth_found = -1;
for (auto P : MPredecessors) {
max_depth_found = std::max(max_depth_found, P.lock()->get_depth());
}
MDepth = max_depth_found + 1;
}
return MDepth.value();
};

/// Checks if this node has an argument.
/// @param Arg Argument to lookup.
/// @return True if \p Arg is used in node, false otherwise.
Expand Down Expand Up @@ -197,6 +205,63 @@ class graph_impl {
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap() {}

private:
/// A cache of pointers to exit nodes
///
/// This is not used (yet), but depth computation starts from exit nodes
/// Perhaps, it might be better to do the exec_order_recompute traversal
/// starting from each exit node and working upwards using MPredecessors
/// rather than from each root node and doing depth-first to exit nodes?
std::vector<node_impl *> MExitNodes;

/// A sorted multimap capturing the optimal execution/submission order
///
/// The SortKey is the depth in the graph for the node_impl in the value
/// Depth is the length of the longest dependence chain to any root node
std::multimap<int, std::shared_ptr<node_impl>> MExecOrder;

/// <summary>
/// Depth-first recursion from V to build the optimal execution order
/// </summary>
/// <param name="V">Starting node for depth-first recursion</param>
void exec_order_recompute(node_impl &V) {
// depth-first recursion to access all nodes that succeed this node
for (auto &S : V.MSuccessors) {
exec_order_recompute(*S.get());
}
// insert this into execution order based on its depth in the graph
MExecOrder.insert(std::pair(V.get_depth(), &V));
// cache all the exit nodes; no reason, just feels like a good idea
if (V.MSuccessors.empty()) {
MExitNodes.push_back(&V);
}
};

/// <summary>
/// Recomputes the optimal submission/execution order for this whole graph
/// </summary>
void exec_order_recompute() {
MExecOrder.clear();
// for all root nodes ...
for (auto &root : MRoots) {
// ... recurse towards all exit nodes
exec_order_recompute(*root);
}
};

public:
/// <summary>
/// Recomputes the optimal submission/execution order then schedules all nodes
/// </summary>
std::list<std::shared_ptr<node_impl>> compute_schedule() {
exec_order_recompute();
std::list<std::shared_ptr<node_impl>> sched;
for (auto &next : MExecOrder) {
sched.push_front(*next.second.get());
}
return sched;
};

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void add_root(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -314,17 +379,15 @@ class exec_graph_impl {
/// @param GraphImpl Modifiable graph implementation to create with.
exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl)
: MSchedule(), MGraphImpl(GraphImpl), MPiCommandBuffers(),
: MSchedule(GraphImpl->compute_schedule()),
MPiCommandBuffers(),
MPiSyncPoints(), MContext(Context) {}

/// Destructor.
///
/// Releases any PI command-buffers the object has created.
~exec_graph_impl();

/// Add nodes to MSchedule.
void schedule();

/// Enqueues the backend objects for the graph to the parametrized queue.
/// @param Queue Command-queue to submit backend objects to.
/// @return Event associated with enqueued object.
Expand Down Expand Up @@ -384,9 +447,6 @@ class exec_graph_impl {

/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, RT::PiExtCommandBuffer> MPiCommandBuffers;
/// Map of nodes in the exec graph to the sync point representing their
Expand Down