1 module dnv.driver; 2 3 import derelict.cuda.driverapi; 4 import std.stdio : writefln, writeln; 5 import dnv.error : nvrtcCheck, cuCheck; 6 static import dnv.error; 7 8 9 class DriverBase { 10 static bool[int] initialized; 11 12 private shared static this () { 13 DerelictCUDADriver.load(); 14 cuInit(0); 15 } 16 17 static auto deviceInit() { 18 int deviceCount = 0; 19 cuDeviceGetCount(&deviceCount); 20 foreach (i; 0 .. deviceCount) { 21 auto result = deviceInit(i); 22 if (CUDA_SUCCESS != result) { 23 return result; 24 } 25 } 26 return cast(dnv.error.CUresult) CUDA_SUCCESS; 27 } 28 29 static auto deviceInit(int id) { 30 import std.string; 31 if (id !in initialized) { 32 int deviceCount = 0; 33 cuDeviceGetCount(&deviceCount); 34 if (id >= deviceCount) { 35 return CUresult(CUDA_ERROR_NO_DEVICE); 36 } 37 38 int device; 39 CUcontext context; 40 enum int namelen = 256; 41 char[namelen] name; 42 cuDeviceGet(&device, id); 43 cuDeviceGetName(name.ptr, namelen, device); 44 // TODO: use logger 45 writefln(">>> Using CUDA Device [%d]: %s", id, name.ptr.fromStringz); 46 47 // get compute capabilities and the devicename 48 // cudaDeviceProp dev; 49 // check(cast(CUresult) cudaGetDeviceProperties(&dev, id)); 50 // writeln(dev); 51 cuCheck(cuCtxCreate(&context, 0, device)); 52 initialized[id] = true; 53 } 54 return cast(dnv.error.CUresult) CUDA_SUCCESS; 55 } 56 } 57 58 unittest { 59 int deviceCount = 0; 60 cuDeviceGetCount(&deviceCount); 61 DriverBase.deviceInit(); 62 assert(DriverBase.initialized.length == deviceCount); 63 } 64 65 66 CUmodule loadPTX(char *ptx, int argc, char **argv, int cuDevice) 67 { 68 CUmodule cumodule; 69 CUcontext context; 70 71 cuCheck(cuInit(0)); 72 cuCheck(cuDeviceGet(&cuDevice, 0)); 73 cuCheck(cuCtxCreate(&context, 0, cuDevice)); 74 cuCheck(cuModuleLoadDataEx(&cumodule, ptx, 0U, null, null)); 75 return cumodule; 76 } 77 78 79 auto compile(void* kernel_addr, string funcname, string code, int cuDevice=0) 80 { 81 import derelict.nvrtc; // : DerelictNVRTC, nvrtcProgram, nvrtcCreateProgram, nvrtcCompileProgram; 82 DerelictNVRTC.load(); 83 84 import std.stdio : writeln, writefln; 85 import std.algorithm : map; 86 import std.array : array, empty; 87 import std.ascii : isWhite; 88 import std.string : toStringz, fromStringz, strip; 89 // compile 90 auto filename = funcname ~ ".cu"; 91 nvrtcProgram prog; 92 nvrtcCheck(nvrtcCreateProgram(&prog, code.toStringz, filename.toStringz, 0, null, null)); 93 auto opts = ["--use_fast_math", "-arch=compute_30", "--std=c++11"].map!(a => cast(const char*) a.toStringz).array; 94 nvrtcCheck(nvrtcCompileProgram(prog, 3, opts.ptr)); 95 96 // dump log 97 size_t logSize; 98 nvrtcCheck(nvrtcGetProgramLogSize(prog, &logSize)); 99 char[] log; 100 log.length = logSize + 1; 101 nvrtcCheck(nvrtcGetProgramLog(prog, log.ptr)); 102 log[logSize] = '\0'; 103 auto slog = log.ptr.fromStringz.strip; 104 if (logSize > 0 && !slog.empty) { 105 writefln(">>> NVRTC log: %s\n%s", funcname, slog); 106 } 107 108 // load compiled ptx 109 size_t ptxSize; 110 nvrtcCheck(nvrtcGetPTXSize(prog, &ptxSize)); 111 char[] ptx; 112 ptx.length = ptxSize; 113 nvrtcCheck(nvrtcGetPTX(prog, ptx.ptr)); 114 nvrtcCheck(nvrtcDestroyProgram(&prog)); 115 116 // FIXME: split these to another function to return 117 CUmodule cumodule = loadPTX(ptx.ptr, 0, null, cuDevice); // ??? 118 cuCheck(cuModuleGetFunction(cast (CUfunction*) kernel_addr, cumodule, funcname.toStringz)); 119 return cast(dnv.error.nvrtcResult) NVRTC_SUCCESS; 120 } 121 122 123 auto launch(void* kernel_addr, void* kernel_args, uint* grids, uint* blocks) 124 // size_t shared_memory=0, CUstream stream=null); 125 { 126 auto func = cast(CUfunction*) kernel_addr; 127 auto args = cast(void**) kernel_args; 128 void** extra = null; 129 cuCheck(cuLaunchKernel(*func, 130 grids[0], grids[1], grids[2], 131 blocks[0], blocks[1], blocks[2], 132 0U, // shared, stream, // FIXME 133 null, 134 &args[0], 135 null)); // FIXME: what is this arg? 136 cuCheck(cuCtxSynchronize()); 137 return cast(dnv.error.CUresult) CUDA_SUCCESS; 138 }