1 module dnv.cuda.cublas;
2 
3 
4 extern (C):
5 
6 alias cublasStatus_t = int;
7 enum : cublasStatus_t {
8     CUBLAS_STATUS_SUCCESS         =0,
9     CUBLAS_STATUS_NOT_INITIALIZED =1,
10     CUBLAS_STATUS_ALLOC_FAILED    =3,
11     CUBLAS_STATUS_INVALID_VALUE   =7,
12     CUBLAS_STATUS_ARCH_MISMATCH   =8,
13     CUBLAS_STATUS_MAPPING_ERROR   =11,
14     CUBLAS_STATUS_EXECUTION_FAILED=13,
15     CUBLAS_STATUS_INTERNAL_ERROR  =14,
16     CUBLAS_STATUS_NOT_SUPPORTED   =15,
17     CUBLAS_STATUS_LICENSE_ERROR   =16
18 }
19 
20 struct cublasContext;
21 alias cublasHandle_t = cublasContext*;
22 
23 alias cublasOperation_t = int;
24 enum : cublasOperation_t {
25     CUBLAS_OP_N, // the non-transpose operation is selected
26     CUBLAS_OP_T, // the transpose operation is selected
27     CUBLAS_OP_C // the conjugate transpose operation is selected
28 }
29 
30 
31 // TODO: parse and retrieve cublas_api.h
32 cublasStatus_t cublasCreate_v2(cublasHandle_t*);
33 cublasStatus_t cublasDestroy_v2(cublasHandle_t handle);
34 
35 cublasStatus_t cublasSgemm_v2(cublasHandle_t handle,
36                            cublasOperation_t transa, cublasOperation_t transb,
37                            int m, int n, int k,
38                            const float           *alpha,
39                            const float           *A, int lda,
40                            const float           *B, int ldb,
41                            const float           *beta,
42                            float           *C, int ldc);
43 
44 
45 
46 version (LDC)
47 unittest {
48     import dnv.storage : Array;
49     import dnv.cuda.cublas;
50 
51     cublasHandle_t handle;
52     auto status = cublasCreate_v2(&handle);
53     scope(exit) cublasDestroy_v2(handle);
54     assert(status == CUBLAS_STATUS_SUCCESS);
55 
56     float[] A = [1, 2, 3,
57                  4, 5, 6]; // M=3 x K=2
58     float[] B = [1, 2,
59                  3, 4,
60                  5, 6,
61                  7, 8]; // N=4 x k=2
62     auto M = 3;
63     auto N = 4;
64     auto K = 2;
65     float alpha = 1.0f;
66     float beta = 0.0f;
67     auto d_A = new Array!float(A);
68     auto d_B = new Array!float(B);
69     auto d_C = new Array!float(M * N);
70 
71     // cublas driver API
72     status = cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_T, M, N, K,
73                             &alpha, d_A.data, M, d_B.data, N, &beta, d_C.data, M);
74     // import std.stdio;
75     // status.writeln;
76     // only LDC1.4.0 pass but
77     // DMD2.075.1 and GDC7.2.0 cause CUBLAS_STATUS_INTERNAL_ERROR
78     assert(status == CUBLAS_STATUS_SUCCESS);
79 
80     auto d_D = new Array!float(N * M);
81     status = cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, M, K,
82                             &alpha, d_B.data, N, d_A.data, M, &beta, d_D.data, N);
83     assert(status == CUBLAS_STATUS_SUCCESS);
84 
85     // check C = D.T
86     auto C = d_C.to_cpu();     // C = A x B.T
87     auto D = d_D.to_cpu();     // D = B x A.T
88     foreach (m; 0 .. M) {
89         foreach (n; 0 .. N) {
90             assert(C[m + n * M] == D[n + m * N]);
91         }
92     }
93 }
94