Skip to content
Open
116 changes: 116 additions & 0 deletions src/Build.UnitTests/BackEnd/TaskHostNodeKey_Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Build.Internal;
using Shouldly;
using Xunit;

namespace Microsoft.Build.Engine.UnitTests.BackEnd
{
/// <summary>
/// Tests for TaskHostNodeKey record struct functionality.
/// </summary>
public class TaskHostNodeKey_Tests
{
[Fact]
public void TaskHostNodeKey_Equality_SameValues_AreEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);

key1.ShouldBe(key2);
(key1 == key2).ShouldBeTrue();
key1.GetHashCode().ShouldBe(key2.GetHashCode());
}

[Fact]
public void TaskHostNodeKey_Equality_DifferentNodeId_AreNotEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 2);

key1.ShouldNotBe(key2);
(key1 != key2).ShouldBeTrue();
}

[Fact]
public void TaskHostNodeKey_Equality_DifferentHandshakeOptions_AreNotEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.X64, 1);

key1.ShouldNotBe(key2);
(key1 != key2).ShouldBeTrue();
}

[Fact]
public void TaskHostNodeKey_CanBeUsedAsDictionaryKey()
{
var dict = new System.Collections.Generic.Dictionary<TaskHostNodeKey, string>();
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.X64, 2);

dict[key1] = "value1";
dict[key2] = "value2";

dict[key1].ShouldBe("value1");
dict[key2].ShouldBe("value2");

// Create a new key with same values as key1
var key1Copy = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
dict[key1Copy].ShouldBe("value1");
}

[Fact]
public void TaskHostNodeKey_LargeNodeId_Works()
{
// Test that we can use node IDs greater than 255 (the previous limit)
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 256);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1000);
var key3 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, int.MaxValue);

key1.NodeId.ShouldBe(256);
key2.NodeId.ShouldBe(1000);
key3.NodeId.ShouldBe(int.MaxValue);

// Ensure they are all different
key1.ShouldNotBe(key2);
key2.ShouldNotBe(key3);
key1.ShouldNotBe(key3);
}

[Fact]
public void TaskHostNodeKey_NegativeNodeId_Works()
{
// Traditional multi-proc builds use -1 for node ID
var key = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, -1);

key.NodeId.ShouldBe(-1);
key.HandshakeOptions.ShouldBe(HandshakeOptions.TaskHost | HandshakeOptions.NET);
}

[Fact]
public void TaskHostNodeKey_AllHandshakeOptions_Work()
{
// Test various HandshakeOptions combinations
HandshakeOptions[] optionsList =
[
HandshakeOptions.None,
HandshakeOptions.TaskHost,
HandshakeOptions.TaskHost | HandshakeOptions.NET,
HandshakeOptions.TaskHost | HandshakeOptions.X64,
HandshakeOptions.TaskHost | HandshakeOptions.NET | HandshakeOptions.NodeReuse,
HandshakeOptions.TaskHost | HandshakeOptions.CLR2,
HandshakeOptions.TaskHost | HandshakeOptions.Arm64
];

foreach (var options in optionsList)
{
var key = new TaskHostNodeKey(options, 42);

key.HandshakeOptions.ShouldBe(options);
key.NodeId.ShouldBe(42);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,43 @@ internal class NodeProviderOutOfProcTaskHost : NodeProviderOutOfProcBase, INodeP

/// <summary>
/// A mapping of all the task host nodes managed by this provider.
/// The key is a TaskHostNodeKey combining HandshakeOptions and scheduled node ID.
/// </summary>
private ConcurrentDictionary<int, NodeContext> _nodeContexts;
private ConcurrentDictionary<TaskHostNodeKey, NodeContext> _nodeContexts;

/// <summary>
/// Reverse mapping from communication node ID to TaskHostNodeKey.
/// Used for O(1) lookup when handling node termination from ShutdownAllNodes.
/// </summary>
private ConcurrentDictionary<int, TaskHostNodeKey> _nodeIdToNodeKey;

/// <summary>
/// A mapping of all of the INodePacketFactories wrapped by this provider.
/// Keyed by the communication node ID (NodeContext.NodeId) for O(1) packet routing.
/// Thread-safe to support parallel taskhost creation in /mt mode where multiple thread nodes
/// can simultaneously create their own taskhosts.
/// </summary>
private ConcurrentDictionary<int, INodePacketFactory> _nodeIdToPacketFactory;

/// <summary>
/// A mapping of all of the INodePacketHandlers wrapped by this provider.
/// Keyed by the communication node ID (NodeContext.NodeId) for O(1) packet routing.
/// Thread-safe to support parallel taskhost creation in /mt mode where multiple thread nodes
/// can simultaneously create their own taskhosts.
/// </summary>
private ConcurrentDictionary<int, INodePacketHandler> _nodeIdToPacketHandler;

/// <summary>
/// Keeps track of the set of nodes for which we have not yet received shutdown notification.
/// Keeps track of the set of node IDs for which we have not yet received shutdown notification.
/// </summary>
private HashSet<int> _activeNodes;

/// <summary>
/// Counter for generating unique communication node IDs.
/// Incremented atomically for each new node created.
/// </summary>
private int _nextNodeId;

/// <summary>
/// Packet factory we use if there's not already one associated with a particular context.
/// </summary>
Expand Down Expand Up @@ -169,12 +184,23 @@ public IList<NodeInfo> CreateNodes(int nextNodeId, INodePacketFactory packetFact

/// <summary>
/// Sends data to the specified node.
/// Note: For task hosts, use the overload that takes TaskHostNodeKey instead.
/// </summary>
/// <param name="nodeId">The node to which data shall be sent.</param>
/// <param name="packet">The packet to send.</param>
public void SendData(int nodeId, INodePacket packet)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this method needed for anything now? if no, remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SendData(int nodeId, INodePacket packet) method is required by the INodeProvider interface but not actually used for task hosts (task hosts use SendData(TaskHostNodeKey, ...) via TaskHostTask). Similar to TaskHostNodeManager.SendData, keeping it as NotImplementedException to satisfy the interface contract.

Copy link
Member

@JanProvaznik JanProvaznik Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit sus, but I suppose it works, flagging this for reviewers

{
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeId, out NodeContext context), "Invalid host context specified: {0}.", nodeId);
throw new NotImplementedException("For task hosts, use the overload that takes TaskHostNodeKey.");
}

/// <summary>
/// Sends data to the specified task host node.
/// </summary>
/// <param name="nodeKey">The task host node key identifying the target node.</param>
/// <param name="packet">The packet to send.</param>
internal void SendData(TaskHostNodeKey nodeKey, INodePacket packet)
{
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeKey, out NodeContext context), "Invalid host context specified: {0}.", nodeKey);

