Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parallel execution of tasks #3958

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions src/Cake.Core.Tests/Unit/CakeEngineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,126 @@ public async Task Should_Return_Report_That_Marks_Failed_Tasks_As_Failed()
Assert.Equal(CakeTaskExecutionStatus.Delegated, report.First(e => e.TaskName == "A").ExecutionStatus);
Assert.Equal(CakeTaskExecutionStatus.Failed, report.First(e => e.TaskName == "B").ExecutionStatus);
}

[Fact]
public async Task Should_Throw_Exception_For_Circular_Dependencies()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("B");
var engine = fixture.CreateEngine();
engine.RegisterTask("B").IsDependentOn("C");
engine.RegisterTask("C").IsDependentOn("D");
engine.RegisterTask("D").IsDependentOn("B");

// When
var result = await Record.ExceptionAsync(() => engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<CakeException>(result);
Assert.Equal("Graph contains circular references.", result?.Message);
}

[Fact]
public async Task Should_Throw_Exception_For_Circular_Dependencies_In_Parallel()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("B").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("B").IsDependentOn("C");
engine.RegisterTask("C").IsDependentOn("D");
engine.RegisterTask("D").IsDependentOn("B");

// When
var result = await Record.ExceptionAsync(() => engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<CakeException>(result);
Assert.Equal("Graph contains circular references.", result?.Message);
}

[Fact]
public async Task Should_Execute_Tasks_In_Order_In_Parallel()
{
// Given
var result = new List<string>();
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A").Does(() => result.Add("A"));
engine.RegisterTask("B").IsDependentOn("A").Does(() => result.Add("B"));
engine.RegisterTask("C").IsDependentOn("B").Does(() => result.Add("C"));
engine.RegisterTask("D").IsDependentOn("C").IsDependeeOf("E").Does(() => { result.Add("D"); });
engine.RegisterTask("E").Does(() => { result.Add("E"); });

// When
await engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings);

// Then
Assert.Equal(5, result.Count);
Assert.Equal("A", result[0]);
Assert.Equal("B", result[1]);
Assert.Equal("C", result[2]);
Assert.Equal("D", result[3]);
Assert.Equal("E", result[4]);
}

[Fact]
public async Task Should_Execute_Tasks_In_Parallel()
{
// Given
var result = new List<string>();
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A").Does(() => result.Add("A"));
engine.RegisterTask("B").IsDependentOn("A").Does(async () =>
{
await Task.Delay(20);
result.Add("B");
});
engine.RegisterTask("C").IsDependentOn("A").Does(async () =>
{
await Task.Delay(5);
result.Add("C");
});
engine.RegisterTask("D").IsDependentOn("A").Does(() => result.Add("D"));
engine.RegisterTask("E").IsDependentOn("B").IsDependentOn("C").IsDependentOn("D").Does(() => result.Add("E"));

// When
await engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings);

// Then
Assert.Equal(5, result.Count);
Assert.Equal("A", result[0]);
Assert.Equal("D", result[1]);
Assert.Equal("C", result[2]);
Assert.Equal("B", result[3]);
Assert.Equal("E", result[4]);
}

[Fact]
public async Task Should_Not_Catch_Exceptions_From_Task_If_ContinueOnError_Is_Not_Set_In_Parallel()
{
// Given
var fixture = new CakeEngineFixture();
var settings = new ExecutionSettings().SetTarget("E").RunInParallel();
var engine = fixture.CreateEngine();
engine.RegisterTask("A");
engine.RegisterTask("B").IsDependentOn("A");
engine.RegisterTask("C").IsDependentOn("A").Does(() => throw new InvalidOperationException("Whoopsie"));
engine.RegisterTask("D").IsDependentOn("A");
engine.RegisterTask("E").IsDependentOn("B").IsDependentOn("C").IsDependentOn("D");

// When
var result = await Record.ExceptionAsync(() =>
engine.RunTargetAsync(fixture.Context, fixture.ExecutionStrategy, settings));

// Then
Assert.IsType<InvalidOperationException>(result);
Assert.Equal("Whoopsie", result?.Message);
}
}

