Skip to content

Commit

Permalink
Moving details in file_handle.hpp to .cpp (#539)
Browse files Browse the repository at this point in the history
... also don't build `remote_handle.cpp` if `KvikIO_REMOTE_SUPPORT=OFF`, which fixes #538

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)
  - Bradley Dice (https://github.com/bdice)

URL: #539
  • Loading branch information
madsbk authored Nov 6, 2024
1 parent 1c99841 commit 12ca83b
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 130 deletions.
7 changes: 6 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ include(cmake/thirdparty/get_thread_pool.cmake)
# ##################################################################################################
# * library targets --------------------------------------------------------------------------------

file(GLOB SOURCES "src/*.cpp")
set(SOURCES "src/file_handle.cpp")

if(KvikIO_REMOTE_SUPPORT)
list(APPEND SOURCES "src/remote_handle.cpp")
endif()

add_library(kvikio ${SOURCES})

# To avoid symbol conflicts when statically linking to libcurl.a (see get_libcurl.cmake) and its
Expand Down
130 changes: 3 additions & 127 deletions cpp/include/kvikio/file_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
*/
#pragma once

#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include <cstddef>
#include <cstdlib>
#include <stdexcept>
#include <system_error>
#include <utility>

Expand All @@ -37,96 +34,6 @@
#include <kvikio/utils.hpp>

namespace kvikio {
namespace detail {

/**
* @brief Parse open file flags given as a string and return oflags
*
* @param flags The flags
* @param o_direct Append O_DIRECT to the open flags
* @return oflags
*
* @throw std::invalid_argument if the specified flags are not supported.
* @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported.
*/
inline int open_fd_parse_flags(const std::string& flags, bool o_direct)
{
int file_flags = -1;
if (flags.empty()) { throw std::invalid_argument("Unknown file open flag"); }
switch (flags[0]) {
case 'r':
file_flags = O_RDONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
break;
case 'w':
file_flags = O_WRONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
file_flags |= O_CREAT | O_TRUNC;
break;
case 'a': throw std::invalid_argument("Open flag 'a' isn't supported");
default: throw std::invalid_argument("Unknown file open flag");
}
file_flags |= O_CLOEXEC;
if (o_direct) {
#if defined(O_DIRECT)
file_flags |= O_DIRECT;
#else
throw std::invalid_argument("'o_direct' flag unsupported on this platform");
#endif
}
return file_flags;
}

/**
* @brief Open file using `open(2)`
*
* @param flags Open flags given as a string
* @param o_direct Append O_DIRECT to `flags`
* @param mode Access modes
* @return File descriptor
*/
inline int open_fd(const std::string& file_path,
const std::string& flags,
bool o_direct,
mode_t mode)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg)
int fd = ::open(file_path.c_str(), open_fd_parse_flags(flags, o_direct), mode);
if (fd == -1) { throw std::system_error(errno, std::generic_category(), "Unable to open file"); }
return fd;
}

/**
* @brief Get the flags of the file descriptor (see `open(2)`)
*
* @return Open flags
*/
[[nodiscard]] inline int open_flags(int fd)
{
int ret = fcntl(fd, F_GETFL); // NOLINT(cppcoreguidelines-pro-type-vararg)
if (ret == -1) {
throw std::system_error(errno, std::generic_category(), "Unable to retrieve open flags");
}
return ret;
}

/**
* @brief Get file size from file descriptor `fstat(3)`
*
* @param file_descriptor Open file descriptor
* @return The number of bytes
*/
[[nodiscard]] inline std::size_t get_file_size(int file_descriptor)
{
struct stat st {};
int ret = fstat(file_descriptor, &st);
if (ret == -1) {
throw std::system_error(errno, std::generic_category(), "Unable to query file size");
}
return static_cast<std::size_t>(st.st_size);
}

} // namespace detail

/**
* @brief Handle of an open file registered with cufile.
Expand Down Expand Up @@ -166,33 +73,7 @@ class FileHandle {
FileHandle(const std::string& file_path,
const std::string& flags = "r",
mode_t mode = m644,
bool compat_mode = defaults::compat_mode())
: _fd_direct_off{detail::open_fd(file_path, flags, false, mode)},
_initialized{true},
_compat_mode{compat_mode}
{
if (_compat_mode) {
return; // Nothing to do in compatibility mode
}

// Try to open the file with the O_DIRECT flag. Fall back to compatibility mode, if it fails.
try {
_fd_direct_on = detail::open_fd(file_path, flags, true, mode);
} catch (const std::system_error&) {
_compat_mode = true;
} catch (const std::invalid_argument&) {
_compat_mode = true;
}

// Create a cuFile handle, if not in compatibility mode
if (!_compat_mode) {
CUfileDescr_t desc{}; // It is important to set to zero!
desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
desc.handle.fd = _fd_direct_on;
CUFILE_TRY(cuFileAPI::instance().HandleRegister(&_handle, &desc));
}
}
bool compat_mode = defaults::compat_mode());

/**
* @brief FileHandle support move semantic but isn't copyable
Expand Down Expand Up @@ -274,7 +155,7 @@ class FileHandle {
*
* @return File descriptor
*/
[[nodiscard]] int fd_open_flags() const { return detail::open_flags(_fd_direct_off); }
[[nodiscard]] int fd_open_flags() const;

