Skip to content

Commit

Permalink
Implement a basic CUDA capable shared_ptr class
Browse files Browse the repository at this point in the history
  • Loading branch information
mborland committed Sep 6, 2024
1 parent cf4c289 commit 7cffc95
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions include/boost/math/tools/shared_ptr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) 2024 Matt Borland
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#ifndef BOOST_MATH_TOOLS_VECTOR_HPP
#define BOOST_MATH_TOOLS_VECTOR_HPP

#include <boost/math/tools/config.hpp>
#include <boost/math/tools/cstdint.hpp>
#include <boost/math/tools/type_traits.hpp>

#ifndef BOOST_MATH_ENABLE_CUDA

#include <memory>

namespace boost {
namespace math {

using std::shared_ptr;
using std::make_shared;

} // namespace math
} // namespace boost

#else // CUDA shared pointer

namespace boost {
namespace math {

template <typename T>
class shared_ptr
{
private:
T* ptr;
int* ref_count;

// If valid increment the reference count
void increment_ref_count()
{
if (ref_count != nullptr)
{
*ref_count++;
}
}

// If valid decrement the reference count
// If we hit 0 references destroy the objects so they are not leaked
void decrement_ref_count()
{
if (ref_count != nullptr)
{
*ref_count--;
if (*ref_count == 0)
{
cudaFree(ptr);
cudaFree(ref_count);
}
}
}

public:
// Constructor
explicit shared_ptr(T* p = nullptr) : ptr(p)
{
cudaMalloc(&ref_count, sizeof(int));
if (ref_count != nullptr)
{
*ref_count = 1;
}
}

// Copy Constructor
shared_ptr(const shared_ptr& other) : ptr(other.ptr), ref_count(other.ref_count)
{
increment_ref_count();
}

// Move Constructor
shared_ptr(shared_ptr&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count)
{
other.ptr = nullptr;
other.ref_count = nullptr;
}

// Copy Assignment
shared_ptr& operator=(const shared_ptr& other)
{
if (this != &other)
{
decrement_ref_count();
ptr = other.ptr;
ref_count = other.ref_count;
increment_ref_count();
}

return *this;
}

// Move assignment
shared_ptr* operator=(shared_ptr&& other)
{
if (this != &other)
{
decrement_ref_count();
ptr = other.ptr;
ref_count = other.ref_count;
other.ptr = nullptr;
other.ref_count = nullptr;
}

return *this;
}

// Destructor
~shared_ptr()
{
decrement_ref_count();
}

T& operator*() const
{
return *ptr;
}

T* operator->() const
{
return ptr;
}

T* get() const
{
return ptr;
}

int use_count() const
{
return ref_count != nullptr ? *ref_count : 0;
}

bool unique() const
{
return use_count() == 1;
}

// Reset to new pointer
// Have to malloc new memory for a new reference counter
void reset(T* p = nullptr)
{
decrement_ref_count();
ptr = p;
cudaMalloc(&ref_count, sizeof(int));
*ref_count = 1;
}
}

} // Namespace math
} // Namespace boost

#endif // CUDA vector

#endif // BOOST_MATH_TOOLS_VECTOR_HPP

0 comments on commit 7cffc95

Please sign in to comment.