Fixed compilation macro name issue
This commit is contained in:
parent
9c66a61288
commit
1248a58265
4
BLAS.h
4
BLAS.h
@ -138,7 +138,7 @@ template <typename T> class Batch {
|
|||||||
|
|
||||||
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
|
Array<T> batch = arr.reshaped({mBatchSize, mShape.rows(), mShape.cols()});
|
||||||
for (uint32_t i = 0; i < mBatchSize; ++i) {
|
for (uint32_t i = 0; i < mBatchSize; ++i) {
|
||||||
#ifdef CUDA
|
#ifdef CUDACC
|
||||||
mBatch[i] = batch[i].dataDevice();
|
mBatch[i] = batch[i].dataDevice();
|
||||||
#else
|
#else
|
||||||
mBatch[i] = batch[i].data();
|
mBatch[i] = batch[i].data();
|
||||||
@ -154,7 +154,7 @@ template <typename T> class Batch {
|
|||||||
void add(const Array<T>& arr) {
|
void add(const Array<T>& arr) {
|
||||||
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
|
CT_ERROR(not arr.isView(), "Cannot add non-view Arrays");
|
||||||
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
|
CT_ERROR_IF(mCount, ==, mBatchSize, "Batch is full, cannot add more arrays");
|
||||||
#ifdef CUDA
|
#ifdef CUDACC
|
||||||
mBatch[mCount] = arr.dataDevice();
|
mBatch[mCount] = arr.dataDevice();
|
||||||
#else
|
#else
|
||||||
mBatch[mCount] = arr.data();
|
mBatch[mCount] = arr.data();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user