SendData(context, packet);
}
Expand Down Expand Up @@ -211,10 +237,12 @@ public void ShutdownAllNodes()
public void InitializeComponent(IBuildComponentHost host)
{
this.ComponentHost = host;
_nodeContexts = new ConcurrentDictionary<int, NodeContext>();
_nodeContexts = new ConcurrentDictionary<TaskHostNodeKey, NodeContext>();
_nodeIdToNodeKey = new ConcurrentDictionary<int, TaskHostNodeKey>();
_nodeIdToPacketFactory = new ConcurrentDictionary<int, INodePacketFactory>();
_nodeIdToPacketHandler = new ConcurrentDictionary<int, INodePacketHandler>();
_activeNodes = new HashSet<int>();
_activeNodes = [];
_nextNodeId = 0;

_noNodesActiveEvent = new ManualResetEvent(true);
_localPacketFactory = new NodePacketFactory();
Expand Down Expand Up @@ -569,17 +597,16 @@ private static string GetPathFromEnvironmentOrDefault(string environmentVariable
/// Make sure a node in the requested context exists.
/// </summary>
internal bool AcquireAndSetUpHost(
HandshakeOptions hostContext,
int taskHostNodeId,
TaskHostNodeKey nodeKey,
INodePacketFactory factory,
INodePacketHandler handler,
TaskHostConfiguration configuration,
in TaskHostParameters taskHostParameters)
{
bool nodeCreationSucceeded;
if (!_nodeContexts.ContainsKey(taskHostNodeId))
if (!_nodeContexts.ContainsKey(nodeKey))
{
nodeCreationSucceeded = CreateNode(hostContext, taskHostNodeId, factory, handler, configuration, taskHostParameters);
nodeCreationSucceeded = CreateNode(nodeKey, factory, handler, configuration, taskHostParameters);
}
else
{
Expand All @@ -589,9 +616,10 @@ internal bool AcquireAndSetUpHost(

if (nodeCreationSucceeded)
{
NodeContext context = _nodeContexts[taskHostNodeId];
_nodeIdToPacketFactory[taskHostNodeId] = factory;
_nodeIdToPacketHandler[taskHostNodeId] = handler;
NodeContext context = _nodeContexts[nodeKey];
// Map the transport ID directly to the handlers for O(1) packet routing
_nodeIdToPacketFactory[context.NodeId] = factory;
_nodeIdToPacketHandler[context.NodeId] = handler;

// Configure the node.
context.SendData(configuration);
Expand All @@ -604,25 +632,35 @@ internal bool AcquireAndSetUpHost(
/// <summary>
/// Expected to be called when TaskHostTask is done with host of the given context.
/// </summary>
internal void DisconnectFromHost(int nodeId)
internal void DisconnectFromHost(TaskHostNodeKey nodeKey)
{
bool successRemoveFactory = _nodeIdToPacketFactory.TryRemove(nodeId, out _);
bool successRemoveHandler = _nodeIdToPacketHandler.TryRemove(nodeId, out _);
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeKey, out NodeContext context), "Node context not found for key: {0}. Was the node created?", nodeKey);

bool successRemoveFactory = _nodeIdToPacketFactory.TryRemove(context.NodeId, out _);
bool successRemoveHandler = _nodeIdToPacketHandler.TryRemove(context.NodeId, out _);

ErrorUtilities.VerifyThrow(successRemoveFactory && successRemoveHandler, "Why are we trying to disconnect from a context that we already disconnected from? Did we call DisconnectFromHost twice?");
}

/// <summary>
/// Instantiates a new MSBuild or MSBuildTaskHost process acting as a child node.
/// </summary>
internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INodePacketFactory factory, INodePacketHandler handler, TaskHostConfiguration configuration, in TaskHostParameters taskHostParameters)
internal bool CreateNode(TaskHostNodeKey nodeKey, INodePacketFactory factory, INodePacketHandler handler, TaskHostConfiguration configuration, in TaskHostParameters taskHostParameters)
{
ErrorUtilities.VerifyThrowArgumentNull(factory);
ErrorUtilities.VerifyThrow(!_nodeIdToPacketFactory.ContainsKey(taskHostNodeId), "We should not already have a factory for this context! Did we forget to call DisconnectFromHost somewhere?");
ErrorUtilities.VerifyThrow(!_nodeContexts.ContainsKey(nodeKey), "We should not already have a node for this context! Did we forget to call DisconnectFromHost somewhere?");

HandshakeOptions hostContext = nodeKey.HandshakeOptions;

// If runtime host path is null it means we don't have MSBuild.dll path resolved and there is no need to include it in the command line arguments.
string commandLineArgsPlaceholder = "\"{0}\" /nologo /nodemode:2 /nodereuse:{1} /low:{2} ";

// Generate a unique node ID for communication purposes using atomic increment.
int communicationNodeId = Interlocked.Increment(ref _nextNodeId);

// Create callbacks that capture the TaskHostNodeKey
void OnNodeContextCreated(NodeContext context) => NodeContextCreated(context, nodeKey);

IList<NodeContext> nodeContexts;

// Handle .NET task host context
Expand All @@ -639,10 +677,10 @@ internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INode
nodeContexts = GetNodes(
runtimeHostPath,
string.Format(commandLineArgsPlaceholder, Path.Combine(msbuildAssemblyPath, Constants.MSBuildAssemblyName), NodeReuseIsEnabled(hostContext), ComponentHost.BuildParameters.LowPriority),
taskHostNodeId,
communicationNodeId,
this,
handshake,
NodeContextCreated,
OnNodeContextCreated,
NodeContextTerminated,
1);

Expand All @@ -663,10 +701,10 @@ internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INode
nodeContexts = GetNodes(
msbuildLocation,
string.Format(commandLineArgsPlaceholder, string.Empty, NodeReuseIsEnabled(hostContext), ComponentHost.BuildParameters.LowPriority),
taskHostNodeId,
communicationNodeId,
this,
new Handshake(hostContext),
NodeContextCreated,
OnNodeContextCreated,
NodeContextTerminated,
1);

Expand All @@ -687,9 +725,10 @@ bool NodeReuseIsEnabled(HandshakeOptions hostContext)
/// <summary>
/// Method called when a context created.
/// </summary>
private void NodeContextCreated(NodeContext context)
private void NodeContextCreated(NodeContext context, TaskHostNodeKey nodeKey)
{
_nodeContexts[context.NodeId] = context;
_nodeContexts[nodeKey] = context;
_nodeIdToNodeKey[context.NodeId] = nodeKey;

// Start the asynchronous read.
context.BeginAsyncPacketRead();
Expand All @@ -702,19 +741,20 @@ private void NodeContextCreated(NodeContext context)
}

/// <summary>
/// Method called when a context terminates.
/// Method called when a context terminates (called from CreateNode callbacks or ShutdownAllNodes).
/// </summary>
private void NodeContextTerminated(int nodeId)
{
_nodeContexts.TryRemove(nodeId, out _);
// Remove from nodeKey-based lookup if we have it
if (_nodeIdToNodeKey.TryRemove(nodeId, out TaskHostNodeKey nodeKey))
{
_nodeContexts.TryRemove(nodeKey, out _);
}

// May also be removed by unnatural termination, so don't assume it's there
lock (_activeNodes)
{
if (_activeNodes.Contains(nodeId))
{
_activeNodes.Remove(nodeId);
}
_activeNodes.Remove(nodeId);

if (_activeNodes.Count == 0)
{
Expand Down
Loading
Loading