Skip to content

Commit

Permalink
Merge pull request #596 from madsbk/bhxx
Browse files Browse the repository at this point in the history
Bhxx fixes
  • Loading branch information
madsbk authored Feb 22, 2019
2 parents 06536af + 25a184f commit cad4f74
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
6 changes: 6 additions & 0 deletions bridge/cxx/examples/bhxx_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ If not, see <http://www.gnu.org/licenses/>.
#include <iostream>

#include <bhxx/bhxx.hpp>
#include <bhxx/random.hpp>

using namespace bhxx;

Expand All @@ -41,6 +42,11 @@ void compute()
random123(r, 42, 42);
std::cout << r << std::endl;

a[0] *= -1;
std::cout << "shape: " << a[1].shape() << ", " << bhxx::random.randn<float>({3, 4}).shape() << std::endl;
a[1] *= bhxx::random.randn<float>({3, 4});
std::cout << a << std::endl;

Runtime::instance().flush();
}

Expand Down
13 changes: 7 additions & 6 deletions bridge/cxx/gen_array_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ def write_doc_and_decl(op, layout, type_sig, type_map, operator, out_as_operand,
if len(signature) > 1:
decl += ", "
else:
decl = "BhArray<%s> " % out_cpp_type
if compound_assignment:
decl += "&" # compound assignment such as "+=" returns a reference
decl = "void " # compound assignment such as "+=" returns nothing
else:
decl = "BhArray<%s> " % out_cpp_type
decl += "%s(" % func_name

for i, (symbol, t) in signature[1:]:
if symbol == "A":
if not (i == 1 and compound_assignment):
decl += "const "
decl += "BhArray<%s> &in%d" % (type_map[t]['cpp'], i)
if i == 1 and compound_assignment:
decl += "BhArray<%s> in%d" % (type_map[t]['cpp'], i)
else:
decl += "const BhArray<%s> &in%d" % (type_map[t]['cpp'], i)
doc += "* @param in%d Array input.\n" % i
else:
decl += "%s in%d" % (type_map[t]['cpp'], i)
Expand Down Expand Up @@ -210,7 +212,6 @@ def main(args):
for i in range(1, len(type_sig)):
impl += ", in%s" % i
impl += ");\n"
impl += "\treturn in1;\n"
impl += "}\n"
head += "#endif /* DOXYGEN_SHOULD_SKIP_THIS */\n"
impl += "\n\n"
Expand Down
29 changes: 26 additions & 3 deletions bridge/cxx/include/bhxx/array_create.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,42 @@ BhArray <OutType> empty_like(const bhxx::BhArray<InType> &ary) {
return BhArray<OutType>{ary.shape()};
}

/** Return a new array filled with zeros
/** Return a new array filled with `value`
*
* @tparam T The data type of the new array
* @param shape The shape of the new array
* @param value The value to fill the new array with
* @return The new array
*/
template<typename T>
BhArray <T> zeros(Shape shape) {
BhArray <T> full(Shape shape, T value) {
BhArray<T> ret{std::move(shape)};
ret = T{0};
ret = value;
return ret;
}

/** Return a new array filled with zeros
*
* @tparam T The data type of the new array
* @param shape The shape of the new array
* @return The new array
*/
template<typename T>
BhArray <T> zeros(Shape shape) {
return full(std::move(shape), T{0});
}

/** Return a new array filled with ones
*
* @tparam T The data type of the new array
* @param shape The shape of the new array
* @return The new array
*/
template<typename T>
BhArray <T> ones(Shape shape) {
return full(std::move(shape), T{1});
}

/** Return evenly spaced values within a given interval.
*
* @tparam T Data type of the returned array
Expand Down
2 changes: 2 additions & 0 deletions bridge/cxx/include/bhxx/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ GNU Lesser General Public License along with Bohrium.
If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once

#include <cstdint>
#include <random>
#include <bhxx/BhArray.hpp>
Expand Down
2 changes: 1 addition & 1 deletion bridge/cxx/src/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template<typename T>
BhArray<T> Random::randn(Shape shape) {
BhArray<T> ary(random.random123(shape.prod()));
T max_value = static_cast<T>(std::numeric_limits<uint64_t>::max());
return ary / max_value;
return (ary / max_value).reshape(shape);
}

// Instantiate API that doesn't support booleans
Expand Down

0 comments on commit cad4f74

Please sign in to comment.