Skip to content

Commit

Permalink
Merge pull request #259 from dlangbe/ac-6.2
Browse files Browse the repository at this point in the history
Actor-Critic Support
  • Loading branch information
dlangbe authored Aug 15, 2024
2 parents a0b1a10 + 704d9f7 commit 833ee93
Show file tree
Hide file tree
Showing 33 changed files with 1,036 additions and 178 deletions.
2 changes: 1 addition & 1 deletion library/include/hiptensor/hiptensor_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ typedef enum
//! Log selection messages
HIPTENSOR_LOG_LEVEL_HEURISTICS_TRACE = 8,
//! Log a trace of API calls
HIPTENSOR_LOG_LEVEL_API_TRACE = 16
HIPTENSOR_LOG_LEVEL_API_TRACE = 16,

} hiptensorLogLevel_t;

Expand Down
1,015 changes: 889 additions & 126 deletions library/src/contraction/contraction_selection.cpp

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions library/src/contraction/contraction_selection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,19 @@ namespace hiptensor
hipDataType typeA,
std::vector<std::size_t> const& a_ms_ks_lengths,
std::vector<std::size_t> const& a_ms_ks_strides,
std::vector<int32_t> const& a_ms_ks_modes,
hipDataType typeB,
std::vector<std::size_t> const& b_ns_ks_lengths,
std::vector<std::size_t> const& b_ns_ks_strides,
std::vector<int32_t> const& b_ns_ks_modes,
hipDataType typeD,
std::vector<std::size_t> const& d_ms_ns_lengths,
std::vector<std::size_t> const& d_ms_ns_strides,
std::vector<int32_t> const& d_ms_ns_modes,
hipDataType typeE,
std::vector<std::size_t> const& e_ms_ns_lengths,
std::vector<std::size_t> const& e_ms_ns_strides,
std::vector<int32_t> const& e_ms_ns_modes,
const uint64_t workspaceSize);
};

Expand All @@ -88,15 +92,19 @@ namespace hiptensor
hipDataType typeA,
std::vector<std::size_t> const& a_ms_ks_lengths,
std::vector<std::size_t> const& a_ms_ks_strides,
std::vector<int32_t> const& a_ms_ks_modes,
hipDataType typeB,
std::vector<std::size_t> const& b_ns_ks_lengths,
std::vector<std::size_t> const& b_ns_ks_strides,
std::vector<int32_t> const& b_ns_ks_modes,
hipDataType typeD,
std::vector<std::size_t> const& d_ms_ns_lengths,
std::vector<std::size_t> const& d_ms_ns_strides,
std::vector<int32_t> const& d_ms_ns_modes,
hipDataType typeE,
std::vector<std::size_t> const& e_ms_ns_lengths,
std::vector<std::size_t> const& e_ms_ns_strides,
std::vector<int32_t> const& e_ms_ns_modes,
hiptensorComputeType_t computeType,
const uint64_t workspaceSize);

Expand Down
3 changes: 1 addition & 2 deletions library/src/contraction/contraction_solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,15 @@ namespace hiptensor
e_ms_ns_strides,
e_ms_ns_modes,
workspacePtr))

{
return {HIPTENSOR_STATUS_INTERNAL_ERROR, -1.0f};
}

if(this->workspaceSize() > workspaceSize)
{
resetInvokerArgs();
return {HIPTENSOR_STATUS_INSUFFICIENT_WORKSPACE, -1.0f};
}

auto time = mInvokerPtr->Run(mInvokerArgPtr.get(), streamConfig);
resetInvokerArgs();

Expand Down
12 changes: 11 additions & 1 deletion library/src/contraction/contraction_solution_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ namespace hiptensor
// Arg test
Base::mValid = deviceOp->IsSupportedArgument(Base::mInvokerArgPtr.get());

return mValid;
if(!Base::mValid)
{
resetArgs();
}

return Base::mValid;
}
};

Expand Down Expand Up @@ -324,6 +329,11 @@ namespace hiptensor
// Arg test
Base::mValid = deviceOp->IsSupportedArgument(Base::mInvokerArgPtr.get());

if(!Base::mValid)
{
resetArgs();
}

return Base::mValid;
}
};
Expand Down
4 changes: 4 additions & 0 deletions library/src/contraction/hiptensor_contraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,19 @@ hiptensorStatus_t hiptensorInitContractionPlan(const hiptensorHandle_t*
ADataType,
desc->mTensorDesc[0].mLengths,
desc->mTensorDesc[0].mStrides,
desc->mTensorMode[0],
BDataType,
desc->mTensorDesc[1].mLengths,
desc->mTensorDesc[1].mStrides,
desc->mTensorMode[1],
DDataType,
desc->mTensorDesc[2].mLengths,
desc->mTensorDesc[2].mStrides,
desc->mTensorMode[2],
EDataType,
desc->mTensorDesc[3].mLengths,
desc->mTensorDesc[3].mStrides,
desc->mTensorMode[2],
desc->mComputeType,
workspaceSize);
}
Expand Down
2 changes: 1 addition & 1 deletion library/src/include/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace hiptensor
LOG_LEVEL_PERF_TRACE = 2,
LOG_LEVEL_PERF_HINT = 4,
LOG_LEVEL_HEURISTICS_TRACE = 8,
LOG_LEVEL_API_TRACE = 16
LOG_LEVEL_API_TRACE = 16,
};

// For static initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
1 change: 1 addition & 0 deletions test/01_contraction/configs/scale_test_params_rank6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tensor Data Types:
Algorithm Types:
- HIPTENSOR_ALGO_DEFAULT
- HIPTENSOR_ALGO_DEFAULT_PATIENT
- HIPTENSOR_ALGO_ACTOR_CRITIC
Operators:
- HIPTENSOR_OP_IDENTITY
Worksize Prefs:
Expand Down
Loading

0 comments on commit 833ee93

Please sign in to comment.