Skip to content

Commit

Permalink
Support parallel execution of tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
twenzel committed Sep 1, 2022
1 parent 51e73fe commit 06d0491
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 6 deletions.
112 changes: 112 additions & 0 deletions src/Cake.Core.Tests/Unit/CakeEngineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,118 @@ public async Task Should_Return_Report_That_Marks_Failed_Tasks_As_Failed()
Assert.Equal(CakeTaskExecutionStatus.Delegated, report.First(e => e.TaskName == "A").ExecutionStatus);
Assert.Equal(CakeTaskExecutionStatus.Failed, report.First(e => e.TaskName == "B").ExecutionStatus);
}

[Fact]
public async Task Should_Throw_Exception_For_Circular_Dependencies()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("B");
var engine = fixture.CreateEngine();
engine.RegisterTask("B").IsDependentOn("C");
engine.RegisterTask("C").IsDependentOn("D");
engine.RegisterTask("D").IsDependentOn("B");

// When
var result = await Record.ExceptionAsync(() => engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<CakeException>(result);
Assert.Equal("Graph contains circular references.", result?.Message);
}

[Fact]
public async Task Should_Throw_Exception_For_Circular_Dependencies_In_Parallel()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("B").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("B").IsDependentOn("C");
engine.RegisterTask("C").IsDependentOn("D");
engine.RegisterTask("D").IsDependentOn("B");

// When
var result = await Record.ExceptionAsync(() => engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<CakeException>(result);
Assert.Equal("Graph contains circular references.", result?.Message);
}

[Fact]
public async Task Should_Execute_Tasks_In_Order_In_Parallel()
{
// Given
var result = new List<string>();
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A").Does(() => result.Add("A"));
engine.RegisterTask("B").IsDependentOn("A").Does(() => result.Add("B"));
engine.RegisterTask("C").IsDependentOn("B").Does(() => result.Add("C"));
engine.RegisterTask("D").IsDependentOn("C").IsDependeeOf("E").Does(() => { result.Add("D"); });
engine.RegisterTask("E").Does(() => { result.Add("E"); });

// When
await engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings);

// Then
Assert.Equal(5, result.Count);
Assert.Equal("A", result[0]);
Assert.Equal("B", result[1]);
Assert.Equal("C", result[2]);
Assert.Equal("D", result[3]);
Assert.Equal("E", result[4]);
}

[Fact]
public async Task Should_Execute_Tasks_In_Parallel()
{
// Given
var result = new List<string>();
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A").Does(() => result.Add("A"));
engine.RegisterTask("B").IsDependentOn("A").Does(async () => { await Task.Delay(20); result.Add("B"); });
engine.RegisterTask("C").IsDependentOn("A").Does(async () => { await Task.Delay(5); result.Add("C"); });
engine.RegisterTask("D").IsDependentOn("A").Does(() => result.Add("D"));
engine.RegisterTask("E").IsDependentOn("B").IsDependentOn("C").IsDependentOn("D").Does(() => result.Add("E"));

// When
await engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings);

// Then
Assert.Equal(5, result.Count);
Assert.Equal("A", result[0]);
Assert.Equal("D", result[1]);
Assert.Equal("C", result[2]);
Assert.Equal("B", result[3]);
Assert.Equal("E", result[4]);
}

[Fact]
public async Task Should_Not_Catch_Exceptions_From_Task_If_ContinueOnError_Is_Not_Set_In_Parallel()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A");
engine.RegisterTask("B").IsDependentOn("A");
engine.RegisterTask("C").IsDependentOn("A").Does(() => throw new InvalidOperationException("Whoopsie"));
engine.RegisterTask("D").IsDependentOn("A");
engine.RegisterTask("E").IsDependentOn("B").IsDependentOn("C").IsDependentOn("D");

// When
var result = await Record.ExceptionAsync(() =>
engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<InvalidOperationException>(result);
Assert.Equal("Whoopsie", result?.Message);
}
}

