Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Status from the thread pool and parallel functions. #5084

Draft
wants to merge 11 commits into
base: dev
Choose a base branch
from
203 changes: 51 additions & 152 deletions tiledb/common/thread_pool/test/unit_thread_pool.cc

Large diffs are not rendered by default.

31 changes: 0 additions & 31 deletions tiledb/common/thread_pool/test/unit_thread_pool.h

This file was deleted.

74 changes: 10 additions & 64 deletions tiledb/common/thread_pool/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,18 @@ void ThreadPool::shutdown() {
threads_.clear();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
auto statuses = wait_all_status(tasks);
for (auto& st : statuses) {
if (!st.ok()) {
return st;
}
}
return Status::Ok();
}

// Return a vector of Status. If any task returns an error value or throws an
// exception, we save an error code in the corresponding location in the Status
// vector. All tasks are waited on before return. Multiple error statuses may
// be saved. We may call logger here because thread pool will not be used until
// context is fully constructed (which will include logger).
// Unfortunately, C++ does not have the notion of an aggregate exception, so we
// don't throw in the case of errors/exceptions.
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<Status> statuses(tasks.size());

void ThreadPool::wait_all(std::vector<Task>& tasks) {
std::queue<size_t> pending_tasks;

// Create queue of ids of all the pending tasks for processing
for (size_t i = 0; i < statuses.size(); ++i) {
for (size_t i = 0; i < tasks.size(); ++i) {
pending_tasks.push(i);
}

Expand All @@ -155,33 +143,12 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
auto& task = tasks[task_id];

if (!task.valid()) {
statuses[task_id] = Status_ThreadPoolError("Invalid task future");
LOG_STATUS_NO_RETURN_VALUE(statuses[task_id]);
throw TaskException("Invalid task future");
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
} catch (const std::string& msg) {
return Status_TaskError("Caught msg: " + msg);
} catch (const Status& stat) {
return stat;
} catch (...) {
return Status_TaskError("Unknown exception");
}
}();

if (!st.ok()) {
LOG_STATUS_NO_RETURN_VALUE(st);
}
statuses[task_id] = st;

// Task is completed, throw possible exception
task.get();
} else {
// If the task is not completed, try again later
pending_tasks.push(task_id);
Expand All @@ -201,39 +168,18 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
}
}
}

return statuses;
}