public sealed class TheSetupEvent
Expand Down
47 changes: 42 additions & 5 deletions src/Cake.Core/CakeEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Cake.Core.Diagnostics;
using Cake.Core.Graph;
Expand Down Expand Up @@ -211,19 +212,48 @@ public async Task<CakeReport> RunTargetAsync(ICakeContext context, IExecutionStr
{
// Execute only the target task.
var task = _tasks.FirstOrDefault(x => x.Name.Equals(settings.Target, StringComparison.OrdinalIgnoreCase));
await RunTask(context, strategy, task, target, stopWatch, report);
await RunTask(context, strategy, task, target, stopWatch, report, null);
}
else if (settings.Parallel)
{
await graph.TraverseAsync(target, async (taskName, cancellationTokenSource) =>
{
if (cancellationTokenSource.IsCancellationRequested)
{
return;
}

var task = _tasks.FirstOrDefault(_ => _.Name.Equals(taskName, StringComparison.OrdinalIgnoreCase));
Debug.Assert(task != null, "Node should not be null");

var isTarget = task.Name.Equals(target, StringComparison.OrdinalIgnoreCase);

await RunTask(context, strategy, task, target, stopWatch, report, cancellationTokenSource);
});
}
else
{
// Execute all scheduled tasks.
foreach (var task in orderedTasks)
{
await RunTask(context, strategy, task, target, stopWatch, report);
await RunTask(context, strategy, task, target, stopWatch, report, null);
}
}

return report;
}
catch (TaskCanceledException)
{
exceptionWasThrown = true;
throw;
}
catch (AggregateException ex)
{
exceptionWasThrown = true;
thrownException = ex.InnerException;

throw ex.GetBaseException();
}
catch (Exception ex)
{
exceptionWasThrown = true;
Expand All @@ -236,7 +266,7 @@ public async Task<CakeReport> RunTargetAsync(ICakeContext context, IExecutionStr
}
}

private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, CakeTask task, string target, Stopwatch stopWatch, CakeReport report)
private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, CakeTask task, string target, Stopwatch stopWatch, CakeReport report, CancellationTokenSource cancellationTokenSource)
{
// Is this the current target?
var isTarget = task.Name.Equals(target, StringComparison.OrdinalIgnoreCase);
Expand All @@ -255,7 +285,7 @@ private async Task RunTask(ICakeContext context, IExecutionStrategy strategy, Ca

if (!skipped)
{
await ExecuteTaskAsync(context, strategy, stopWatch, task, report).ConfigureAwait(false);
await ExecuteTaskAsync(context, strategy, stopWatch, task, report, cancellationTokenSource).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -309,7 +339,7 @@ private static bool ShouldTaskExecute(ICakeContext context, CakeTask task, CakeT
}

private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy strategy, Stopwatch stopWatch,
CakeTask task, CakeReport report)
CakeTask task, CakeReport report, CancellationTokenSource cancellationTokenSource)
{
stopWatch.Restart();

Expand All @@ -321,6 +351,11 @@ private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy str
// Execute the task.
await strategy.ExecuteAsync(task, context).ConfigureAwait(false);
}
catch (TaskCanceledException exception)
{
taskException = exception;
throw;
}
catch (Exception exception)
{
_log.Error("An error occurred when executing task '{0}'.", task.Name);
Expand All @@ -340,6 +375,8 @@ private async Task ExecuteTaskAsync(ICakeContext context, IExecutionStrategy str
}
else
{
cancellationTokenSource?.Cancel();

// No error handler defined for this task.
// Rethrow the exception and let it propagate.
throw;
Expand Down
15 changes: 15 additions & 0 deletions src/Cake.Core/ExecutionSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ public sealed class ExecutionSettings
/// </summary>
public bool Exclusive { get; private set; }

/// <summary>
/// Gets a value indicating whether the dependend task of the target should be run in parallel (if possible).
/// </summary>
public bool Parallel { get; private set; }

/// <summary>
/// Sets the target to be executed.
/// </summary>
Expand All @@ -39,5 +44,15 @@ public ExecutionSettings UseExclusiveTarget()
Exclusive = true;
return this;
}

/// <summary>
/// Whether or not to run the dependend task in parallel.
/// </summary>
/// <returns>The same <see cref="ExecutionSettings"/> instance so that multiple calls can be chained.</returns>
public ExecutionSettings RunInParallel()
{
Parallel = true;
return this;
}
}
}
79 changes: 79 additions & 0 deletions src/Cake.Core/Graph/CakeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace Cake.Core.Graph
{
Expand Down Expand Up @@ -112,6 +114,29 @@ public IEnumerable<string> Traverse(string target)
return result;
}

