Added CUDA Graphs support
This commit is contained in:
parent
1c439b4944
commit
167edfea44
11
Array.h
11
Array.h
@ -7,6 +7,7 @@
|
||||
#include <Eigen/Dense>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
@ -788,12 +789,16 @@ void printAxis(std::ostream& out, const Array<T>& arr, const uint32_t axis, size
|
||||
|
||||
template <typename T> std::ostream& operator<<(std::ostream& out, const Array<T>& arr) {
|
||||
size_t width = 0;
|
||||
if constexpr (is_num<T>) {
|
||||
if constexpr (is_int<T>) {
|
||||
T max_val = 0;
|
||||
bool negative = false;
|
||||
for (auto it = arr.begin(); it != arr.end(); ++it) {
|
||||
if (*it < 0) negative = true;
|
||||
max_val = (abs(*it) > max_val) ? abs(*it) : max_val;
|
||||
T val = *it;
|
||||
if (*it < 0) {
|
||||
negative = true;
|
||||
val *= -1;
|
||||
}
|
||||
max_val = (val > max_val) ? val : max_val;
|
||||
}
|
||||
width = std::to_string(max_val).size() + 1;
|
||||
width += (negative) ? 1 : 0;
|
||||
|
||||
13
BLAS.h
13
BLAS.h
@ -235,8 +235,7 @@ StreamID GEMV(const T alpha, const Array<T>& A, const Array<T>& x, const T beta,
|
||||
uint32_t cols = A.shape().cols();
|
||||
T a = alpha, b = beta;
|
||||
#ifdef CUDA
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
||||
if (bi.size == 1) {
|
||||
invoke<T>(cublasSgemv, cublasDgemv, cublasCgemv, cublasZgemv,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, rows, cols, CAST(&a),
|
||||
@ -282,8 +281,7 @@ StreamID GEMM(const T alpha, const Array<T>& A, const Array<T>& B, const T beta,
|
||||
|
||||
T a = alpha, b = beta;
|
||||
#ifdef CUDA
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
||||
if (bi.size == 1) {
|
||||
invoke<T>(cublasSgemm, cublasDgemm, cublasCgemm, cublasZgemm,
|
||||
Manager::get()->cublasHandle(), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, CAST(&a),
|
||||
@ -338,8 +336,7 @@ StreamID DGMM(const Array<T>& A, const Array<T>& X, const Array<T>& C, const boo
|
||||
uint32_t m = C.shape().rows();
|
||||
uint32_t n = C.shape().cols();
|
||||
auto mode = (left) ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
CUBLAS_CHECK(cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
||||
invoke<T>(cublasSdgmm, cublasDdgmm, cublasCdgmm, cublasZdgmm, Manager::get()->cublasHandle(), m,
|
||||
n, CAST(A.dataDevice()), A.shape().rows(), CAST(X.dataDevice()), 1,
|
||||
CAST(C.dataDevice()), m);
|
||||
@ -514,7 +511,7 @@ class PLUBatch : public Batch<T> {
|
||||
#ifdef CUDA
|
||||
uint32_t n = this->mShape.rows();
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
||||
invoke<T>(cublasSgetrfBatched, cublasDgetrfBatched, cublasCgetrfBatched,
|
||||
cublasZgetrfBatched, Manager::get()->cublasHandle(), n,
|
||||
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
||||
@ -546,7 +543,7 @@ class PLUBatch : public Batch<T> {
|
||||
uint32_t n = b.shape().rows();
|
||||
uint32_t nrhs = b.shape().cols();
|
||||
CUBLAS_CHECK(
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream.id)));
|
||||
cublasSetStream(Manager::get()->cublasHandle(), Manager::get()->stream(stream)));
|
||||
invoke<T>(cublasSgetrsBatched, cublasDgetrsBatched, cublasCgetrsBatched,
|
||||
cublasZgetrsBatched, Manager::get()->cublasHandle(), CUBLAS_OP_N, n, nrhs,
|
||||
DCAST(this->mBatch.dataDevice()), n, mPivotsBatch.dataDevice(),
|
||||
|
||||
235
Core.h
235
Core.h
@ -2,13 +2,16 @@
|
||||
#define CUDATOOLS_H
|
||||
|
||||
#include "Macros.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace CudaTools {
|
||||
|
||||
struct Event;
|
||||
/**
|
||||
* Simple wrapper for the name of a stream. Its purposes is to allow for
|
||||
* 'streams' to be passed on host code, and allowing for simple syntax
|
||||
@ -16,18 +19,19 @@ namespace CudaTools {
|
||||
*/
|
||||
struct StreamID {
|
||||
public:
|
||||
std::string id;
|
||||
StreamID() : id(""){};
|
||||
std::string mId;
|
||||
StreamID() : mId(""){};
|
||||
/**
|
||||
* The constructor for a StreamID.
|
||||
*/
|
||||
StreamID(const std::string& id_) : id(id_){};
|
||||
StreamID(const char* id_) : id(id_){};
|
||||
StreamID(const std::string& id_) : mId(id_){};
|
||||
StreamID(const char* id_) : mId(id_){};
|
||||
|
||||
void wait() const; /**< Makes host wait for this stream. */
|
||||
/**
|
||||
* Waits for the stream with this stream ID.
|
||||
* Makes this stream wait for this event. Does not block the host.
|
||||
*/
|
||||
void wait() const;
|
||||
void wait(const Event& event) const;
|
||||
};
|
||||
|
||||
static const StreamID DEF_MEM_STREAM = StreamID{"defaultMemory"};
|
||||
@ -137,6 +141,20 @@ struct Settings {
|
||||
*/
|
||||
Settings basic(const size_t threads, const StreamID& stream = DEF_KERNEL_STREAM);
|
||||
|
||||
/**
|
||||
* Launches a kernel with the provided function, settings and its arguments.
|
||||
*/
|
||||
|
||||
template <typename F, typename... Args>
|
||||
StreamID launch(F func, const Kernel::Settings& sett, Args... args) {
|
||||
#ifdef CUDA
|
||||
func<<<sett.blockGrid, sett.threadBlock, sett.sharedMemoryBytes,
|
||||
Manager::get()->stream(sett.stream.mId)>>>(args...);
|
||||
#else
|
||||
func(args...);
|
||||
#endif
|
||||
return sett.stream;
|
||||
}
|
||||
}; // namespace Kernel
|
||||
|
||||
template <typename T> class Array;
|
||||
@ -186,29 +204,143 @@ class Shape {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& s);
|
||||
|
||||
/**
|
||||
* A simple class that manages a CUDA Event.
|
||||
*/
|
||||
struct Event {
|
||||
#ifdef CUDACC
|
||||
cudaEvent_t mEvent;
|
||||
#endif
|
||||
Event();
|
||||
~Event();
|
||||
void record(const StreamID& stream); /**< Records a event from a stream. */
|
||||
};
|
||||
|
||||
template <typename F, typename... Args> struct FuncHolder {
|
||||
F mFunc;
|
||||
std::tuple<Args...> mArgs;
|
||||
FuncHolder() = delete;
|
||||
FuncHolder(F func, Args... args) : mFunc(func), mArgs(std::make_tuple(args...)){};
|
||||
static void run(void* data) {
|
||||
FuncHolder<F, Args...>* fh = (FuncHolder<F, Args...>*)(data);
|
||||
std::apply([fh](auto&&... args) { fh->mFunc(args...); }, fh->mArgs);
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Accessory struct to deal with host callbacks for CUDA Graphs in a nice fashion.
|
||||
*/
|
||||
struct GraphTools {
|
||||
std::vector<void*> mHostData;
|
||||
std::vector<Event*> mEvents;
|
||||
|
||||
~GraphTools();
|
||||
|
||||
/**
|
||||
* Within a function that is being stream captured, launch a host function that can
|
||||
* be captured into the graph.
|
||||
*/
|
||||
|
||||
template <typename F, typename... Args>
|
||||
void launchHostFunction(const StreamID& stream, F func, Args&&... args) {
|
||||
#ifdef CUDACC
|
||||
FuncHolder<F, Args...>* fh = new FuncHolder<F, Args...>(func, args...);
|
||||
mHostData.push_back((void*)fh);
|
||||
cudaHostFn_t run_func = fh->run;
|
||||
CUDA_CHECK(cudaLaunchHostFunc(Manager::get()->stream(stream), run_func, fh));
|
||||
#else
|
||||
func(args...);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes a new branch in the graph to be run in parallel by a new stream.
|
||||
* \param orig_stream the original stream to branch from.
|
||||
* \param branch_stream the stream of the new branch.
|
||||
*/
|
||||
void makeBranch(const StreamID& orig_stream, const StreamID& branch_stream);
|
||||
/**
|
||||
* Joins a existing branch in the graph to collapse a parallel block.
|
||||
* \param orig_stream the original stream to join the branch to.
|
||||
* \param branch_stream the stream of the branch to join.
|
||||
*/
|
||||
void joinBranch(const StreamID& orig_stream, const StreamID& branch_stream);
|
||||
};
|
||||
|
||||
/**
|
||||
* A class that manages CUDA Graphs.
|
||||
*/
|
||||
template <typename F, typename... Args> class Graph {
|
||||
private:
|
||||
#ifdef CUDACC
|
||||
cudaGraph_t mGraph;
|
||||
cudaGraphExec_t mInstance;
|
||||
#endif
|
||||
FuncHolder<F, Args...> mFuncHolder;
|
||||
StreamID mStream;
|
||||
|
||||
public:
|
||||
Graph() = delete;
|
||||
/**
|
||||
* The constructor for a Graph, which captures the function.
|
||||
* \param func the function to capture.
|
||||
* \param stream the origin stream to use.
|
||||
* \param args the arguments of the function.
|
||||
*/
|
||||
Graph(const StreamID& stream, F func, Args... args)
|
||||
: mFuncHolder(func, args...), mStream(stream) {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(
|
||||
cudaStreamBeginCapture(Manager::get()->stream(mStream), cudaStreamCaptureModeGlobal));
|
||||
mFuncHolder.run((void*)&mFuncHolder);
|
||||
CUDA_CHECK(cudaStreamEndCapture(Manager::get()->stream(mStream), &mGraph));
|
||||
CUDA_CHECK(cudaGraphInstantiate(&mInstance, mGraph, NULL, NULL, 0));
|
||||
#endif
|
||||
};
|
||||
|
||||
~Graph() {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaGraphDestroy(mGraph));
|
||||
CUDA_CHECK(cudaGraphExecDestroy(mInstance));
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
* Executes the instantiated graph, or simply runs the function with provided
|
||||
* arguments if compiling for CPU.
|
||||
*/
|
||||
StreamID execute() const {
|
||||
#ifdef CUDACC
|
||||
cudaGraphLaunch(mInstance, Manager::get()->stream(mStream));
|
||||
#else
|
||||
mFuncHolder.run((void*)&mFuncHolder);
|
||||
#endif
|
||||
return mStream;
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace CudaTools
|
||||
|
||||
#ifdef CUDATOOLS_IMPLEMENTATION
|
||||
|
||||
namespace CudaTools {
|
||||
|
||||
template <typename T, typename... Args>
|
||||
StreamID runKernel(T func, const Kernel::Settings& sett, Args... args) {
|
||||
#ifdef CUDA
|
||||
func<<<sett.blockGrid, sett.threadBlock, sett.sharedMemoryBytes,
|
||||
Manager::get()->stream(sett.stream.id)>>>(args...);
|
||||
#else
|
||||
func(args...);
|
||||
//////////////////////
|
||||
// StreamID Methods //
|
||||
//////////////////////
|
||||
|
||||
void StreamID::wait() const { Manager::get()->waitFor(mId); }
|
||||
|
||||
void StreamID::wait(const Event& event) const {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaStreamWaitEvent(Manager::get()->stream(mId), event.mEvent, 0));
|
||||
#endif
|
||||
return sett.stream;
|
||||
}
|
||||
|
||||
////////////////////
|
||||
// Memory Methods //
|
||||
////////////////////
|
||||
|
||||
void StreamID::wait() const { Manager::get()->waitFor(id); }
|
||||
|
||||
void* malloc(const size_t size) {
|
||||
#ifdef CUDACC
|
||||
void* pDevice;
|
||||
@ -228,7 +360,7 @@ void free(void* const pDevice) {
|
||||
StreamID push(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaMemcpyAsync(pDevice, pHost, size, cudaMemcpyHostToDevice,
|
||||
Manager::get()->stream(stream.id)));
|
||||
Manager::get()->stream(stream)));
|
||||
#endif
|
||||
return stream;
|
||||
}
|
||||
@ -236,7 +368,7 @@ StreamID push(void* const pHost, void* const pDevice, const size_t size, const S
|
||||
StreamID pull(void* const pHost, void* const pDevice, const size_t size, const StreamID& stream) {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaMemcpyAsync(pHost, pDevice, size, cudaMemcpyDeviceToHost,
|
||||
Manager::get()->stream(stream.id)));
|
||||
Manager::get()->stream(stream)));
|
||||
#endif
|
||||
return stream;
|
||||
}
|
||||
@ -245,7 +377,7 @@ StreamID deviceCopy(void* const pSrc, void* const pDest, const size_t size,
|
||||
const StreamID& stream) {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaMemcpyAsync(pDest, pSrc, size, cudaMemcpyDeviceToDevice,
|
||||
Manager::get()->stream(stream.id)));
|
||||
Manager::get()->stream(stream)));
|
||||
#endif
|
||||
return stream;
|
||||
}
|
||||
@ -289,11 +421,11 @@ Manager::~Manager() {
|
||||
|
||||
void Manager::waitFor(const StreamID& stream) const {
|
||||
#ifdef CUDACC
|
||||
auto it = mStreams.find(stream.id);
|
||||
auto it = mStreams.find(stream.mId);
|
||||
if (it != mStreams.end()) {
|
||||
CUDA_CHECK(cudaStreamSynchronize(it->second));
|
||||
} else {
|
||||
CT_ERROR(true, ("Invalid stream " + stream.id).c_str());
|
||||
CT_ERROR(true, ("Invalid stream " + stream.mId).c_str());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -314,11 +446,11 @@ void Manager::addStream(const std::string& name) {
|
||||
|
||||
#ifdef CUDACC
|
||||
cudaStream_t Manager::stream(const StreamID& stream) const {
|
||||
auto it = mStreams.find(stream.id);
|
||||
auto it = mStreams.find(stream.mId);
|
||||
if (it != mStreams.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
CT_ERROR(true, ("Invalid stream " + stream.id).c_str());
|
||||
CT_ERROR(true, ("Invalid stream " + stream.mId).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -407,7 +539,7 @@ void Settings::setSharedMemSize(const size_t bytes) {
|
||||
|
||||
void Settings::setStream(const StreamID& stream_) {
|
||||
#ifdef CUDACC
|
||||
stream.id = stream_.id;
|
||||
stream = stream_;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -425,7 +557,8 @@ Settings basic(const size_t threads, const StreamID& stream) {
|
||||
#endif
|
||||
return sett;
|
||||
}
|
||||
} // namespace Kernel
|
||||
|
||||
}; // namespace Kernel
|
||||
|
||||
/////////////////////
|
||||
// Shape Functions //
|
||||
@ -506,6 +639,57 @@ std::ostream& operator<<(std::ostream& out, const Shape& s) {
|
||||
return out << s.dim(s.axes() - 1) << ")";
|
||||
}
|
||||
|
||||
///////////////////
|
||||
// Event Methods //
|
||||
///////////////////
|
||||
|
||||
Event::Event() {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaEventCreate(&mEvent));
|
||||
#endif
|
||||
}
|
||||
|
||||
Event::~Event() {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaEventDestroy(mEvent));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Event::record(const StreamID& stream) {
|
||||
#ifdef CUDACC
|
||||
CUDA_CHECK(cudaEventRecord(mEvent, Manager::get()->stream(stream)));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////
|
||||
// GraphTools Methods //
|
||||
////////////////////////
|
||||
|
||||
GraphTools::~GraphTools() {
|
||||
#ifdef CUDACC
|
||||
for (void* func : mHostData) {
|
||||
delete func;
|
||||
}
|
||||
for (Event* event : mEvents) {
|
||||
delete event;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void GraphTools::makeBranch(const StreamID& orig_stream, const StreamID& branch_stream) {
|
||||
Event* event = new Event();
|
||||
event->record(orig_stream);
|
||||
mEvents.push_back(event);
|
||||
branch_stream.wait(*event);
|
||||
}
|
||||
|
||||
void GraphTools::joinBranch(const StreamID& orig_stream, const StreamID& branch_stream) {
|
||||
Event* event = new Event();
|
||||
event->record(branch_stream);
|
||||
mEvents.push_back(event);
|
||||
orig_stream.wait(*event);
|
||||
}
|
||||
|
||||
#ifdef CUDACC
|
||||
const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
@ -537,7 +721,6 @@ const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
}; // namespace CudaTools
|
||||
#endif // CUDATOOLS_IMPLEMENTATION
|
||||
|
||||
|
||||
22
Macros.h
22
Macros.h
@ -145,27 +145,17 @@ using real64 = double; /**< Type alias for 64-bit floating point datatype. */
|
||||
#define HD __host__ __device__
|
||||
#define SHARED __shared__
|
||||
|
||||
#define DECLARE_KERNEL(call, ...) __global__ void call(__VA_ARGS__)
|
||||
|
||||
#define DEFINE_KERNEL(call, ...) \
|
||||
template CudaTools::StreamID CudaTools::runKernel( \
|
||||
void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__); \
|
||||
__global__ void call(__VA_ARGS__)
|
||||
#define KERNEL(call, ...) __global__ void call(__VA_ARGS__)
|
||||
|
||||
#else
|
||||
#define HD
|
||||
#define SHARED
|
||||
|
||||
#define DECLARE_KERNEL(call, ...) void call(__VA_ARGS__)
|
||||
|
||||
#define DEFINE_KERNEL(call, ...) \
|
||||
template CudaTools::StreamID CudaTools::runKernel( \
|
||||
void (*)(__VA_ARGS__), const CudaTools::Kernel::Settings&, __VA_ARGS__); \
|
||||
void call(__VA_ARGS__)
|
||||
#define KERNEL(call, ...) void call(__VA_ARGS__)
|
||||
|
||||
#endif // CUDACC
|
||||
|
||||
#define KERNEL(call, settings, ...) CudaTools::runKernel(call, settings, __VA_ARGS__)
|
||||
//#define KERNEL(call, settings, ...) CudaTools::runKernel(call, settings, __VA_ARGS__)
|
||||
|
||||
///////////////////
|
||||
// DEVICE MACROS //
|
||||
@ -218,8 +208,10 @@ using real64 = double; /**< Type alias for 64-bit floating point datatype. */
|
||||
#ifndef CUDATOOLS_ARRAY_MAX_AXES
|
||||
/**
|
||||
* \def CUDATOOLS_ARRAY_MAX_AXES
|
||||
* The maximum number of axes/dimensions an CudaTools::Array can have. The default is
|
||||
* set to 4, but can be manully set fit the program needs.
|
||||
* The maximum number of axes/dimensions an
|
||||
* CudaTools::Array can have. The default is set
|
||||
* to 4, but can be manully set fit the program
|
||||
* needs.
|
||||
*/
|
||||
#define CUDATOOLS_ARRAY_MAX_AXES 4
|
||||
#endif
|
||||
|
||||
2
Makefile
2
Makefile
@ -1,7 +1,7 @@
|
||||
CC := g++-10
|
||||
NVCC := nvcc
|
||||
CFLAGS := -Wall -std=c++17 -fopenmp -MMD
|
||||
NVCC_FLAGS := -MMD -w -Xcompiler
|
||||
NVCC_FLAGS := -MMD -std=c++17 -w -Xcompiler
|
||||
|
||||
INCLUDE :=
|
||||
LIBS_DIR :=
|
||||
|
||||
@ -31,7 +31,7 @@ After installing the required Python packages
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install -r requirements
|
||||
$ pip install -r requirements.txt
|
||||
|
||||
you can now run the script
|
||||
|
||||
|
||||
78
tests.cu.cpp
78
tests.cu.cpp
@ -97,18 +97,36 @@ class TestClass {
|
||||
};
|
||||
};
|
||||
|
||||
DEFINE_KERNEL(times, const CT::Array<int> arr) {
|
||||
KERNEL(times, const CT::Array<int> arr) {
|
||||
BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; }
|
||||
}
|
||||
|
||||
DEFINE_KERNEL(classTest, TestClass* const test) { test->x = 100; }
|
||||
KERNEL(classTest, TestClass* const test) { test->x = 100; }
|
||||
|
||||
KERNEL(collatz, const CT::Array<uint32_t> arr) {
|
||||
BASIC_LOOP(arr.shape().length()) {
|
||||
if (arr[iThread] % 2) {
|
||||
arr[iThread] = 3 * arr[iThread] + 1;
|
||||
} else {
|
||||
arr[iThread] = arr[iThread] >> 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
KERNEL(plusOne, const CT::Array<uint32_t> arr) {
|
||||
BASIC_LOOP(arr.shape().length()) { arr[iThread] += 1; }
|
||||
}
|
||||
|
||||
KERNEL(addBoth, const CT::Array<uint32_t> a, const CT::Array<uint32_t> b) {
|
||||
BASIC_LOOP(a.shape().length()) { a[iThread] += b[iThread]; }
|
||||
}
|
||||
|
||||
struct MacroTests {
|
||||
static uint32_t Kernel() {
|
||||
uint32_t failed = 0;
|
||||
CT::Array<int> A = CT::Array<int>::constant({10}, 1);
|
||||
A.updateDevice().wait();
|
||||
KERNEL(times, CT::Kernel::basic(A.shape().items()), A.view()).wait();
|
||||
CT::Kernel::launch(times, CT::Kernel::basic(A.shape().items()), A.view()).wait();
|
||||
A.updateHost().wait();
|
||||
|
||||
uint32_t errors = 0;
|
||||
@ -125,7 +143,7 @@ struct MacroTests {
|
||||
static uint32_t Class() {
|
||||
uint32_t failed = 0;
|
||||
TestClass test(1);
|
||||
KERNEL(classTest, CT::Kernel::basic(1), test.that()).wait();
|
||||
CT::Kernel::launch(classTest, CT::Kernel::basic(1), test.that()).wait();
|
||||
test.updateHost().wait();
|
||||
|
||||
TEST(test.x == 100, "Class", "Errors: 0");
|
||||
@ -473,6 +491,53 @@ template <typename T> uint32_t doBLASTests() {
|
||||
return failed;
|
||||
}
|
||||
|
||||
void myHostFunc(const CT::Array<uint32_t> A, uint32_t num) {
|
||||
auto Aeig = A.atLeast2D().eigenMap();
|
||||
Aeig = Aeig.array() + num;
|
||||
}
|
||||
|
||||
void myBasicGraph(CT::GraphTools* tools, CT::Array<uint32_t>* A, CT::Array<uint32_t>* B) {
|
||||
// tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
|
||||
A->updateDevice("graphStream");
|
||||
tools->makeBranch("graphStream", "graphStreamBranch");
|
||||
B->updateDevice("graphStreamBranch");
|
||||
for (uint32_t iTimes = 0; iTimes < 30; ++iTimes) {
|
||||
CT::Kernel::launch(collatz, CT::Kernel::basic(A->shape().items(), "graphStream"),
|
||||
A->view());
|
||||
CT::Kernel::launch(plusOne, CT::Kernel::basic(A->shape().items(), "graphStreamBranch"),
|
||||
B->view());
|
||||
}
|
||||
|
||||
tools->joinBranch("graphStream", "graphStreamBranch");
|
||||
CT::Kernel::launch(addBoth, CT::Kernel::basic(A->shape().items(), "graphStream"), A->view(),
|
||||
B->view());
|
||||
A->updateHost("graphStream");
|
||||
B->updateHost("graphStream");
|
||||
tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
|
||||
}
|
||||
|
||||
uint32_t doGraphTest() {
|
||||
uint32_t failed = 0;
|
||||
CT::Array<uint32_t> A = CT::Array<uint32_t>::constant({1000000}, 50);
|
||||
CT::Array<uint32_t> B = CT::Array<uint32_t>::constant({1000000}, 0);
|
||||
CT::Manager::get()->addStream("graphStream");
|
||||
CT::Manager::get()->addStream("graphStreamBranch");
|
||||
|
||||
CT::GraphTools tools;
|
||||
CT::Graph graph("graphStream", myBasicGraph, &tools, &A, &B);
|
||||
graph.execute().wait();
|
||||
|
||||
uint32_t errors = 0;
|
||||
for (auto it = A.begin(); it != A.end(); ++it) {
|
||||
if (*it != 36) ++errors;
|
||||
}
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "Errors: " << errors;
|
||||
TEST(errors == 0, "Graph", msg.str().c_str());
|
||||
return failed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
uint32_t failed = 0;
|
||||
std::cout << box("Macro Tests") << "\n";
|
||||
@ -491,7 +556,10 @@ int main() {
|
||||
failed += doBLASTests<complex64>();
|
||||
failed += doBLASTests<complex128>();
|
||||
|
||||
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4;
|
||||
std::cout << box("Stream/Graph Tests") << "\n";
|
||||
failed += doGraphTest();
|
||||
|
||||
constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4 + 1;
|
||||
std::ostringstream msg;
|
||||
msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(")
|
||||
<< (tests - failed) << "/" << tests << ")";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user