1 module dnv.driver;
2 
3 import derelict.cuda.driverapi;
4 import std.stdio : writefln, writeln;
5 import dnv.error : nvrtcCheck, cuCheck;
6 static import dnv.error;
7 
8 
9 class DriverBase {
10     static bool[int] initialized;
11 
12     private shared static this () {
13         DerelictCUDADriver.load();
14         cuInit(0);
15     }
16 
17     static auto deviceInit() {
18         int deviceCount = 0;
19         cuDeviceGetCount(&deviceCount);
20         foreach (i; 0 .. deviceCount) {
21             auto result = deviceInit(i);
22             if (CUDA_SUCCESS != result) {
23                 return result;
24             }
25         }
26         return cast(dnv.error.CUresult) CUDA_SUCCESS;
27     }
28 
29     static auto deviceInit(int id) {
30         import std.string;
31         if (id !in initialized) {
32             int deviceCount = 0;
33             cuDeviceGetCount(&deviceCount);
34             if (id >= deviceCount) {
35                 return CUresult(CUDA_ERROR_NO_DEVICE);
36             }
37 
38             int device;
39             CUcontext context;
40             enum int namelen = 256;
41             char[namelen] name;
42             cuDeviceGet(&device, id);
43             cuDeviceGetName(name.ptr, namelen, device);
44             // TODO: use logger
45             writefln(">>> Using CUDA Device [%d]: %s", id, name.ptr.fromStringz);
46 
47             // get compute capabilities and the devicename
48             // cudaDeviceProp dev;
49             // check(cast(CUresult) cudaGetDeviceProperties(&dev, id));
50             // writeln(dev);
51             cuCheck(cuCtxCreate(&context, 0, device));
52             initialized[id] = true;
53         }
54         return cast(dnv.error.CUresult) CUDA_SUCCESS;
55     }
56 }
57 
58 unittest {
59     int deviceCount = 0;
60     cuDeviceGetCount(&deviceCount);
61     DriverBase.deviceInit();
62     assert(DriverBase.initialized.length == deviceCount);
63 }
64 
65 
66 CUmodule loadPTX(char *ptx, int argc, char **argv, int cuDevice)
67 {
68     CUmodule cumodule;
69     CUcontext context;
70 
71     cuCheck(cuInit(0));
72     cuCheck(cuDeviceGet(&cuDevice, 0));
73     cuCheck(cuCtxCreate(&context, 0, cuDevice));
74     cuCheck(cuModuleLoadDataEx(&cumodule, ptx, 0U, null, null));
75     return cumodule;
76 }
77 
78 
79 auto compile(void* kernel_addr, string funcname, string code, int cuDevice=0)
80 {
81     import derelict.nvrtc; //  : DerelictNVRTC, nvrtcProgram, nvrtcCreateProgram, nvrtcCompileProgram;
82     DerelictNVRTC.load();
83 
84     import std.stdio : writeln, writefln;
85     import std.algorithm : map;
86     import std.array : array, empty;
87     import std.ascii : isWhite;
88     import std.string : toStringz, fromStringz, strip;
89     // compile
90     auto filename = funcname ~ ".cu";
91     nvrtcProgram prog;
92     nvrtcCheck(nvrtcCreateProgram(&prog, code.toStringz, filename.toStringz, 0, null, null));
93     auto opts = ["--use_fast_math", "-arch=compute_30", "--std=c++11"].map!(a => cast(const char*) a.toStringz).array;
94     nvrtcCheck(nvrtcCompileProgram(prog, 3, opts.ptr));
95 
96     // dump log
97     size_t logSize;
98     nvrtcCheck(nvrtcGetProgramLogSize(prog, &logSize));
99     char[] log;
100     log.length = logSize + 1;
101     nvrtcCheck(nvrtcGetProgramLog(prog, log.ptr));
102     log[logSize] = '\0';
103     auto slog = log.ptr.fromStringz.strip;
104     if (logSize > 0 && !slog.empty) {
105         writefln(">>> NVRTC log: %s\n%s", funcname, slog);
106     }
107 
108     // load compiled ptx
109     size_t ptxSize;
110     nvrtcCheck(nvrtcGetPTXSize(prog, &ptxSize));
111     char[] ptx;
112     ptx.length = ptxSize;
113     nvrtcCheck(nvrtcGetPTX(prog, ptx.ptr));
114     nvrtcCheck(nvrtcDestroyProgram(&prog));
115 
116     // FIXME: split these to another function to return 
117     CUmodule cumodule = loadPTX(ptx.ptr, 0, null, cuDevice); // ???
118     cuCheck(cuModuleGetFunction(cast (CUfunction*) kernel_addr, cumodule, funcname.toStringz));
119     return cast(dnv.error.nvrtcResult) NVRTC_SUCCESS;
120 }
121 
122 
123 auto launch(void* kernel_addr, void* kernel_args, uint* grids, uint* blocks)
124                  // size_t shared_memory=0, CUstream stream=null);
125 {
126     auto func = cast(CUfunction*) kernel_addr;
127     auto args = cast(void**) kernel_args;
128     void** extra = null;
129     cuCheck(cuLaunchKernel(*func,
130                            grids[0], grids[1], grids[2],
131                            blocks[0], blocks[1], blocks[2],
132                            0U, // shared, stream, // FIXME
133                            null,
134                            &args[0],
135                            null)); // FIXME: what is this arg?
136     cuCheck(cuCtxSynchronize());
137     return cast(dnv.error.CUresult) CUDA_SUCCESS;
138 }