Added complex number support
This commit is contained in:
parent
36b2720fe3
commit
1c439b4944
98
Array.h
98
Array.h
@ -1,11 +1,13 @@
|
||||
#ifndef ARRAY_H
|
||||
#define ARRAY_H
|
||||
#ifndef CUDATOOLS_ARRAY_H
|
||||
#define CUDATOOLS_ARRAY_H
|
||||
|
||||
#include "Complex.h"
|
||||
#include "Core.h"
|
||||
#include "Macros.h"
|
||||
#include <Eigen/Dense>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <iomanip>
|
||||
#include <math.h>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
|
||||
@ -17,18 +19,34 @@
|
||||
|
||||
namespace CudaTools {
|
||||
|
||||
/** Type alises and lots of metaprogramming definitions, primarily dealing with
|
||||
* the different numeric types and overrides. */
|
||||
|
||||
template <typename T>
|
||||
using EigenMat = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
|
||||
template <typename T> using EigenMapMat = Eigen::Map<EigenMat<T>>;
|
||||
template <typename T> using ConstEigenMapMat = Eigen::Map<const EigenMat<T>>;
|
||||
|
||||
template <typename T> struct EigenAdaptConst { typedef EigenMapMat<T> type; };
|
||||
template <typename T> struct EigenAdaptConst<const T> { typedef ConstEigenMapMat<T> type; };
|
||||
template <typename T> struct EigenAdaptConst_S { typedef EigenMapMat<T> type; };
|
||||
template <typename T> struct EigenAdaptConst_S<const T> { typedef ConstEigenMapMat<T> type; };
|
||||
template <typename T> using EigenAdaptConst = typename EigenAdaptConst_S<T>::type;
|
||||
|
||||
#define ENABLE_IF(X) std::enable_if_t<X, bool>
|
||||
#define IS_INT(T) std::is_integral<T>::value
|
||||
#define IS_FLOAT(T) std::is_floating_point<T>::value
|
||||
#define IS_NUM(T) IS_INT(T) or IS_FLOAT(T)
|
||||
template <typename T> struct ComplexUnderlying_S { typedef T type; };
|
||||
template <> struct ComplexUnderlying_S<complex64> { typedef float type; };
|
||||
template <> struct ComplexUnderlying_S<complex128> { typedef double type; };
|
||||
template <typename T> using ComplexUnderlying = typename ComplexUnderlying_S<T>::type;
|
||||
|
||||
template <typename T> struct ComplexConversion_S { typedef T type; };
|
||||
template <> struct ComplexConversion_S<complex64> { typedef std::complex<float> type; };
|
||||
template <> struct ComplexConversion_S<complex128> { typedef std::complex<double> type; };
|
||||
template <typename T> using ComplexConversion = typename ComplexConversion_S<T>::type;
|
||||
|
||||
template <typename T> inline constexpr bool is_int = std::is_integral<T>::value;
|
||||
template <typename T> inline constexpr bool is_float = std::is_floating_point<T>::value;
|
||||
template <typename T>
|
||||
inline constexpr bool is_complex =
|
||||
std::is_same<T, complex64>::value or std::is_same<T, complex128>::value;
|
||||
template <typename T> inline constexpr bool is_num = is_int<T> or is_float<T> or is_complex<T>;
|
||||
|
||||
template <typename T> class Array;
|
||||
using Slice = std::pair<uint32_t, uint32_t>;
|
||||
@ -99,11 +117,11 @@ template <typename T> class ArrayIterator {
|
||||
*/
|
||||
HD void advance(const int32_t amount) {
|
||||
if (amount < 0) {
|
||||
for (uint32_t i = 0; i < abs(amount); ++i) {
|
||||
for (uint32_t i = 0; i < std::abs(amount); ++i) {
|
||||
prev();
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < abs(amount); ++i) {
|
||||
for (uint32_t i = 0; i < std::abs(amount); ++i) {
|
||||
next();
|
||||
}
|
||||
}
|
||||
@ -211,7 +229,7 @@ template <typename T> class Array {
|
||||
pHost = new T[shape.items()];
|
||||
calcEnd();
|
||||
if (noDevice) return;
|
||||
pDevice = (T*)CudaTools::malloc(shape.items() * sizeof(T));
|
||||
pDevice = reinterpret_cast<T*>(CudaTools::malloc(shape.items() * sizeof(T)));
|
||||
};
|
||||
|
||||
/**
|
||||
@ -226,7 +244,7 @@ template <typename T> class Array {
|
||||
calcEnd();
|
||||
#ifndef DEVICE
|
||||
if (noDevice) return;
|
||||
pDevice = (T*)CudaTools::malloc(shape.items() * sizeof(T));
|
||||
pDevice = reinterpret_cast<T*>(CudaTools::malloc(shape.items() * sizeof(T)));
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -492,12 +510,13 @@ template <typename T> class Array {
|
||||
/**
|
||||
* Returns the Eigen::Map of this Array.
|
||||
*/
|
||||
typename EigenAdaptConst<T>::type eigenMap() const {
|
||||
EigenAdaptConst<ComplexConversion<T>> eigenMap() const {
|
||||
uint32_t total_dim = mShape.mAxes;
|
||||
CT_ERROR(mIsSlice, "Mapping to an Eigen array cannot occur on slices")
|
||||
CT_ERROR_IF(total_dim, !=, 2,
|
||||
"Mapping to an Eigen array can only occur on two-dimensional arrays");
|
||||
return typename EigenAdaptConst<T>::type(POINTER, mShape.rows(), mShape.cols());
|
||||
return EigenAdaptConst<ComplexConversion<T>>((ComplexConversion<T>*)POINTER, mShape.rows(),
|
||||
mShape.cols());
|
||||
};
|
||||
|
||||
/**
|
||||
@ -508,7 +527,7 @@ template <typename T> class Array {
|
||||
/**
|
||||
* Gets the pointer to this array, depending on host or device.
|
||||
*/
|
||||
HD T* data() const { return POINTER; };
|
||||
HD ComplexConversion<T>* data() const { return (ComplexConversion<T>*)POINTER; };
|
||||
|
||||
/**
|
||||
* Returns the device pointer regardless of host or device.
|
||||
@ -556,7 +575,7 @@ template <typename T> class Array {
|
||||
* Sets the values of the entire Array to a constant. This is restricted to numerical types.
|
||||
*/
|
||||
HD void setConstant(const T value) const {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
for (auto it = begin(); it != end(); ++it) {
|
||||
*it = value;
|
||||
}
|
||||
@ -568,20 +587,33 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
void setRandom(const T min, const T max) const {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
CT_ERROR_IF(max, <, min, "Upper bound of range cannot be larger than lower bound");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
if constexpr (is_complex<T>) {
|
||||
CT_ERROR_IF(max.real(), <, min.real(),
|
||||
"Upper bound of range cannot be larger than lower bound");
|
||||
CT_ERROR_IF(max.imag(), <, min.imag(),
|
||||
"Upper bound of range cannot be larger than lower bound");
|
||||
} else {
|
||||
CT_ERROR_IF(max, <, min, "Upper bound of range cannot be larger than lower bound");
|
||||
}
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
if constexpr (IS_INT(T)) {
|
||||
if constexpr (is_int<T>) {
|
||||
std::uniform_int_distribution<T> dist(min, max);
|
||||
for (auto it = begin(); it != end(); ++it) {
|
||||
*it = dist(mt);
|
||||
}
|
||||
} else if constexpr (IS_FLOAT(T)) {
|
||||
} else if constexpr (is_float<T>) {
|
||||
std::uniform_real_distribution<T> dist(min, max);
|
||||
for (auto it = begin(); it != end(); ++it) {
|
||||
*it = dist(mt);
|
||||
}
|
||||
} else if constexpr (is_complex<T>) {
|
||||
std::uniform_real_distribution<ComplexUnderlying<T>> distr(min.real(), max.real());
|
||||
std::uniform_real_distribution<ComplexUnderlying<T>> disti(min.imag(), max.imag());
|
||||
for (auto it = begin(); it != end(); ++it) {
|
||||
*it = T(distr(mt), disti(mt));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -590,7 +622,7 @@ template <typename T> class Array {
|
||||
* restricted to numerical types.
|
||||
*/
|
||||
HD void setRange(T min, const T step = 1) const {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
for (auto it = begin(); it != end(); ++it) {
|
||||
*it = min;
|
||||
min += step;
|
||||
@ -601,7 +633,7 @@ template <typename T> class Array {
|
||||
* to floating point types.
|
||||
*/
|
||||
HD void setLinspace(const T min, const T max) const {
|
||||
static_assert(IS_FLOAT(T), "Function only available on numeric floating types.");
|
||||
static_assert(is_float<T>, "Function only available on numeric floating types.");
|
||||
CT_ERROR_IF(max, <, min, "Upper bound of range cannot be larger than lower bound");
|
||||
T i = 0;
|
||||
T d = max - min;
|
||||
@ -617,7 +649,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
static Array constant(const Shape& shape, const T value) {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
Array<T> arr(shape);
|
||||
arr.setConstant(value);
|
||||
return arr;
|
||||
@ -629,7 +661,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
static Array random(const Shape& shape, const T min, const T max) {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
Array<T> arr(shape);
|
||||
arr.setRandom(min, max);
|
||||
return arr;
|
||||
@ -640,7 +672,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
static Array range(const T min, const T max, const T step = 1) {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
CT_ERROR_IF(max, <, min, "Upper bound of range cannot be larger than lower bound");
|
||||
Array<T> arr({(uint32_t)((max - min) / step)});
|
||||
arr.setRange(min, step);
|
||||
@ -653,7 +685,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
static Array linspace(const T min, const T max, const uint32_t size) {
|
||||
static_assert(IS_FLOAT(T), "Function only available on numeric floating types.");
|
||||
static_assert(is_float<T>, "Function only available on numeric floating types.");
|
||||
Array<T> arr({size});
|
||||
arr.setLinspace(min, max);
|
||||
return arr;
|
||||
@ -665,7 +697,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
Array transposed() const {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
CT_ERROR_IF(shape().axes(), !=, 2, "Tranpose can only occur on two-dimensional arrays");
|
||||
Array<T> new_arr({mShape.rows(), mShape.cols()});
|
||||
new_arr.eigenMap() = this->eigenMap().transpose().eval();
|
||||
@ -678,7 +710,7 @@ template <typename T> class Array {
|
||||
* \brief Host only
|
||||
*/
|
||||
void transpose() {
|
||||
static_assert(IS_NUM(T), "Function only available on numeric types.");
|
||||
static_assert(is_num<T>, "Function only available on numeric types.");
|
||||
CT_ERROR_IF(shape().axes(), !=, 2, "Tranpose can only occur on two-dimensional arrays");
|
||||
Array<T> new_arr(*this, {mShape.cols(), mShape.rows()});
|
||||
new_arr.eigenMap() = this->eigenMap().transpose().eval();
|
||||
@ -686,7 +718,7 @@ template <typename T> class Array {
|
||||
};
|
||||
|
||||
void inverse() const {
|
||||
static_assert(IS_FLOAT(T), "Function only available on floating numeric types.");
|
||||
static_assert(is_float<T>, "Function only available on floating numeric types.");
|
||||
CT_ERROR_IF(shape().axes(), !=, 2, "Inverse can only occur on two-dimensional arrays");
|
||||
CT_ERROR_IF(shape().rows(), !=, shape().cols(),
|
||||
"Inverse can only occur on square matrices");
|
||||
@ -736,7 +768,7 @@ void printAxis(std::ostream& out, const Array<T>& arr, const uint32_t axis, size
|
||||
} else {
|
||||
out << std::setw((i == 0) ? width - 1 : width);
|
||||
}
|
||||
out << (T)arr[i] << ((i == arr.shape().items() - 1) ? "]" : ",");
|
||||
out << static_cast<T>(arr[i]) << ((i == arr.shape().items() - 1) ? "]" : ",");
|
||||
}
|
||||
} else if (arr.shape().axes() == 2) {
|
||||
for (uint32_t i = 0; i < arr.shape().dim(0); ++i) {
|
||||
@ -756,7 +788,7 @@ void printAxis(std::ostream& out, const Array<T>& arr, const uint32_t axis, size
|
||||
|
||||
template <typename T> std::ostream& operator<<(std::ostream& out, const Array<T>& arr) {
|
||||
size_t width = 0;
|
||||
if constexpr (IS_NUM(T)) {
|
||||
if constexpr (is_num<T>) {
|
||||
T max_val = 0;
|
||||
bool negative = false;
|
||||
for (auto it = arr.begin(); it != arr.end(); ++it) {
|
||||
@ -765,7 +797,7 @@ template <typename T> std::ostream& operator<<(std::ostream& out, const Array<T>
|
||||
}
|
||||
width = std::to_string(max_val).size() + 1;
|
||||
width += (negative) ? 1 : 0;
|
||||
} else if constexpr (IS_FLOAT(T)) {
|
||||
} else if constexpr (is_float<T>) {
|
||||
T max_val = 0;
|
||||
bool negative = false;
|
||||
for (auto it = arr.begin(); it != arr.end(); ++it) {
|
||||
|
||||
132
BLAS.h
132
BLAS.h
@ -1,10 +1,15 @@
|
||||
#ifndef BLAS_H
|
||||
#define BLAS_H
|
||||
#ifndef CUDATOOLS_BLAS_H
|
||||
#define CUDATOOLS_BLAS_H
|
||||
|
||||
#include "Array.h"
|
||||
#include "Complex.h"
|
||||
#include "Core.h"
|
||||
#include "Macros.h"
|
||||
|
||||
#ifdef CUDACC
|
||||
#include <cuComplex.h>
|
||||
#endif
|
||||
|
||||
namespace CudaTools {
|
||||
|
||||
namespace BLAS {
|
||||
@ -186,12 +191,29 @@ template <typename T> class Batch {
|
||||
// cuBLAS API //
|
||||
////////////////
|
||||
|
||||
template <typename T, typename F1, typename F2, typename... Args>
|
||||
constexpr void invoke(F1 f1, F2 f2, Args&&... args) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
template <typename T> struct CudaComplexConversion_S { typedef T type; };
|
||||
#ifdef CUDACC
|
||||
template <> struct CudaComplexConversion_S<complex64> { typedef cuComplex type; };
|
||||
template <> struct CudaComplexConversion_S<complex128> { typedef cuDoubleComplex type; };
|
||||
#endif
|
||||
|
||||
template <typename T> using CudaComplexConversion = typename CudaComplexConversion_S<T>::type;
|
||||
|
||||
// Shorthands to reduce clutter.
|
||||
|
||||
#define CAST(var) reinterpret_cast<CudaComplexConversion<T>*>(var)
|
||||
#define DCAST(var) reinterpret_cast<CudaComplexConversion<T>**>(var)
|
||||
|
||||
template <typename T, typename F1, typename F2, typename F3, typename F4, typename... Args>
|
||||
constexpr void invoke(F1 f1, F2 f2, F3 f3, F4 f4, Args&&... args) {
|
||||
if constexpr (std::is_same<T, real32>::value) {
|
||||
CUBLAS_CHECK(f1(args...));
|
||||
} else if constexpr (std::is_same<T, double>::value) {
|
||||
} else if constexpr (std::is_same<T, real64>::value) {
|
||||
CUBLAS_CHECK(f2(args...));
|
||||
} else if constexpr (std::is_same<T, complex64>::value) {
|
||||
CUBLAS_CHECK(f3(args...));
|
||||
} else if constexpr (std::is_same<T, complex128>::value) {
|
||||
CUBLAS_CHECK(f4(args...));
|
||||
} else {
|
||||
CT_ERROR(true, "BLAS functions are not callable with that type");
|
||||
}
|
||||
@ -216,14 +238,16 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta,
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
if (bi.size == 1) {
|
||||
invoke<T>(cublasSgemv, cublasDgemv, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols,
|
||||
&a, A.dataDevice(), rows, x.dataDevice(), 1, &b, y.dataDevice(), 1);
|
||||
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a),
|
||||
CAST(A.dataDevice()), rows, CAST(x.dataDevice()), 1, CAST(&b),
|
||||
CAST(y.dataDevice()), 1);
|
||||
|
||||
} else { // Greater than 2, so broadcast.
|
||||
invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, &a, A.dataDevice(), rows,
|
||||
bi.strideA, x.dataDevice(), 1, bi.strideB, &b, y.dataDevice(), 1, bi.strideC,
|
||||
bi.size);
|
||||
invoke<T>(cublasSgemvStridedBatched, cublasDgemvStridedBatched, cublasCgemvStridedBatched,
|
||||
cublasZgemvStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, rows,
|
||||
cols, CAST(&a), CAST(A.dataDevice()), rows, bi.strideA, CAST(x.dataDevice()), 1,
|
||||
bi.strideB, CAST(&b), CAST(y.dataDevice()), 1, bi.strideC, bi.size);
|
||||
}
|
||||
|
||||
#else
|
||||
@ -261,15 +285,17 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta,
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
if (bi.size == 1) {
|
||||
invoke<T>(cublasSgemm, cublasDgemm, Manager::get()->cublasHandle(), CUBLAS_OP_N,
|
||||
CUBLAS_OP_N, m, n, k, &a, A.dataDevice(), m, B.dataDevice(), k, &b,
|
||||
C.dataDevice(), m);
|
||||
invoke<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
|
||||
CAST(A.dataDevice()), m, CAST(B.dataDevice()), k, CAST(&b), CAST(C.dataDevice()),
|
||||
m);
|
||||
|
||||
} else { // Greater than 2, so broadcast.
|
||||
invoke<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &a,
|
||||
A.dataDevice(), m, bi.strideA, B.dataDevice(), k, bi.strideB, &b, C.dataDevice(),
|
||||
m, bi.strideC, bi.size);
|
||||
invoke<T>(cublasSgemmStridedBatched, cublasDgemmStridedBatched, cublasCgemmStridedBatched,
|
||||
cublasZgemmStridedBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N,
|
||||
CUBLAS_OP_N, m, n, k, CAST(&a), CAST(A.dataDevice()), m, bi.strideA,
|
||||
CAST(B.dataDevice()), k, bi.strideB, CAST(&b), CAST(C.dataDevice()), m,
|
||||
bi.strideC, bi.size);
|
||||
}
|
||||
|
||||
#else
|
||||
@ -314,8 +340,9 @@ StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const boo
|
||||
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
invoke<T>(cublasSdgmm, cublasDdgmm, Manager::get()->cublasHandle(), m, n, A.dataDevice(),
|
||||
A.shape().rows(), X.dataDevice(), 1, C.dataDevice(), m);
|
||||
invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m,
|
||||
n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1,
|
||||
CAST(C.dataDevice()), m);
|
||||
#else
|
||||
if (left) {
|
||||
C.eigenMap() = X.eigenMap().asDiagonal() * A.eigenMap();
|
||||
@ -341,13 +368,14 @@ template <typename T> static Array<T> empty({1, 1});
|
||||
template <typename T> static EigenMapMat<T> empty_map = empty<T>.eigenMap();
|
||||
}; // namespace internal
|
||||
|
||||
template <typename T, ENABLE_IF(IS_FLOAT(T)) = true> class PLUArray;
|
||||
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true> class PLUArray;
|
||||
// This is a wrapper class for Eigen's class so we have more controlled access to
|
||||
// the underlying data.
|
||||
template <typename T> class PartialPivLU : public Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>> {
|
||||
private:
|
||||
using Base = Eigen::PartialPivLU<Eigen::Ref<EigenMat<T>>>;
|
||||
template <typename U, ENABLE_IF(IS_FLOAT(U))> friend class PLUArray;
|
||||
template <typename U, std::enable_if_t<is_float<U> or is_complex<U>, bool>>
|
||||
friend class PLUArray;
|
||||
|
||||
EigenMapMat<T> mMapLU;
|
||||
EigenMapMat<int32_t> mMapPivots;
|
||||
@ -382,7 +410,7 @@ template <typename T> static PartialPivLU<T> BlankPPLU = PartialPivLU<T>();
|
||||
/**
|
||||
* Class for storing the PLU decomposition an Array. This is restricted to floating point types.
|
||||
*/
|
||||
template <typename T, ENABLE_IF(IS_FLOAT(T))> class PLUArray {
|
||||
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool>> class PLUArray {
|
||||
private:
|
||||
Array<T> mLU;
|
||||
Array<int32_t> mPivots;
|
||||
@ -443,7 +471,7 @@ template <typename T, ENABLE_IF(IS_FLOAT(T))> class PLUArray {
|
||||
* This is a batch version of PLUArray, to enable usage of the cuBLAS API. This is restricted to
|
||||
* floating point types.
|
||||
*/
|
||||
template <typename T, std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
|
||||
template <typename T, std::enable_if_t<is_float<T> or is_complex<T>, bool> = true>
|
||||
class PLUBatch : public Batch<T> {
|
||||
private:
|
||||
Array<int32_t> mPivotsBatch;
|
||||
@ -487,9 +515,10 @@ class PLUBatch : public Batch<T> {
|
||||
uint32_t n = this->mShape.rows();
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, Manager::get()->cublasHandle(), n,
|
||||
this->mBatch.dataDevice(), n, mPivotsBatch.dataDevice(), mInfoLU.dataDevice(),
|
||||
this->mBatchSize);
|
||||
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched,
|
||||
cublasZgetrfBatched, Manager::get()->cublasHandle(), n,
|
||||
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
||||
mInfoLU.dataDevice(), this->mBatchSize);
|
||||
|
||||
#else
|
||||
#pragma omp parallel for
|
||||
@ -518,9 +547,10 @@ class PLUBatch : public Batch<T> {
|
||||
uint32_t nrhs = b.shape().cols();
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, Manager::get()->cublasHandle(),
|
||||
CUBLAS_OP_N, n, nrhs, this->mBatch.dataDevice(), n, mPivotsBatch.dataDevice(),
|
||||
b.batch().dataDevice(), n, &mInfoSolve, this->mBatchSize);
|
||||
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched,
|
||||
cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs,
|
||||
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
||||
DCAST(b.batch().dataDevice()), n, &mInfoSolve, this->mBatchSize);
|
||||
|
||||
#else
|
||||
#pragma omp parallel for
|
||||
@ -554,46 +584,6 @@ class PLUBatch : public Batch<T> {
|
||||
int32_t validSolve() const { return mInfoSolve == 0; }
|
||||
};
|
||||
|
||||
// /**
|
||||
// * Gets the inverse of each A[i], using an already PLU factorized A[i].
|
||||
// * Only available if compiling with CUDA.
|
||||
// */
|
||||
// template <typename T>
|
||||
// void inverseBatch(const Array<T*>& batchA, const Array<T*>& batchC, const Array<int>&
|
||||
// pivots,
|
||||
// const Array<int>& info, const Shape shapeA, const Shape shapeC,
|
||||
// const uint stream = 0) {
|
||||
// #ifdef CUDA
|
||||
// CT_ERROR_IF(shapeA.rows(), !=, shapeA.cols(),
|
||||
// "'A' needs to be square, rows() and column need to match.");
|
||||
// CT_ERROR_IF(shapeA.rows(), !=, shapeC.cols(), "'A' needs to be the same shape as
|
||||
// 'C'."); CT_ERROR_IF(shapeA.rows(), !=, shapeC.rows(), "'A' needs to be the same shape
|
||||
// as 'C'.");
|
||||
|
||||
// CT_ERROR_IF(shapeA.rows(), !=, pivots.shape().rows(),
|
||||
// "Rows()/columns of 'A' and rows() of pivots need to match.");
|
||||
// CT_ERROR_IF(batchA.shape().rows(), !=, pivots.shape().cols(),
|
||||
// "Batch size and columns of pivots need to match.");
|
||||
// CT_ERROR_IF(info.shape().cols(), !=, 1, "Info needs to be a column vector.")
|
||||
// CT_ERROR_IF(batchA.shape().rows(), !=, info.shape().rows(),
|
||||
// "Batch size and length of info need to match.");
|
||||
// CT_ERROR_IF(batchA.shape().rows(), !=, batchC.shape().rows(),
|
||||
// "Batches 'A[i]' and 'C[i]' need to match.");
|
||||
|
||||
// std::string s = "cublas" + std::to_string(stream);
|
||||
// CUBLAS_CHECK(
|
||||
// cublasSetStream(Manager::get()->cublasHandle(),
|
||||
// Manager::get()->stream(s)));
|
||||
// invoke<T>(cublasSgetriBatched, cublasDgetriBatched,
|
||||
// Manager::get()->cublasHandle(),
|
||||
// shapeA.rows(), batchA.dataDevice(), shapeA.rows(), pivots.dataDevice(),
|
||||
// batchC.dataDevice(), shapeC.rows(), info.dataDevice(),
|
||||
// batchA.shape().rows());
|
||||
// #else
|
||||
// CT_ERROR_IF(true, ==, true, "inverseBatch is not callable without CUDA.");
|
||||
// #endif
|
||||
// }
|
||||
|
||||
}; // namespace BLAS
|
||||
}; // namespace CudaTools
|
||||
|
||||
|
||||
125
Complex.h
Normal file
125
Complex.h
Normal file
@ -0,0 +1,125 @@
|
||||
#ifndef CUDATOOLS_COMPLEX_H
|
||||
#define CUDATOOLS_COMPLEX_H
|
||||
|
||||
#include "Macros.h"
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
/**
|
||||
* This is directly adapated from cuComplex.h, except placed into a C++ friendly format.
|
||||
*/
|
||||
|
||||
namespace CudaTools {
|
||||
|
||||
template <typename T> class complex {
|
||||
private:
|
||||
T r = 0;
|
||||
T i = 0;
|
||||
|
||||
public:
|
||||
HD complex() = default;
|
||||
HD complex(T real, T imag) : r(real), i(imag){};
|
||||
HD complex(T x) : r(x), i(0){};
|
||||
|
||||
HD complex<T> operator+(const complex<T> z) const { return complex(r + z.r, i + z.i); };
|
||||
HD complex<T> operator-(const complex<T> z) const { return complex(r - z.r, i - z.i); };
|
||||
HD complex<T> operator*(const T y) const { return complex(r * y, i * y); };
|
||||
HD complex<T> operator/(const T y) const { return complex(r / y, i / y); };
|
||||
|
||||
HD complex<T> operator*(const complex<T> z) const {
|
||||
return complex(r * z.r - i * z.i, r * z.i + i * z.r);
|
||||
};
|
||||
HD complex<T> operator/(const complex<T> z) const {
|
||||
T s = std::abs(z.r) + std::abs(z.i);
|
||||
T oos = 1.0f / s;
|
||||
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
||||
s = (brs * brs) + (bis * bis);
|
||||
oos = 1.0f / s;
|
||||
return complex(ars * brs + ais * bis, ais * brs - ars * bis) * oos;
|
||||
};
|
||||
|
||||
HD void operator+=(const complex<T> z) {
|
||||
r += z.r;
|
||||
i += z.i;
|
||||
};
|
||||
HD void operator-=(const complex<T> z) {
|
||||
r -= z.r;
|
||||
i -= z.i;
|
||||
};
|
||||
HD void operator*=(const T y) {
|
||||
r *= y;
|
||||
i *= y;
|
||||
};
|
||||
HD void operator/=(const T y) {
|
||||
r /= y;
|
||||
i /= y;
|
||||
};
|
||||
|
||||
HD void operator*=(const complex<T> z) {
|
||||
T a = r * z.r - i * z.i, b = r * z.i + i * z.r;
|
||||
r = a;
|
||||
i = b;
|
||||
}
|
||||
|
||||
HD void operator/=(const complex<T> z) {
|
||||
T s = std::abs(z.r) + std::abs(z.i);
|
||||
T oos = 1.0f / s;
|
||||
T ars = r * oos, ais = i * oos, brs = z.r * oos, bis = z.i * oos;
|
||||
s = (brs * brs) + (bis * bis);
|
||||
oos = 1.0f / s;
|
||||
r = (ars * brs + ais * bis) * oos;
|
||||
i = (ais * brs - ars * bis) * oos;
|
||||
};
|
||||
|
||||
HD T abs() const {
|
||||
T a = std::abs(r), b = std::abs(i);
|
||||
T v, w;
|
||||
if (a > b) {
|
||||
v = a;
|
||||
w = b;
|
||||
} else {
|
||||
v = b;
|
||||
w = a;
|
||||
}
|
||||
T t = w / v;
|
||||
t = 1.0f + t * t;
|
||||
t = v * std::sqrt(t);
|
||||
if ((v == 0.0f) || (v > 3.402823466e38f) || (w > 3.402823466e38f)) {
|
||||
t = v + w;
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
HD complex<T> conj() const { return complex(r, -1 * i); }
|
||||
|
||||
HD T real() const { return r; };
|
||||
HD T imag() const { return i; };
|
||||
};
|
||||
|
||||
template class complex<real32>;
|
||||
template class complex<real64>;
|
||||
|
||||
template <class T> complex<T> operator*(const T y, const complex<T> z) { return z * y; };
|
||||
template <class T> complex<T> operator/(const T y, const complex<T> z) { return z / y; };
|
||||
|
||||
template complex<real32> operator*<real32>(const real32, const complex<real32>);
|
||||
template complex<real64> operator*<real64>(const real64, const complex<real64>);
|
||||
template complex<real32> operator/<real32>(const real32, const complex<real32>);
|
||||
template complex<real64> operator/<real64>(const real64, const complex<real64>);
|
||||
|
||||
}; // namespace CudaTools
|
||||
|
||||
#ifdef CUDA
|
||||
using complex64 = CudaTools::complex<real32>;
|
||||
using complex128 = CudaTools::complex<real64>;
|
||||
#else
|
||||
using complex64 = std::complex<real32>; /**< Type alias for 64-bit complex floating point datatype.
|
||||
* This adapts depending on the CUDA compilation flag, and
|
||||
* will automatically switch CudaTools::complex<real32>. */
|
||||
using complex128 =
|
||||
std::complex<real64>; /**< Type alias for 128-bit complex floating point datatype. This adapts
|
||||
* depending on the CUDA compilation flag, and will automatically switch
|
||||
* CudaTools::complex<real64>. */
|
||||
#endif
|
||||
|
||||
#endif
|
||||
3
Macros.h
3
Macros.h
@ -9,6 +9,9 @@
|
||||
#define CUDACC
|
||||
#endif
|
||||
|
||||
using real32 = float; /**< Type alias for 32-bit floating point datatype. */
|
||||
using real64 = double; /**< Type alias for 64-bit floating point datatype. */
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 0)
|
||||
#define DEVICE
|
||||
#endif
|
||||
|
||||
95
Makefile.template
Normal file
95
Makefile.template
Normal file
@ -0,0 +1,95 @@
|
||||
CC := g++-10
|
||||
NVCC := nvcc
|
||||
CFLAGS := -Wall -std=c++17 -fopenmp -MMD
|
||||
NVCC_FLAGS := -MMD -w -Xcompiler
|
||||
|
||||
INCLUDE := <<Put extra include directories here, separated by a space>>
|
||||
LIBS_DIR := <<Put library directories here, separated by a space>>
|
||||
LIBS_DIR_GPU := /usr/local/cuda/lib64 <<Put extra include GPU library directories here, separated by a space>>
|
||||
LIBS := <<Put the names of the libraries here, separated by a space>>
|
||||
LIBS_GPU := cuda cudart cublas <<Put extra GPU libraries here, separated by a space>>
|
||||
|
||||
TARGET = <<Put the name of your target here>>
|
||||
SRC_DIR = .
|
||||
BUILD_DIR = build
|
||||
|
||||
# Should not need to modify below.
|
||||
|
||||
CPU_BUILD_DIR = $(BUILD_DIR)/cpu
|
||||
GPU_BUILD_DIR = $(BUILD_DIR)/gpu
|
||||
|
||||
SRC = $(wildcard $(SRC_DIR)/*/*.cpp) $(wildcard $(SRC_DIR)/*.cpp)
|
||||
|
||||
# Get source files and object files.
|
||||
GCC_SRC = $(filter-out %.cu.cpp ,$(SRC))
|
||||
NVCC_SRC = $(filter %.cu.cpp, $(SRC))
|
||||
GCC_OBJ = $(GCC_SRC:$(SRC_DIR)/%.cpp=%.o)
|
||||
NVCC_OBJ = $(NVCC_SRC:$(SRC_DIR)/%.cpp=%.o)
|
||||
|
||||
# If compiling for CPU, all go to GCC. Otherwise, they are split.
|
||||
CPU_OBJ = $(addprefix $(CPU_BUILD_DIR)/,$(GCC_OBJ)) $(addprefix $(CPU_BUILD_DIR)/,$(NVCC_OBJ))
|
||||
GPU_GCC_OBJ = $(addprefix $(GPU_BUILD_DIR)/,$(GCC_OBJ))
|
||||
GPU_NVCC_OBJ = $(addprefix $(GPU_BUILD_DIR)/,$(NVCC_OBJ))
|
||||
|
||||
# $(info $$GCC_SRC is [${GCC_SRC}])
|
||||
# $(info $$NVCC_SRC is [${NVCC_SRC}])
|
||||
# $(info $$GCC_OBJ is [${GCC_OBJ}])
|
||||
# $(info $$NVCC_OBJ is [${NVCC_OBJ}])
|
||||
|
||||
# $(info $$CPU_OBJ is [${CPU_OBJ}])
|
||||
# $(info $$GPU_GCC_OBJ is [${GPU_GCC_OBJ}])
|
||||
# $(info $$GPU_NVCC_OBJ is [${GPU_NVCC_OBJ}])
|
||||
|
||||
HEADER = $(wildcard $(SRC_DIR)/*/*.h) $(wildcard $(SRC_DIR)/*.h)
|
||||
CPU_DEPS = $(wildcard $(CPU_BUILD_DIR)/*.d)
|
||||
GPU_DEPS = $(wildcard $(GPU_BUILD_DIR)/*.d)
|
||||
|
||||
INC := $(INCLUDE:%=-I%)
|
||||
LIB := $(LIBS_DIR:%=-L%)
|
||||
LIB_GPU := $(LIBS_DIR_GPU:%=-L%)
|
||||
LD := $(LIBS:%=-l%)
|
||||
LD_GPU := $(LIBS_GPU:%=-l%)
|
||||
|
||||
# Reminder:
|
||||
# $< = first prerequisite
|
||||
# $@ = the target which matched the rule
|
||||
# $^ = all prerequisites
|
||||
|
||||
.PHONY: all clean
|
||||
|
||||
all : cpu gpu
|
||||
|
||||
cpu: $(TARGET)CPU
|
||||
gpu: $(TARGET)GPU
|
||||
|
||||
$(TARGET)CPU: $(CPU_OBJ)
|
||||
$(CC) $(CFLAGS) $^ -o $@ $(INC) $(LIB) $(LDFLAGS)
|
||||
|
||||
$(CPU_BUILD_DIR)/%.o $(CPU_BUILD_DIR)/%.cu.o: $(SRC_DIR)/%.cpp | $(CPU_BUILD_DIR)
|
||||
$(CC) $(CFLAGS) -c -o $@ $< $(INC)
|
||||
|
||||
# For GPU, we need to build the NVCC objects, the NVCC linked object, and the
|
||||
# regular ones. Then, we link them all together.
|
||||
$(TARGET)GPU: $(GPU_BUILD_DIR)/link.o $(GPU_GCC_OBJ) | $(GPU_BUILD_DIR)
|
||||
$(CC) -g -DCUDA $(CFLAGS) $(GPU_NVCC_OBJ) $^ -o $@ $(INC) $(LIB) $(LIB_GPU) $(LD) $(LD_GPU)
|
||||
|
||||
$(GPU_BUILD_DIR)/link.o: $(GPU_NVCC_OBJ) | $(GPU_BUILD_DIR)
|
||||
$(NVCC) --device-link $^ -o $@
|
||||
|
||||
$(GPU_BUILD_DIR)/%.cu.o: $(SRC_DIR)/%.cu.cpp | $(GPU_BUILD_DIR)
|
||||
$(NVCC) $(NVCC_FLAGS) -DCUDA -x cu --device-c -o $@ $< $(INC)
|
||||
|
||||
$(GPU_BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp | $(GPU_BUILD_DIR)
|
||||
$(CC) $(CFLAGS) -g -DCUDA -c -o $@ $< $(INC)
|
||||
|
||||
-include $(CPU_DEPS)
|
||||
-include $(GPU_DEPS)
|
||||
|
||||
$(CPU_BUILD_DIR):
|
||||
mkdir -p $@
|
||||
|
||||
$(GPU_BUILD_DIR):
|
||||
mkdir -p $@
|
||||
|
||||
clean:
|
||||
rm -Rf $(BUILD_DIR) $(TARGET)CPU $(TARGET)GPU
|
||||
@ -2,12 +2,20 @@
|
||||
Core.h
|
||||
======
|
||||
|
||||
The ``Core.h`` header file defines several compiler flags and macros along with
|
||||
The ``Core.h`` header file defines some useful types and some macros along with
|
||||
a few core classes.
|
||||
|
||||
Flags
|
||||
Types
|
||||
=====
|
||||
|
||||
.. doxygentypedef:: real32
|
||||
.. doxygentypedef:: real64
|
||||
.. doxygentypedef:: complex64
|
||||
.. doxygentypedef:: complex128
|
||||
|
||||
Macro Definitions
|
||||
=================
|
||||
|
||||
Device Indicators
|
||||
-----------------
|
||||
.. doxygendefine:: CUDACC
|
||||
@ -22,8 +30,8 @@ Compilation Options
|
||||
-------------------
|
||||
.. doxygendefine:: CUDATOOLS_ARRAY_MAX_AXES
|
||||
|
||||
Macros
|
||||
======
|
||||
Macro Functions
|
||||
===============
|
||||
|
||||
Kernel
|
||||
------
|
||||
|
||||
@ -10,6 +10,7 @@ compilation and linking framework:
|
||||
#. :ref:`Array Examples`
|
||||
#. :ref:`BLAS Examples`
|
||||
#. :ref:`Compilation and Linking`
|
||||
#. :ref:`Notes`
|
||||
|
||||
The ``Core.h`` header contains the necessary macros, flags and objects for interfacing with
|
||||
basic kernel launching and the CUDA Runtime API. The ``Array.h`` header contains the ``CudaTools::Array``
|
||||
@ -47,7 +48,7 @@ kernel. The launch parameters have several items, but for 'embarassingly paralle
|
||||
cases, we can simply generate the settings with the number of threads. More detail with
|
||||
creating launch parameters can be found :ref:`here <CudaTools::Kernel::Settings>`. In the above example,
|
||||
there is only one thread. The rest of the arguments are just the kernel arguments. For more detail,
|
||||
see :ref:`here <Macros>`.
|
||||
see :ref:`here <Macro Functions>`.
|
||||
|
||||
.. warning::
|
||||
These kernel definitions must be in a file that will be compiled by ``nvcc``. Also,
|
||||
@ -297,3 +298,20 @@ file for the first example:
|
||||
The lines above are the first few lines of the ``Makefile``, which are the only
|
||||
lines you should need to modify, consisting of libraries and flags, as well as
|
||||
the name of the target.
|
||||
|
||||
Notes
|
||||
=====
|
||||
|
||||
Complex Numbers
|
||||
---------------
|
||||
Dealing with complex numbers is slightly complicated, trying to enforce compatability between
|
||||
two systems and several different libraries which many not have the right support. We
|
||||
create a simple barebones host and device compatible complex number class following
|
||||
the same as ``cuComplex.h``, but with proper C++ operator overloading and class structure. However,
|
||||
while the underlying data structure is identical to all other complex number structures, there
|
||||
is a lot of type-casting done underneath the hood to get cuBLAS and Eigen to work well
|
||||
together, while maintaining one 'unified' complex type.
|
||||
|
||||
As a result, there could be some issues and lack of functionality with this at the moment.
|
||||
For now, it's recommended to use the given ``complex64`` and ``complex128`` types which
|
||||
should properly adapt and work.
|
||||
|
||||
19
tests.cu.cpp
19
tests.cu.cpp
@ -2,6 +2,7 @@
|
||||
#define CUDATOOLS_ARRAY_MAX_AXES 8
|
||||
#include "Array.h"
|
||||
#include "BLAS.h"
|
||||
#include "Complex.h"
|
||||
#include "Core.h"
|
||||
|
||||
#include <Eigen/Core>
|
||||
@ -47,8 +48,10 @@ template <typename T> struct Type;
|
||||
REGISTER_PARSE_TYPE(uint8_t);
|
||||
REGISTER_PARSE_TYPE(int16_t);
|
||||
REGISTER_PARSE_TYPE(int32_t);
|
||||
REGISTER_PARSE_TYPE(float);
|
||||
REGISTER_PARSE_TYPE(double);
|
||||
REGISTER_PARSE_TYPE(real32);
|
||||
REGISTER_PARSE_TYPE(real64);
|
||||
REGISTER_PARSE_TYPE(complex64);
|
||||
REGISTER_PARSE_TYPE(complex128);
|
||||
|
||||
std::string box(std::string str) {
|
||||
std::string tops(str.size() + 6, '#');
|
||||
@ -433,6 +436,8 @@ template <typename T> struct BLASTests {
|
||||
|
||||
template <> double BLASTests<float>::thres = 10e-1;
|
||||
template <> double BLASTests<double>::thres = 10e-8;
|
||||
template <> double BLASTests<complex64>::thres = 10e-1;
|
||||
template <> double BLASTests<complex128>::thres = 10e-8;
|
||||
|
||||
uint32_t doMacroTests() {
|
||||
uint32_t failed = 0;
|
||||
@ -478,13 +483,15 @@ int main() {
|
||||
failed += doArrayTests<uint8_t>();
|
||||
failed += doArrayTests<int16_t>();
|
||||
failed += doArrayTests<int32_t>();
|
||||
failed += doArrayTests<double>();
|
||||
failed += doArrayTests<real64>();
|
||||
|
||||
std::cout << box("BLAS Tests") << "\n";
|
||||
failed += doBLASTests<float>();
|
||||
failed += doBLASTests<double>();
|
||||
failed += doBLASTests<real32>();
|
||||
failed += doBLASTests<real64>();
|
||||
failed += doBLASTests<complex64>();
|
||||
failed += doBLASTests<complex128>();
|
||||
|
||||
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 2;
|
||||
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4;
|
||||
std::ostringstream msg;
|
||||
msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(")
|
||||
<< (tests - failed) << "/" << tests << ")";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user