1 module dnv.cuda.cublas; 2 3 4 extern (C): 5 6 alias cublasStatus_t = int; 7 enum : cublasStatus_t { 8 CUBLAS_STATUS_SUCCESS =0, 9 CUBLAS_STATUS_NOT_INITIALIZED =1, 10 CUBLAS_STATUS_ALLOC_FAILED =3, 11 CUBLAS_STATUS_INVALID_VALUE =7, 12 CUBLAS_STATUS_ARCH_MISMATCH =8, 13 CUBLAS_STATUS_MAPPING_ERROR =11, 14 CUBLAS_STATUS_EXECUTION_FAILED=13, 15 CUBLAS_STATUS_INTERNAL_ERROR =14, 16 CUBLAS_STATUS_NOT_SUPPORTED =15, 17 CUBLAS_STATUS_LICENSE_ERROR =16 18 } 19 20 struct cublasContext; 21 alias cublasHandle_t = cublasContext*; 22 23 alias cublasOperation_t = int; 24 enum : cublasOperation_t { 25 CUBLAS_OP_N, // the non-transpose operation is selected 26 CUBLAS_OP_T, // the transpose operation is selected 27 CUBLAS_OP_C // the conjugate transpose operation is selected 28 } 29 30 31 // TODO: parse and retrieve cublas_api.h 32 cublasStatus_t cublasCreate_v2(cublasHandle_t*); 33 cublasStatus_t cublasDestroy_v2(cublasHandle_t handle); 34 35 cublasStatus_t cublasSgemm_v2(cublasHandle_t handle, 36 cublasOperation_t transa, cublasOperation_t transb, 37 int m, int n, int k, 38 const float *alpha, 39 const float *A, int lda, 40 const float *B, int ldb, 41 const float *beta, 42 float *C, int ldc); 43 44 45 46 version (LDC) 47 unittest { 48 import dnv.storage : Array; 49 import dnv.cuda.cublas; 50 51 cublasHandle_t handle; 52 auto status = cublasCreate_v2(&handle); 53 scope(exit) cublasDestroy_v2(handle); 54 assert(status == CUBLAS_STATUS_SUCCESS); 55 56 float[] A = [1, 2, 3, 57 4, 5, 6]; // M=3 x K=2 58 float[] B = [1, 2, 59 3, 4, 60 5, 6, 61 7, 8]; // N=4 x k=2 62 auto M = 3; 63 auto N = 4; 64 auto K = 2; 65 float alpha = 1.0f; 66 float beta = 0.0f; 67 auto d_A = new Array!float(A); 68 auto d_B = new Array!float(B); 69 auto d_C = new Array!float(M * N); 70 71 // cublas driver API 72 status = cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_T, M, N, K, 73 &alpha, d_A.data, M, d_B.data, N, &beta, d_C.data, M); 74 // import std.stdio; 75 // status.writeln; 76 // only LDC1.4.0 pass but 77 // DMD2.075.1 and GDC7.2.0 cause CUBLAS_STATUS_INTERNAL_ERROR 78 assert(status == CUBLAS_STATUS_SUCCESS); 79 80 auto d_D = new Array!float(N * M); 81 status = cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, M, K, 82 &alpha, d_B.data, N, d_A.data, M, &beta, d_D.data, N); 83 assert(status == CUBLAS_STATUS_SUCCESS); 84 85 // check C = D.T 86 auto C = d_C.to_cpu(); // C = A x B.T 87 auto D = d_D.to_cpu(); // D = B x A.T 88 foreach (m; 0 .. M) { 89 foreach (n; 0 .. N) { 90 assert(C[m + n * M] == D[n + m * N]); 91 } 92 } 93 } 94