Fixed compilation macro name issue

This commit is contained in:
Kenneth Jao 2023-06-02 14:47:00 -05:00 committed by Kenneth Jao
parent 9c66a61288
commit 1248a58265

4
BLAS.h
View File

@ -138,7 +138,7 @@ template <typename T> class Batch {
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()}); Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
for (uint32_t i = 0; i < mBatchSize; ++i) { for (uint32_t i = 0; i < mBatchSize; ++i) {
#ifdef CUDA #ifdef CUDACC
mBatch[i] = batch[i].dataDevice(); mBatch[i] = batch[i].dataDevice();
#else #else
mBatch[i] = batch[i].data(); mBatch[i] = batch[i].data();
@ -154,7 +154,7 @@ template <typename T> class Batch {
void add(const Array<T>& arr) { void add(const Array<T>& arr) {
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays"); CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays"); CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
#ifdef CUDA #ifdef CUDACC
mBatch[mCount] = arr.dataDevice(); mBatch[mCount] = arr.dataDevice();
#else #else
mBatch[mCount] = arr.data(); mBatch[mCount] = arr.data();