1 module dnv.storage;
2 
3 import std.conv : to;
4 
5 import dnv.error;
6 import dnv.driver;
7 
8 import derelict.cuda.driverapi; // : cuMemAlloc, cuMemcpyDtoH, cuMemcpyHtoD, cuMemFree, CUdeviceptr, cuCtxSynchronize;
9 
10 class Array(T) : DriverBase {
11     alias Element = T;
12     alias Storage = T*;
13 
14     int dev;
15     size_t rawLength;
16     size_t length;
17     CUdeviceptr ptr;
18     T[] cpu_storage;
19     this(size_t n, int dev = 0) {
20         dev = dev;
21         length = n;
22         rawLength = T.sizeof * n;
23         cuCheck(this.deviceInit(dev));
24         cuCheck(cuMemAlloc(&ptr, rawLength));
25     }
26     this(in T[] src, int dev = 0) {
27         this(src.length);
28         to_gpu(src);
29     }
30     ~this() {
31         cuCheck(cuMemFree(ptr));
32     }
33     auto to_gpu(in T[] src) {
34         // TODO: use cuMemcpyHtoDAsync and CUstream
35         cuCheck(cuMemcpyHtoD(ptr, to!(const(void*))(src.ptr), rawLength));
36         return this;
37     }
38     T[] to_cpu() {
39         cpu_storage.length = rawLength / T.sizeof;
40         // check(cudaDeviceInit_(dev));
41         cuCheck(cuMemcpyDtoH(to!(void*)(cpu_storage.ptr), ptr, rawLength));
42         return cpu_storage;
43     }
44 
45     auto data() {
46         return cast(T*) this.ptr;
47     }
48 }
49 
50 unittest {
51     float[] h1 = [1,2,3];
52     auto d = new Array!float(h1.length);
53 
54     auto a1 = d.to_gpu(h1).to_cpu();
55     assert(h1 == a1);
56 
57     float[] h2 = [3,2,1];
58     d.to_gpu(h2);
59     auto hd = d.to_cpu();
60     assert(h2 == hd);
61     static assert(is(typeof(d).Storage == float*));
62     static assert(is(typeof(d).Element == float));
63     delete d;
64 }