/// <summary>
/// Traverses the graph asynchrounus leading to the specified target.
/// </summary>
/// <param name="target">The target to traverse to.</param>
/// <param name="executeTask">Action which will be called on each task.</param>
/// <returns>A task to wait for.</returns>
public async Task TraverseAsync(string target, Func<string, CancellationTokenSource, Task> executeTask)
{
if (!Exist(target))
{
return;
}

if (HasCircularReferences(target))
{
throw new CakeException("Graph contains circular references.");
}

var cancellationTokenSource = new CancellationTokenSource();
var visitedNodes = new Dictionary<string, Task>();
await TraverseAsync(target, executeTask, cancellationTokenSource, visitedNodes);
}

private void Traverse(string node, ICollection<string> result, ISet<string> visited = null)
{
visited = visited ?? new HashSet<string>(StringComparer.OrdinalIgnoreCase);
Expand All @@ -130,5 +155,59 @@ private void Traverse(string node, ICollection<string> result, ISet<string> visi
throw new CakeException("Graph contains circular references.");
}
}

private async Task TraverseAsync(string node, Func<string, CancellationTokenSource, Task> executeTask,
CancellationTokenSource cancellationTokenSource, IDictionary<string, Task> visitedNodes)
{
if (visitedNodes.ContainsKey(node))
{
await visitedNodes[node];
return;
}

var token = cancellationTokenSource.Token;
var dependentTasks = _edges
.Where(x => x.End.Equals(node, StringComparison.OrdinalIgnoreCase))
.Select(x =>
{
var task = TraverseAsync(x.Start, executeTask, cancellationTokenSource, visitedNodes);
visitedNodes[x.Start] = task;

if (task.IsFaulted)
{
throw task.Exception;
}

return task;
})
.ToArray();

if (dependentTasks.Any())
{
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
token.Register(() => tcs.TrySetCanceled(), false);
await Task.WhenAny(Task.WhenAll(dependentTasks), tcs.Task);
}

await executeTask(node, cancellationTokenSource);
}

private bool HasCircularReferences(string node, Stack<string> visited = null)
{
visited = visited ?? new Stack<string>();

if (visited.Contains(node))
{
return true;
}

visited.Push(node);
var hasCircularReference = _edges
.Where(x => x.End.Equals(node, StringComparison.OrdinalIgnoreCase))
.Any(x => HasCircularReferences(x.Start, visited));
visited.Pop();

return hasCircularReference;
}
}
}
5 changes: 5 additions & 0 deletions src/Cake.Frosting/Internal/Commands/DefaultCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public override int Execute(CommandContext context, DefaultCommandSettings setti
runner.Settings.UseExclusiveTarget();
}

if (settings.Parallel)
{
runner.Settings.RunInParallel();
}

runner.Run(settings.Target);
}
catch (Exception ex)
Expand Down
4 changes: 4 additions & 0 deletions src/Cake.Frosting/Internal/Commands/DefaultCommandSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,9 @@ internal sealed class DefaultCommandSettings : CommandSettings
[CommandOption("--info")]
[Description("Displays additional information about Cake.")]
public bool Info { get; set; }

[CommandOption("--parallel|-p")]
[Description("Enables the support for parallel tasks.")]
public bool Parallel { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/Cake.Tests/Unit/ProgramTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Linq;
using System.Threading.Tasks;
using Autofac;
using Cake.Cli;
Expand Down Expand Up @@ -32,6 +31,7 @@ public async Task Should_Use_Default_Parameters_By_Default()
settings.BuildHostKind == BuildHostKind.Build &&
settings.Debug == false &&
settings.Exclusive == false &&
settings.Parallel == false &&
settings.Script.FullPath == "build.cake" &&
settings.Verbosity == Verbosity.Normal &&
settings.NoBootstrapping == false));
Expand Down
Loading