Skip to content

Commit

Permalink
Remove unneeded copy from LK ctor
Browse files Browse the repository at this point in the history
  • Loading branch information
mlxd committed Jan 12, 2024
1 parent 6b1c229 commit b6f8c10
Showing 1 changed file with 54 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,60 +115,6 @@ class StateVectorKokkos final
init_generators_indices_();
};

/**
* @brief Init zeros for the state-vector on device.
*/
void initZeros() { Kokkos::deep_copy(getView(), ComplexT{0.0, 0.0}); }

/**
* @brief Set value for a single element of the state-vector on device.
*
* @param index Index of the target element.
*/
void setBasisState(const size_t index) {
KokkosVector sv_view =
getView(); // circumvent error capturing this with KOKKOS_LAMBDA
Kokkos::parallel_for(
sv_view.size(), KOKKOS_LAMBDA(const size_t i) {
sv_view(i) =
(i == index) ? ComplexT{1.0, 0.0} : ComplexT{0.0, 0.0};
});
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
* @param values Values to be set for the target elements.
* @param indices Indices of the target elements.
*/
void setStateVector(const std::vector<std::size_t> &indices,
const std::vector<ComplexT> &values) {
initZeros();
KokkosSizeTVector d_indices("d_indices", indices.size());
KokkosVector d_values("d_values", values.size());
Kokkos::deep_copy(d_indices, UnmanagedConstSizeTHostView(
indices.data(), indices.size()));
Kokkos::deep_copy(d_values, UnmanagedConstComplexHostView(
values.data(), values.size()));
KokkosVector sv_view =
getView(); // circumvent error capturing this with KOKKOS_LAMBDA
Kokkos::parallel_for(
indices.size(), KOKKOS_LAMBDA(const std::size_t i) {
sv_view(d_indices[i]) = d_values[i];
});
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
*
* @param num_qubits Number of qubits
*/
void resetStateVector() {
if (this->getLength() > 0) {
setBasisState(0U);
}
}

/**
* @brief Create a new state vector from data on the host.
*
Expand Down Expand Up @@ -200,7 +146,6 @@ class StateVectorKokkos final
: StateVectorKokkos(log2(length), kokkos_args) {
PL_ABORT_IF_NOT(isPerfectPowerOf2(length),
"The size of provided data must be a power of 2.");
std::vector<ComplexT> hostdata_copy(hostdata_, hostdata_ + length);
HostToDevice(hostdata_copy.data(), length);
}

Expand Down Expand Up @@ -244,6 +189,60 @@ class StateVectorKokkos final
}
}

/**
* @brief Init zeros for the state-vector on device.
*/
void initZeros() { Kokkos::deep_copy(getView(), ComplexT{0.0, 0.0}); }

/**
* @brief Set value for a single element of the state-vector on device.
*
* @param index Index of the target element.
*/
void setBasisState(const size_t index) {
KokkosVector sv_view =
getView(); // circumvent error capturing this with KOKKOS_LAMBDA
Kokkos::parallel_for(
sv_view.size(), KOKKOS_LAMBDA(const size_t i) {
sv_view(i) =
(i == index) ? ComplexT{1.0, 0.0} : ComplexT{0.0, 0.0};
});
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
* @param values Values to be set for the target elements.
* @param indices Indices of the target elements.
*/
void setStateVector(const std::vector<std::size_t> &indices,
const std::vector<ComplexT> &values) {
initZeros();
KokkosSizeTVector d_indices("d_indices", indices.size());
KokkosVector d_values("d_values", values.size());
Kokkos::deep_copy(d_indices, UnmanagedConstSizeTHostView(
indices.data(), indices.size()));
Kokkos::deep_copy(d_values, UnmanagedConstComplexHostView(
values.data(), values.size()));
KokkosVector sv_view =
getView(); // circumvent error capturing this with KOKKOS_LAMBDA
Kokkos::parallel_for(
indices.size(), KOKKOS_LAMBDA(const std::size_t i) {
sv_view(d_indices[i]) = d_values[i];
});
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
*
* @param num_qubits Number of qubits
*/
void resetStateVector() {
if (this->getLength() > 0) {
setBasisState(0U);
}
}

/**
* @brief Apply a single gate to the state vector.
*
Expand Down

0 comments on commit b6f8c10

Please sign in to comment.