From fec9065460d2205f9b9478ccd5f69fa51d2a839e Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Wed, 6 Dec 2023 21:19:36 +0000 Subject: [PATCH] Add comments to explain how to pass alpha value --- library/src/contraction/contraction_selection.cpp | 11 ++++++++++- test/01_contraction/contraction_test.cpp | 15 ++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/library/src/contraction/contraction_selection.cpp b/library/src/contraction/contraction_selection.cpp index 68c748b0..9b0cdf9f 100644 --- a/library/src/contraction/contraction_selection.cpp +++ b/library/src/contraction/contraction_selection.cpp @@ -71,7 +71,16 @@ namespace hiptensor auto sizeE = elementSpaceFromLengthsAndStrides(e_ms_ns_lengths, e_ms_ns_strides) * hipDataTypeSize(typeE); - void * A_d, *B_d, *D_d, *E_d, *wspace; + void *A_d, *B_d, *D_d, *E_d, *wspace; + + /* + * `alpha` and `beta` are void pointer. hiptensor uses readVal to load the value of alpha. + * ``` + * alphaF = hiptensor::readVal( + * alpha, convertToComputeType(HipDataType_v)); + * ``` + * Hence, the `alpha` and `bete` need to point to a ComputeData value + */ double alpha = 0.0; double beta = 0.0; writeVal(&alpha, computeType, 1.02); diff --git a/test/01_contraction/contraction_test.cpp b/test/01_contraction/contraction_test.cpp index ce67278f..76cc3033 100644 --- a/test/01_contraction/contraction_test.cpp +++ b/test/01_contraction/contraction_test.cpp @@ -491,9 +491,18 @@ namespace hiptensor auto CDataType = testType[2]; auto DDataType = testType[3]; - auto computeType = convertToComputeType(testType[4]); - double alphaBuf = 0.; - double betaBuf = 0.; + auto computeType = convertToComputeType(testType[4]); + + /* + * `alpha` and `beta` are void pointer. hiptensor uses readVal to load the value of alpha. + * ``` + * alphaF = hiptensor::readVal( + * alpha, convertToComputeType(HipDataType_v)); + * ``` + * Hence, the `alpha` and `bete` need to point to a ComputeData value + */ + double alphaBuf = 0.; + double betaBuf = 0.; writeVal(&alphaBuf, computeType, alpha); writeVal(&betaBuf, computeType, beta);