未验证 提交 54b93919 编辑于 作者: Burc Eryilmaz's avatar Burc Eryilmaz 提交者: GitHub
浏览文件

fix CUBLAS guards (#1162)



* support for fused dense layer with cublasLt, fusion in both fprop and bprop

* fix typo causing syntax error

* add fused GEMM+gelu+GEMM modue

* fix typo for workspace size

* update cublas check for 11600

* add tests for fused dense layer

* fix CUDA 10.x path

* safer guard around CUBLAS constants, remove unreferenced variable

* more guard changes

* guard against cublas version instead of cuda
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
上级 ae1cdd64
......@@ -62,7 +62,7 @@ std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight
// create output/workspace tensor
auto d_weight = at::empty({out_features, in_features}, input.type());
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, input.type());
......
......@@ -129,7 +129,7 @@ cublasStatus_t gemm_bias(
}
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int gemm_bias_lt(
......@@ -1148,7 +1148,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
......@@ -1200,7 +1200,6 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features,
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bgradb_lt(
......@@ -1273,7 +1272,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2,
const float alpha = 1.0;
const float beta_zero = 0.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_gelu_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
......@@ -1329,9 +1328,8 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm
status = gemm_bgradb_lt(
(cublasLtHandle_t)handle,
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册