Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
wanggh
apex
提交
0c7d8e3f
未验证
提交
0c7d8e3f
编辑于
10月 19, 2021
作者:
Masaki Kozuki
提交者:
GitHub
10月 18, 2021
浏览文件
remove THC headers/functions (#1192)
Changes include - THC headers removal - TH macros replacement - fix some typo in comment
上级
60821f53
变更
19
Hide whitespace changes
Inline
Side-by-side
apex/contrib/csrc/fmha/fmha_api.cpp
浏览文件 @
0c7d8e3f
...
...
@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in th
c
random
// number of times random will be generated per thread, to offset philox counter in th
e
random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
...
...
@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in th
c
random
// number of times random will be generated per thread, to offset philox counter in th
e
random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
...
...
apex/contrib/csrc/groupbn/batch_norm.cu
浏览文件 @
0c7d8e3f
...
...
@@ -2,8 +2,6 @@
#include
<ATen/cuda/CUDAContext.h>
#include
<c10/cuda/CUDACachingAllocator.h>
#include
"THC/THC.h"
#include
"batch_norm.h"
#include
<cuda.h>
...
...
apex/contrib/csrc/groupbn/batch_norm_add_relu.cu
浏览文件 @
0c7d8e3f
...
...
@@ -2,8 +2,6 @@
#include
<ATen/cuda/CUDAContext.h>
#include
<c10/cuda/CUDACachingAllocator.h>
#include
"THC/THC.h"
#include
"batch_norm_add_relu.h"
#include
<cuda.h>
...
...
apex/contrib/csrc/groupbn/ipc.cu
浏览文件 @
0c7d8e3f
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
"THC/THC.h"
#include
<cuda.h>
#include
"compat.h"
...
...
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"softmax.h"
#include
"dropout.h"
...
...
apex/contrib/csrc/multihead_attn/dropout.h
浏览文件 @
0c7d8e3f
...
...
@@ -9,8 +9,6 @@
#include
<ATen/cuda/CUDAContext.h>
#include
<curand_kernel.h>
#include
<THC/THCGeneral.h>
const
int
UNROLL
=
4
;
template
<
...
...
@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in th
c
random state
//number of times random will be generated per thread, to offset philox counter in th
e
random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
...
...
@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in th
c
random state
//number of times random will be generated per thread, to offset philox counter in th
e
random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"strided_batched_gemm.h"
#include
"softmax.h"
...
...
@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Q Fwd
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Fwd
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
);
// Output Linear
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -232,7 +232,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_lin_q_results
,
...
...
@@ -312,10 +312,10 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -456,7 +456,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Input Linear Q Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -478,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Q Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -499,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -521,7 +521,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"strided_batched_gemm.h"
#include
"softmax.h"
...
...
@@ -95,7 +95,7 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
...
...
@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Fwd
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -234,7 +234,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
);
// Output Linear
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -272,7 +272,7 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q
);
}
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
lyr_nrm_results
,
...
...
@@ -367,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
...
@@ -378,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -399,7 +399,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -519,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Input Linear Q Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -542,7 +542,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Q Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -563,7 +563,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -585,7 +585,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -622,7 +622,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"softmax.h"
#include
"dropout.h"
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"strided_batched_gemm.h"
#include
"softmax.h"
...
...
@@ -82,10 +82,10 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -173,7 +173,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs
.
copy_
(
output_biases
);
// Output Linear
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -194,7 +194,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_lin_results
,
...
...
@@ -264,10 +264,10 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -287,7 +287,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -403,7 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride
,
attn_batches
);
// Input Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -426,7 +426,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -447,7 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"strided_batched_gemm.h"
#include
"softmax.h"
...
...
@@ -81,10 +81,10 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -185,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs
.
copy_
(
output_biases
);
// Output Linear
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_lin_results
,
...
...
@@ -275,10 +275,10 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -298,7 +298,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -411,7 +411,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride
,
attn_batches
);
// Input Linear Dgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -434,7 +434,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Wgrad
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
浏览文件 @
0c7d8e3f
#include
<vector>
#include
<math.h>
#include
<iostream>
#include
<ATen/ATen.h>
#include
<cuda.h>
#include
<cuda_runtime.h>
#include
<cuda_fp16.h>
#include
<cuda_profiler_api.h>
#include
"THC/THC.h"
#include
<ATen/ATen.h>
#include
<ATen/cuda/CUDAContext.h>
#include
<torch/extension.h>
#include
<math.h>
#include
"strided_batched_gemm.h"
#include
"softmax.h"
...
...
@@ -78,9 +78,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -182,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
);
// Output Linear
T
HCublasCheck
(
cublasGemmEx
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -202,7 +202,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
T
HCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
T
ORCH_CUDABLAS_CHECK
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));