未验证 提交 0c7d8e3f 编辑于 作者: Masaki Kozuki's avatar Masaki Kozuki 提交者: GitHub
浏览文件

remove THC headers/functions (#1192)

Changes include
- THC headers removal
- TH macros replacement
- fix some typo in comment
上级 60821f53
......@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
......@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
......
......@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
......
......@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <cuda.h>
#include "compat.h"
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
......
......@@ -9,8 +9,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include <THC/THCGeneral.h>
const int UNROLL = 4;
template <
......@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
......@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
//number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -232,7 +232,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_q_results,
......@@ -312,10 +312,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -456,7 +456,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -478,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -499,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -521,7 +521,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -95,7 +95,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
......@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -234,7 +234,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -272,7 +272,7 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q);
}
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
lyr_nrm_results,
......@@ -367,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......@@ -378,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -399,7 +399,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -519,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches);
// Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -542,7 +542,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -563,7 +563,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -585,7 +585,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -622,7 +622,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -82,10 +82,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -173,7 +173,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -194,7 +194,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
......@@ -264,10 +264,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -287,7 +287,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -403,7 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -426,7 +426,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -447,7 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -81,10 +81,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -185,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_lin_results,
......@@ -275,10 +275,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -298,7 +298,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -411,7 +411,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride,
attn_batches);
// Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -434,7 +434,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_grads,
......
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
......@@ -78,9 +78,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -182,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches);
// Output Linear
THCublasCheck(cublasGemmEx(handle,
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -202,7 +202,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));