1 module dnv.kernel;
2 
3 import std.conv : to;
4 
5 import dnv.storage;
6 import dnv.compiler;
7 import dnv.driver;
8 import dnv.error;
9 
10 import derelict.cuda.driverapi : CUfunction;
11 
12 
13 void* vptr(F)(ref F f) {
14   static if (is(typeof(f.ptr))) {
15     return to!(void*)(&f.ptr);
16   } else {
17     return to!(void*)(&f);
18   }
19 }
20 
21 
22 class KernelBase(Compiler, Launcher) {
23   /*
24     FIXME: CUresult.CUDA_ERROR_INVALID_HANDLE
25     - init device in this or opCall?
26     - multi device support
27 
28     NOTE: Kernel can be factored into Kernel!(Compiler, Launcher)
29     - Compiler: Static, Dynamic, Unsafe ... and user-defined
30     - Launcher: Simple, Heavy, Shared, Async ... and user-defined
31   */
32   CUfunction func;
33   Launcher launch;
34   Compiler compiler;
35 
36   // FIXME: need better prediction
37   static if (is(typeof(compiler.cargs))) {
38     this() {
39       compiler.build(vfunc, compiler.code);
40     }
41   } else {
42     this(Code code) {
43       compiler.build(vfunc, code);
44     }
45   }
46 
47   void* vfunc() {
48     return to!(void*)(&func);
49   }
50 
51   void opCall() {}
52 
53   void opCall(Ts...)(Ts targs) {
54     compiler.assertArgs(targs);
55 
56     void[] vargs;
57     foreach (i, t; targs) {
58       vargs ~= [vptr(targs[i])];
59     }
60 
61     launch.setup(targs);
62     check(dnv.driver.launch(vptr(func), vargs.ptr, launch.grids.ptr, launch.blocks.ptr));
63   }
64 }
65 
66 alias RuntimeKernel(L = SimpleLauncher) = KernelBase!(UnsafeCompiler, L);
67 
68 alias TypedKernel(Code code, L = SimpleLauncher) = KernelBase!(StaticCompiler!code, L);
69 
70 
71 unittest {
72   import std.stdio;
73   import std.random;
74   import std.range;
75 
76 
77   auto empty = new RuntimeKernel!()
78     (Code("empty", "", "int i = blockDim.x * blockIdx.x + threadIdx.x;"));
79   empty();
80 
81   int n = 10;
82   auto gen = () => new Array!float(generate!(() => uniform(-1f, 1f)).take(n).array());
83   auto a = gen();
84   auto b = gen();
85   auto c = new Array!float(n);
86   auto saxpy = new RuntimeKernel!()(
87     Code(
88       "saxpy", q{float *A, float *B, float *C, int numElements},
89       q{
90         int i = blockDim.x * blockIdx.x + threadIdx.x;
91         if (i < numElements) C[i] = A[i] + B[i];
92       })
93     );
94   saxpy(a, b, c, n);
95   foreach (ai, bi, ci; zip(a.to_cpu(), b.to_cpu(), c.to_cpu())) {
96     assert(ai + bi == ci);
97   }
98 
99   enum code = Code(
100       "saxpy", q{float *A, float *B, float *C, int numElements},
101       q{
102         int i = blockDim.x * blockIdx.x + threadIdx.x;
103         if (i < numElements) C[i] = A[i] + B[i];
104       });
105   auto tsaxpy = new TypedKernel!(code);
106   tsaxpy(a, b, c, n);
107   foreach (ai, bi, ci; zip(a.to_cpu(), b.to_cpu(), c.to_cpu())) {
108     assert(ai + bi == ci);
109   }
110 }