Skip to content

Commit

Permalink
Add comments to explain how to pass alpha value
Browse files Browse the repository at this point in the history
  • Loading branch information
CongMa13 committed Dec 8, 2023
1 parent 43f33ee commit fec9065
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
11 changes: 10 additions & 1 deletion library/src/contraction/contraction_selection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,16 @@ namespace hiptensor
auto sizeE = elementSpaceFromLengthsAndStrides(e_ms_ns_lengths, e_ms_ns_strides)
* hipDataTypeSize(typeE);

void * A_d, *B_d, *D_d, *E_d, *wspace;
void *A_d, *B_d, *D_d, *E_d, *wspace;

/*
* `alpha` and `beta` are void pointer. hiptensor uses readVal to load the value of alpha.
* ```
* alphaF = hiptensor::readVal<float>(
* alpha, convertToComputeType(HipDataType_v<typename Traits::ComputeDataT>));
* ```
* Hence, the `alpha` and `bete` need to point to a ComputeData value
*/
double alpha = 0.0;
double beta = 0.0;
writeVal(&alpha, computeType, 1.02);
Expand Down
15 changes: 12 additions & 3 deletions test/01_contraction/contraction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,18 @@ namespace hiptensor
auto CDataType = testType[2];
auto DDataType = testType[3];

auto computeType = convertToComputeType(testType[4]);
double alphaBuf = 0.;
double betaBuf = 0.;
auto computeType = convertToComputeType(testType[4]);

/*
* `alpha` and `beta` are void pointer. hiptensor uses readVal to load the value of alpha.
* ```
* alphaF = hiptensor::readVal<float>(
* alpha, convertToComputeType(HipDataType_v<typename Traits::ComputeDataT>));
* ```
* Hence, the `alpha` and `bete` need to point to a ComputeData value
*/
double alphaBuf = 0.;
double betaBuf = 0.;
writeVal(&alphaBuf, computeType, alpha);
writeVal(&betaBuf, computeType, beta);

Expand Down

0 comments on commit fec9065

Please sign in to comment.