/**
* @brief Get the file size
Expand All @@ -283,12 +164,7 @@ class FileHandle {
*
* @return The number of bytes
*/
[[nodiscard]] std::size_t nbytes() const
{
if (closed()) { return 0; }
if (_nbytes == 0) { _nbytes = detail::get_file_size(_fd_direct_off); }
return _nbytes;
}
[[nodiscard]] std::size_t nbytes() const;

/**
* @brief Reads specified bytes from the file into the device memory.
Expand Down
158 changes: 158 additions & 0 deletions cpp/src/file_handle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstddef>
#include <cstdlib>
#include <stdexcept>
#include <system_error>

#include <kvikio/file_handle.hpp>

namespace kvikio {

namespace {

/**
* @brief Parse open file flags given as a string and return oflags
*
* @param flags The flags
* @param o_direct Append O_DIRECT to the open flags
* @return oflags
*
* @throw std::invalid_argument if the specified flags are not supported.
* @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported.
*/
int open_fd_parse_flags(const std::string& flags, bool o_direct)
{
int file_flags = -1;
if (flags.empty()) { throw std::invalid_argument("Unknown file open flag"); }
switch (flags[0]) {
case 'r':
file_flags = O_RDONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
break;
case 'w':
file_flags = O_WRONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
file_flags |= O_CREAT | O_TRUNC;
break;
case 'a': throw std::invalid_argument("Open flag 'a' isn't supported");
default: throw std::invalid_argument("Unknown file open flag");
}
file_flags |= O_CLOEXEC;
if (o_direct) {
#if defined(O_DIRECT)
file_flags |= O_DIRECT;
#else
throw std::invalid_argument("'o_direct' flag unsupported on this platform");
#endif
}
return file_flags;
}

/**
* @brief Open file using `open(2)`
*
* @param flags Open flags given as a string
* @param o_direct Append O_DIRECT to `flags`
* @param mode Access modes
* @return File descriptor
*/
int open_fd(const std::string& file_path, const std::string& flags, bool o_direct, mode_t mode)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg)
int fd = ::open(file_path.c_str(), open_fd_parse_flags(flags, o_direct), mode);
if (fd == -1) { throw std::system_error(errno, std::generic_category(), "Unable to open file"); }
return fd;
}

/**
* @brief Get the flags of the file descriptor (see `open(2)`)
*
* @return Open flags
*/
[[nodiscard]] int open_flags(int fd)
{
int ret = fcntl(fd, F_GETFL); // NOLINT(cppcoreguidelines-pro-type-vararg)
if (ret == -1) {
throw std::system_error(errno, std::generic_category(), "Unable to retrieve open flags");
}
return ret;
}

/**
* @brief Get file size from file descriptor `fstat(3)`
*
* @param file_descriptor Open file descriptor
* @return The number of bytes
*/
[[nodiscard]] std::size_t get_file_size(int file_descriptor)
{
struct stat st {};
int ret = fstat(file_descriptor, &st);
if (ret == -1) {
throw std::system_error(errno, std::generic_category(), "Unable to query file size");
}
return static_cast<std::size_t>(st.st_size);
}

} // namespace

FileHandle::FileHandle(const std::string& file_path,
const std::string& flags,
mode_t mode,
bool compat_mode)
: _fd_direct_off{open_fd(file_path, flags, false, mode)},
_initialized{true},
_compat_mode{compat_mode}
{
if (_compat_mode) {
return; // Nothing to do in compatibility mode
}

// Try to open the file with the O_DIRECT flag. Fall back to compatibility mode, if it fails.
try {
_fd_direct_on = open_fd(file_path, flags, true, mode);
} catch (const std::system_error&) {
_compat_mode = true;
} catch (const std::invalid_argument&) {
_compat_mode = true;
}

// Create a cuFile handle, if not in compatibility mode
if (!_compat_mode) {
CUfileDescr_t desc{}; // It is important to set to zero!
desc.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
desc.handle.fd = _fd_direct_on;
CUFILE_TRY(cuFileAPI::instance().HandleRegister(&_handle, &desc));
}
}

[[nodiscard]] int FileHandle::fd_open_flags() const { return open_flags(_fd_direct_off); }

[[nodiscard]] std::size_t FileHandle::nbytes() const
{
if (closed()) { return 0; }
if (_nbytes == 0) { _nbytes = get_file_size(_fd_direct_off); }
return _nbytes;
}

} // namespace kvikio
2 changes: 0 additions & 2 deletions cpp/src/remote_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
#include <cstring>
#include <iostream>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down

0 comments on commit 12ca83b

Please sign in to comment.