Status ThreadPool::wait(Task& task) {
void ThreadPool::wait(Task& task) {
while (true) {
if (!task.valid()) {
return Status_ThreadPoolError("Invalid task future");
throw TaskException("Invalid task future");
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
} catch (const std::string& msg) {
return Status_TaskError("Caught msg: " + msg);
} catch (const Status& stat) {
return stat;
} catch (...) {
return Status_TaskError("Unknown exception");
}
}();

if (!st.ok()) {
LOG_STATUS_NO_RETURN_VALUE(st);
}

return st;
// Task is completed, throw possible exception
task.get();
return;
} else {
// In the meantime, try to do something useful to make progress (and avoid
// deadlock)
Expand Down
65 changes: 27 additions & 38 deletions tiledb/common/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,27 @@

#include "producer_consumer_queue.h"

#include <concepts>
#include <functional>
#include <future>

#include "tiledb/common/common.h"
#include "tiledb/common/logger_public.h"
#include "tiledb/common/macros.h"
#include "tiledb/common/status.h"

namespace tiledb::common {

/** Class for Task status exceptions. */
class TaskException : public StatusException {
public:
explicit TaskException(const std::string& msg)
: StatusException("Task", msg) {
}
};

class ThreadPool {
public:
using Task = std::future<Status>;
using Task = std::future<void>;

/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
Expand Down Expand Up @@ -92,26 +100,19 @@ class ThreadPool {
*/

template <class Fn, class... Args>
auto async(Fn&& f, Args&&... args) {
Task async(Fn&& f, Args&&... args)
requires std::same_as<std::invoke_result_t<Fn, std::decay_t<Args>...>, void>
{
if (concurrency_level_ == 0) {
Task invalid_future;
LOG_ERROR("Cannot execute task; thread pool uninitialized.");
return invalid_future;
}

using R = std::invoke_result_t<std::decay_t<Fn>, std::decay_t<Args>...>;

auto task = make_shared<std::packaged_task<R()>>(
HERE(),
[f = std::forward<Fn>(f),
args = std::make_tuple(std::forward<Args>(args)...)]() mutable {
return std::apply(std::move(f), std::move(args));
});

std::future<R> future = task->get_future();

auto task = make_shared<std::packaged_task<void()>>(
HERE(), std::bind(std::forward<Fn>(f), std::forward<Args>(args)...));
auto future = task->get_future();
task_queue_.push(task);

return future;
}

Expand All @@ -123,7 +124,9 @@ class ThreadPool {
* @return std::future referring to the shared state created by this call
*/
template <class Fn, class... Args>
auto execute(Fn&& f, Args&&... args) {
Task execute(Fn&& f, Args&&... args)
requires std::same_as<std::invoke_result_t<Fn, std::decay_t<Args>...>, void>
{
return async(std::forward<Fn>(f), std::forward<Args>(args)...);
}

Expand All @@ -133,36 +136,22 @@ class ThreadPool {
* waiting.
*
* @param tasks Task list to wait on.
* @return Status::Ok if all tasks returned Status::Ok, otherwise the first
* error status is returned
*/
Status wait_all(std::vector<Task>& tasks);

/**
* Wait on all the given tasks to complete, returning a vector of their return
* Status. Exceptions caught while waiting are returned as Status_TaskError.
* Status are saved at the same index in the return vector as the
* corresponding task in the input vector. The status vector may contain more
* than one error Status.
*
* This function is safe to call recursively and may execute pending tasks
* with the calling thread while waiting.
*
* @param tasks Task list to wait on
* @return Vector of each task's Status.
* @throws This function will throw the first exception thrown by one of the
* tasks.
*/
std::vector<Status> wait_all_status(std::vector<Task>& tasks);
void wait_all(std::vector<Task>& tasks);

/**
* Wait on a single tasks to complete. This function is safe to call
* recursively and may execute pending tasks on the calling thread while
* waiting.
*
* @param task Task to wait on.
* @return Status::Ok if the task returned Status::Ok, otherwise the error
* status is returned
*
* @throws This function will throw the exception thrown by task.
*/
Status wait(Task& task);
void wait(Task& task);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand All @@ -177,8 +166,8 @@ class ThreadPool {

/** Producer-consumer queue where functions to be executed are kept */
ProducerConsumerQueue<
shared_ptr<std::packaged_task<Status()>>,
std::deque<shared_ptr<std::packaged_task<Status()>>>>
shared_ptr<std::packaged_task<void()>>,
std::deque<shared_ptr<std::packaged_task<void()>>>>
task_queue_;

/** The worker threads */
Expand Down
44 changes: 18 additions & 26 deletions tiledb/sm/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,14 @@ void Array::delete_fragments(

// Delete fragments and commits
auto vfs = &(resources.vfs());
throw_if_not_ok(parallel_for(
&resources.compute_tp(), 0, fragment_uris.size(), [&](size_t i) {
throw_if_not_ok(vfs->remove_dir(fragment_uris[i].uri_));
bool is_file = false;
throw_if_not_ok(vfs->is_file(commit_uris_to_delete[i], &is_file));
if (is_file) {
throw_if_not_ok(vfs->remove_file(commit_uris_to_delete[i]));
}
return Status::Ok();
}));
parallel_for(&resources.compute_tp(), 0, fragment_uris.size(), [&](size_t i) {
throw_if_not_ok(vfs->remove_dir(fragment_uris[i].uri_));
bool is_file = false;
throw_if_not_ok(vfs->is_file(commit_uris_to_delete[i], &is_file));
if (is_file) {
throw_if_not_ok(vfs->remove_file(commit_uris_to_delete[i]));
}
});
}

void Array::delete_fragments(
Expand Down Expand Up @@ -1711,7 +1709,7 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()

// Load all metadata for tile var sizes among fragments.
for (const auto& var_name : var_names) {
throw_if_not_ok(parallel_for(
parallel_for(
&resources_.compute_tp(),
0,
fragment_metadata.size(),
Expand All @@ -1720,17 +1718,16 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()
// evolution that do not exists in this fragment.
const auto& schema = fragment_metadata[f]->array_schema();
if (!schema->is_field(var_name)) {
return Status::Ok();
return;
}

fragment_metadata[f]->loaded_metadata()->load_tile_var_sizes(
*encryption_key(), var_name);
return Status::Ok();
}));
});
}

// Now compute for each var size names, the average cell size.
throw_if_not_ok(parallel_for(
parallel_for(
&resources_.compute_tp(), 0, var_names.size(), [&](const uint64_t n) {
uint64_t total_size = 0;
uint64_t cell_num = 0;
Expand All @@ -1756,9 +1753,7 @@ std::unordered_map<std::string, uint64_t> Array::get_average_var_cell_sizes()

uint64_t average_cell_size = total_size / cell_num;
ret[var_name] = std::max<uint64_t>(average_cell_size, 1);

return Status::Ok();
}));
});

return ret;
}
Expand Down Expand Up @@ -1988,15 +1983,12 @@ void Array::do_load_metadata() {

auto metadata_num = array_metadata_to_load.size();
std::vector<shared_ptr<Tile>> metadata_tiles(metadata_num);
throw_if_not_ok(
parallel_for(&resources_.compute_tp(), 0, metadata_num, [&](size_t m) {
const auto& uri = array_metadata_to_load[m].uri_;

metadata_tiles[m] = GenericTileIO::load(
resources_, uri, 0, *encryption_key(), memory_tracker_);
parallel_for(&resources_.compute_tp(), 0, metadata_num, [&](size_t m) {
const auto& uri = array_metadata_to_load[m].uri_;

return Status::Ok();
}));
metadata_tiles[m] = GenericTileIO::load(
resources_, uri, 0, *encryption_key(), memory_tracker_);
});

// Compute array metadata size for the statistics
uint64_t meta_size = 0;
Expand Down
Loading
Loading