public sealed class TheSetupEvent
Expand Down
47 changes: 42 additions & 5 deletions src/Cake.Core/CakeEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Cake.Core.Diagnostics;
using Cake.Core.Graph;
Expand Down Expand Up @@ -211,19 +212,48 @@ public async Task<CakeReport> RunTargetAsync(ICakeContext context, IExecutionStr
{
// Execute only the target task.
var task = _tasks.FirstOrDefault(x => x.Name.Equals(settings.Target, StringComparison.OrdinalIgnoreCase));
await RunTask(context, strategy, task, target, stopWatch, report);
await RunTask(context, strategy, task, target, stopWatch, report, null);
}
else if (settings.Parallel)
{
await graph.TraverseAsync(target, async (taskName, cancellationTokenSource) =>
{
if (cancellationTokenSource.IsCancellationRequested)
{
return;
}
var task = _tasks.FirstOrDefault(_ => _.Name.Equals(taskName, StringComparison.OrdinalIgnoreCase));
Debug.Assert(task != null, "Node should not be null");
var isTarget = task.Name.Equals(target, StringComparison.OrdinalIgnoreCase);
await RunTask(context, strategy, task, target, stopWatch, report, cancellationTokenSource);
});
}
else
{
// Execute all scheduled tasks.
foreach (var task in orderedTasks)
{
await RunTask(context, strategy, task, target, stopWatch, report);
await RunTask(context, strategy, task, target, stopWatch, report, null);
}
}

return report;
}
catch (TaskCanceledException)
{
exceptionWasThrown = true;
throw;
}
catch (AggregateException ex)
{
exceptionWasThrown = true;
thrownException = ex.InnerException;

throw ex.GetBaseException();
}
catch (Exception ex)
{
exceptionWasThrown = true;
Expand All @@ -236,7 +266,7 @@ public async Task<CakeReport> RunTargetAsync(ICakeContext context, IExecutionStr
}
}

private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, CakeTask task, string target, Stopwatch stopWatch, CakeReport report)
private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, CakeTask task, string target, Stopwatch stopWatch, CakeReport report, CancellationTokenSource cancellationTokenSource)
{
// Is this the current target?
var isTarget = task.Name.Equals(target, StringComparison.OrdinalIgnoreCase);
Expand All @@ -255,7 +285,7 @@ private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, Ca

if (!skipped)
{
await ExecuteTaskAsync(context, strategy, stopWatch, task, report).ConfigureAwait(false);
await ExecuteTaskAsync(context, strategy, stopWatch, task, report, cancellationTokenSource).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -309,7 +339,7 @@ private static bool ShouldTaskExecute(ICakeContext context, CakeTask task, CakeT
}

private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy strategy, Stopwatch stopWatch,
CakeTask task, CakeReport report)
CakeTask task, CakeReport report, CancellationTokenSource cancellationTokenSource)
{
stopWatch.Restart();

Expand All @@ -321,6 +351,11 @@ private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy str
// Execute the task.
await strategy.ExecuteAsync(task, context).ConfigureAwait(false);
}
catch (TaskCanceledException exception)
{
taskException = exception;
throw;
}
catch (Exception exception)
{
_log.Error("An error occurred when executing task '{0}'.", task.Name);
Expand All @@ -340,6 +375,8 @@ private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy str
}
else
{
cancellationTokenSource?.Cancel();

// No error handler defined for this task.
// Rethrow the exception and let it propagate.
throw;
Expand Down
15 changes: 15 additions & 0 deletions src/Cake.Core/ExecutionSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ public sealed class ExecutionSettings
/// </summary>
public bool Exclusive { get; private set; }

/// <summary>
/// Gets a value indicating whether the dependend task of the target should be run in parallel (if possible).
/// </summary>
public bool Parallel { get; private set; }

/// <summary>
/// Sets the target to be executed.
/// </summary>
Expand All @@ -39,5 +44,15 @@ public ExecutionSettings UseExclusiveTarget()
Exclusive = true;
return this;
}

/// <summary>
/// Whether or not to run the dependend task in parallel.
/// </summary>
/// <returns>The same <see cref="ExecutionSettings"/> instance so that multiple calls can be chained.</returns>
public ExecutionSettings RunInParallel()
{
Parallel = true;
return this;
}
}
}
79 changes: 79 additions & 0 deletions src/Cake.Core/Graph/CakeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace Cake.Core.Graph
{
Expand Down Expand Up @@ -112,6 +114,29 @@ public IEnumerable<string> Traverse(string target)
return result;
}

