Skip to content

Commit

Permalink
move libcurl use to remote_handle.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Nov 3, 2024
1 parent ab789f5 commit 0d42b43
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 153 deletions.
160 changes: 7 additions & 153 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <kvikio/error.hpp>
#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
#include <kvikio/shim/libcurl.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
Expand Down Expand Up @@ -128,78 +127,10 @@ class BounceBufferH2D {
}
};

/**
* @brief Context used by the "CURLOPT_WRITEFUNCTION" callbacks.
*/
struct CallbackContext {
char* buf; // Output buffer to read into.
std::size_t size; // Total number of bytes to read.
std::ptrdiff_t offset; // Offset into `buf` to start reading.
bool overflow_error; // Flag to indicate overflow.
CallbackContext(void* buf, std::size_t size)
: buf{static_cast<char*>(buf)}, size{size}, offset{0}, overflow_error{0}
{
}
BounceBufferH2D* bounce_buffer{nullptr}; // Only used by callback_device_memory
};

/**
* @brief A "CURLOPT_WRITEFUNCTION" to copy downloaded data to the output host buffer.
*
* See <https://curl.se/libcurl/c/CURLOPT_WRITEFUNCTION.html>.
*
* @param data Data downloaded by libcurl that is ready for consumption.
* @param size Size of each element in `nmemb`; size is always 1.
* @param nmemb Size of the data in `nmemb`.
* @param context A pointer to an instance of `CallbackContext`.
*/
inline std::size_t callback_host_memory(char* data,
std::size_t size,
std::size_t nmemb,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
}
KVIKIO_NVTX_SCOPED_RANGE("RemoteHandle - callback_host_memory()", nbytes);
std::memcpy(ctx->buf + ctx->offset, data, nbytes);
ctx->offset += nbytes;
return nbytes;
}

/**
* @brief A "CURLOPT_WRITEFUNCTION" to copy downloaded data to the output device buffer.
*
* See <https://curl.se/libcurl/c/CURLOPT_WRITEFUNCTION.html>.
*
* @param data Data downloaded by libcurl that is ready for consumption.
* @param size Size of each element in `nmemb`; size is always 1.
* @param nmemb Size of the data in `nmemb`.
* @param context A pointer to an instance of `CallbackContext`.
*/
inline std::size_t callback_device_memory(char* data,
std::size_t size,
std::size_t nmemb,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
}
KVIKIO_NVTX_SCOPED_RANGE("RemoteHandle - callback_device_memory()", nbytes);

ctx->bounce_buffer->write(data, nbytes);
ctx->offset += nbytes;
return nbytes;
}

} // namespace detail

class CurlHandle; // Prototype

/**
* @brief Abstract base class for remote endpoints.
*
Expand Down Expand Up @@ -243,7 +174,7 @@ class HttpEndpoint : public RemoteEndpoint {
* @param url The full http url to the remote file.
*/
HttpEndpoint(std::string url) : _url{std::move(url)} {}
void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); }
void setopt(CurlHandle& curl) override;
std::string str() const override { return _url; }
~HttpEndpoint() override = default;
};
Expand Down Expand Up @@ -424,12 +355,7 @@ class S3Endpoint : public RemoteEndpoint {
{
}

void setopt(CurlHandle& curl) override
{
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
}
void setopt(CurlHandle& curl) override;
std::string str() const override { return _url; }
~S3Endpoint() override = default;
};
Expand Down Expand Up @@ -461,23 +387,7 @@ class RemoteHandle {
*
* @param endpoint Remote endpoint used for subsequently IO.
*/
RemoteHandle(std::unique_ptr<RemoteEndpoint> endpoint)
{
auto curl = create_curl_handle();

endpoint->setopt(curl);
curl.setopt(CURLOPT_NOBODY, 1L);
curl.setopt(CURLOPT_FOLLOWLOCATION, 1L);
curl.perform();
curl_off_t cl;
curl.getinfo(CURLINFO_CONTENT_LENGTH_DOWNLOAD_T, &cl);
if (cl < 0) {
throw std::runtime_error("cannot get size of " + endpoint->str() +
", content-length not provided by the server");
}
_nbytes = cl;
_endpoint = std::move(endpoint);
}
RemoteHandle(std::unique_ptr<RemoteEndpoint> endpoint);

// A remote handle is moveable but not copyable.
RemoteHandle(RemoteHandle&& o) = default;
Expand Down Expand Up @@ -513,53 +423,7 @@ class RemoteHandle {
* @param file_offset File offset in bytes.
* @return Number of bytes read, which is always `size`.
*/
std::size_t read(void* buf, std::size_t size, std::size_t file_offset = 0)
{
KVIKIO_NVTX_SCOPED_RANGE("RemoteHandle::read()", size);

if (file_offset + size > _nbytes) {
std::stringstream ss;
ss << "cannot read " << file_offset << "+" << size << " bytes into a " << _nbytes
<< " bytes file (" << _endpoint->str() << ")";
throw std::invalid_argument(ss.str());
}
bool const is_host_mem = is_host_memory(buf);
auto curl = create_curl_handle();
_endpoint->setopt(curl);

std::string const byte_range =
std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1);
curl.setopt(CURLOPT_RANGE, byte_range.c_str());

if (is_host_mem) {
curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_host_memory);
} else {
curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_device_memory);
}
detail::CallbackContext ctx{buf, size};
curl.setopt(CURLOPT_WRITEDATA, &ctx);

try {
if (is_host_mem) {
curl.perform();
} else {
PushAndPopContext c(get_context_from_pointer(buf));
// We use a bounce buffer to avoid many small memory copies to device. Libcurl has a
// maximum chunk size of 16kb (`CURL_MAX_WRITE_SIZE`) but chunks are often much smaller.
detail::BounceBufferH2D bounce_buffer(detail::StreamsByThread::get(), buf);
ctx.bounce_buffer = &bounce_buffer;
curl.perform();
}
} catch (std::runtime_error const& e) {
if (ctx.overflow_error) {
std::stringstream ss;
ss << "maybe the server doesn't support file ranges? [" << e.what() << "]";
throw std::overflow_error(ss.str());
}
throw;
}
return size;
}
std::size_t read(void* buf, std::size_t size, std::size_t file_offset = 0);

/**
* @brief Read from remote source into buffer (host or device memory) in parallel.
Expand All @@ -576,17 +440,7 @@ class RemoteHandle {
std::future<std::size_t> pread(void* buf,
std::size_t size,
std::size_t file_offset = 0,
std::size_t task_size = defaults::task_size())
{
KVIKIO_NVTX_SCOPED_RANGE("RemoteHandle::pread()", size);
auto task = [this](void* devPtr_base,
std::size_t size,
std::size_t file_offset,
std::size_t devPtr_offset) -> std::size_t {
return read(static_cast<char*>(devPtr_base) + devPtr_offset, size, file_offset);
};
return parallel_io(task, buf, size, file_offset, task_size, 0);
}
std::size_t task_size = defaults::task_size());
};

} // namespace kvikio
Loading

0 comments on commit 0d42b43

Please sign in to comment.