Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions SnaffCore/Concurrency/BlockingTaskScheduler.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Numerics;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Timers;
using static SnaffCore.Config.Options;

namespace SnaffCore.Concurrency
{
Expand Down Expand Up @@ -34,6 +39,8 @@ public bool Done()
Scheduler.RecalculateCounters();
TaskCounters taskCounters = Scheduler.GetTaskCounters();

Console.WriteLine($"Checking if done - queued: {taskCounters.CurrentTasksQueued}, done: {taskCounters.CurrentTasksRunning}");

if ((taskCounters.CurrentTasksQueued + taskCounters.CurrentTasksRunning == 0))
{
return true;
Expand Down Expand Up @@ -68,6 +75,168 @@ public void New(Action action)
}
}

public enum TaskFileType
{
None = 0,
Share = 1,
Tree = 2,
File = 3
}

public enum TaskFileEntryStatus
{
Pending = 0,
Completed = 1,
}

public struct TaskFileEntry
{
public TaskFileEntryStatus status;
public string guid;
public TaskFileType type;
public string input;

public override string ToString()
{
StringBuilder stringBuilder = new StringBuilder();

stringBuilder.Append(status.ToString());
stringBuilder.Append("|");
stringBuilder.Append(guid);
stringBuilder.Append("|");
stringBuilder.Append(type.ToString());
stringBuilder.Append("|");
stringBuilder.Append(input);

return stringBuilder.ToString();
}

public TaskFileEntry(TaskFileType type, string input)
{
guid = Guid.NewGuid().ToString();
status = TaskFileEntryStatus.Pending;
this.type = type;
this.input = input;
}

public TaskFileEntry(string entryLine)
{
string[] lineParts = entryLine.Split('|');

status = (TaskFileEntryStatus)Enum.Parse(typeof(TaskFileEntryStatus), lineParts[0]);
guid = lineParts[1];

type = (TaskFileType)Enum.Parse(typeof(TaskFileType), lineParts[2]);
input = lineParts[3];
}
}

public class ResumingTaskScheduler : BlockingStaticTaskScheduler
{
private static readonly Dictionary<string, Tuple<string, string>> pendingTasks = new Dictionary<string, Tuple<string, string>>();
private static readonly object writeLock = new object();
private static int pendingSaveCalls = 0;

internal BlockingMq Mq { get; }

public ResumingTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog)
{
this.Mq = BlockingMq.GetMq();
}

public void New(string taskType, Action<string> action, string path)
{
string guid = null;

if (MyOptions.TaskFile != null)
{
guid = Guid.NewGuid().ToString();
pendingTasks.Add(guid, new Tuple<string, string>(taskType, path));
}

New(() =>
{
try
{
action(path);
}
catch (Exception e)
{
Mq.Error("Exception in " + taskType.ToString() + " task for host " + path);
Mq.Error(e.ToString());
}

if (guid != null) pendingTasks.Remove(guid);
});
}

public static void SaveState(object sender, ElapsedEventArgs e)
{
SaveState();
}

public static void SaveState()
{
// Guard against the possibility that someone forgot to check this
if (MyOptions.TaskFile == null) return;

// This blocks more than one save call from being buffered at a time
// Prevents a situation where a bunch of buffered calls wait for the lock
// But still allows for you to "write continously" if you set an interval shorter than the file write time
if (pendingSaveCalls > 1) return;
pendingSaveCalls++;

// In case the file takes longer to write than the set interval
lock (writeLock)
{
using (StreamWriter fileWriter = new StreamWriter(MyOptions.TaskFile, false))
{
// Copy the values into the array to avoid an error in case the pending tasks are changed during the write loop
Tuple<string, string>[] valuesSnapshot = pendingTasks.Values.ToArray();

foreach (Tuple<string, string> value in valuesSnapshot)
{
fileWriter.WriteLine($"{value.Item1}|{value.Item2}");
}

fileWriter.Flush();
}

pendingSaveCalls--;
}
}
}

public class ShareTaskScheduler : ResumingTaskScheduler
{
public ShareTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { }

public void New(Action<string> action, string share)
{
New("share", action, share);
}
}

public class TreeTaskScheduler : ResumingTaskScheduler
{
public TreeTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { }

public void New(Action<string> action, string tree)
{
New("tree", action, tree);
}
}

public class FileTaskScheduler : ResumingTaskScheduler
{
public FileTaskScheduler(int threads, int maxBacklog) : base(threads, maxBacklog) { }

public void New(Action<string> action, string file)
{
New("file", action, file);
}
}

public class TaskCounters
{
public BigInteger TotalTasksQueued { get; set; }
Expand Down
5 changes: 5 additions & 0 deletions SnaffCore/Config/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ public partial class Options
{
public static Options MyOptions { get; set; }

// Pause and resume functionality
public string TaskFile { get; set; }
public double TaskFileTimeOut { get; set; } = 5;
public string ResumeFrom { get; set; }

// Manual Targeting Options
public List<string> PathTargets { get; set; } = new List<string>();
public string[] ComputerTargets { get; set; }
Expand Down
15 changes: 2 additions & 13 deletions SnaffCore/ShareFind/ShareFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace SnaffCore.ShareFind
public class ShareFinder
{
private BlockingMq Mq { get; set; }
private BlockingStaticTaskScheduler TreeTaskScheduler { get; set; }
private TreeTaskScheduler TreeTaskScheduler { get; set; }
private TreeWalker TreeWalker { get; set; }
//private EffectivePermissions effectivePermissions { get; set; } = new EffectivePermissions(MyOptions.CurrentUser);

Expand Down Expand Up @@ -184,18 +184,7 @@ internal void GetComputerShares(string computer)
if (MyOptions.ScanFoundShares)
{
Mq.Trace("Creating a TreeWalker task for " + shareResult.SharePath);
TreeTaskScheduler.New(() =>
{
try
{
TreeWalker.WalkTree(shareResult.SharePath);
}
catch (Exception e)
{
Mq.Error("Exception in TreeWalker task for share " + shareResult.SharePath);
Mq.Error(e.ToString());
}
});
TreeTaskScheduler.New(TreeWalker.WalkTree, shareResult.SharePath);
}
Mq.ShareResult(shareResult);
}
Expand Down
Loading