/// <summary>
/// Traverses the graph asynchrounus leading to the specified target.
/// </summary>
/// <param name="target">The target to traverse to.</param>
/// <param name="executeTask">Action which will be called on each task.</param>
/// <returns>A task to wait for.</returns>
public async Task TraverseAsync(string target, Func<string, CancellationTokenSource, Task> executeTask)
{
if (!Exist(target))
{
return;
}

if (HasCircularReferences(target))
{
throw new CakeException("Graph contains circular references.");
}

var cancellationTokenSource = new CancellationTokenSource();
var visitedNodes = new Dictionary<string, Task>();
await TraverseAsync(target, executeTask, cancellationTokenSource, visitedNodes);
}

private void Traverse(string node, ICollection<string> result, ISet<string> visited = null)
{
visited = visited ?? new HashSet<string>(StringComparer.OrdinalIgnoreCase);
Expand All @@ -130,5 +155,59 @@ private void Traverse(string node, ICollection<string> result, ISet<string> visi
throw new CakeException("Graph contains circular references.");
}
}

private async Task TraverseAsync(string node, Func<string, CancellationTokenSource, Task> executeTask,
CancellationTokenSource cancellationTokenSource, IDictionary<string, Task> visitedNodes)
{
if (visitedNodes.ContainsKey(node))
{
await visitedNodes[node];
return;
}

var token = cancellationTokenSource.Token;
var dependentTasks = _edges
.Where(x => x.End.Equals(node, StringComparison.OrdinalIgnoreCase))
.Select(x =>
{
var task = TraverseAsync(x.Start, executeTask, cancellationTokenSource, visitedNodes);
visitedNodes[x.Start] = task;
if (task.IsFaulted)
{
throw task.Exception;
}
return task;
})
.ToArray();

if (dependentTasks.Any())
{
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
token.Register(() => tcs.TrySetCanceled(), false);
await Task.WhenAny(Task.WhenAll(dependentTasks), tcs.Task);
}

await executeTask(node, cancellationTokenSource);
}

private bool HasCircularReferences(string node, Stack<string> visited = null)
{
visited = visited ?? new Stack<string>();

if (visited.Contains(node))
{
return true;
}

visited.Push(node);
var hasCircularReference = _edges
.Where(x => x.End.Equals(node, StringComparison.OrdinalIgnoreCase))
.Any(x => HasCircularReferences(x.Start, visited));
visited.Pop();

return hasCircularReference;
}
}
}
5 changes: 5 additions & 0 deletions src/Cake.Frosting/Internal/Commands/DefaultCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public override int Execute(CommandContext context, DefaultCommandSettings setti
runner.Settings.UseExclusiveTarget();
}

if (settings.Parallel)
{
runner.Settings.RunInParallel();
}

runner.Run(settings.Target);
}
catch (Exception ex)
Expand Down
4 changes: 4 additions & 0 deletions src/Cake.Frosting/Internal/Commands/DefaultCommandSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,9 @@ internal sealed class DefaultCommandSettings : CommandSettings
[CommandOption("--info")]
[Description("Displays additional information about Cake.")]
public bool Info { get; set; }

[CommandOption("--parallel|-p")]
[Description("Enables the support for parallel tasks.")]
public bool Parallel { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/Cake.Tests/Unit/ProgramTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Linq;
using System.Threading.Tasks;
using Autofac;
using Cake.Cli;
Expand Down Expand Up @@ -32,6 +31,7 @@ public async Task Should_Use_Default_Parameters_By_Default()
settings.BuildHostKind == BuildHostKind.Build &&
settings.Debug == false &&
settings.Exclusive == false &&
settings.Parallel == false &&
settings.Script.FullPath == "build.cake" &&
settings.Verbosity == Verbosity.Normal &&
settings.NoBootstrapping == false));
Expand Down
1 change: 1 addition & 0 deletions src/Cake/Commands/DefaultCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public override int Execute(CommandContext context, DefaultCommandSettings setti
Script = settings.Script,
Verbosity = settings.Verbosity,
Exclusive = settings.Exclusive,
Parallel = settings.Parallel,
Debug = settings.Debug,
NoBootstrapping = settings.SkipBootstrap,
});
Expand Down
Loading

0 comments on commit 06d0491

Please sign in to comment.