未验证 提交 ed719967 编辑于 作者: yjk21's avatar yjk21 提交者: GitHub
浏览文件

Adds small-batch kernels (#1126)

上级 c1378e6f
......@@ -30,28 +30,6 @@
#include "fmha.h"
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
......@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *seqlens_d,
void *o_packed_d,
void *s_d,
float p_dropout) {
......@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
params.seqlens = static_cast<int *>(seqlens_d);
// S = softmax(P)
params.s_ptr = s_d;
......@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}
constexpr uint32_t NUM_HEADS_DIM = 2;
constexpr uint32_t THREE_DIM = 1;
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout,
const int max_seq_len,
const bool is_training,
......@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
......@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int total = sizes[0];
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
......@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
......@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
......@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
......@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
......@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
// we're re-using these scales scales
// we're re-using these scales
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
......@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax };
}
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}
launch(params, is_training, num_chunks, stream);
return { ctx, s };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
auto opts = qkv.options();
auto dqkv = torch::empty_like(qkv);
int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
}else if( batch_size == 2 ) {
num_chunks = 3;
}
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // o_ptr = dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
params.dkv_ptr = dkv.data_ptr();
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, num_chunks, stream);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const int hidden_size = num_heads * head_size;
fmha_run_noloop_reduce(
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
......@@ -35,6 +35,12 @@
#include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1;
constexpr int H_DIM = 2;
constexpr int D_DIM = 3;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
......@@ -43,6 +49,9 @@ struct Qkv_params {
// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void *dqkv_ptr;
// Temporary for dKV.
void *dkv_ptr;
// The O matrix (output).
void *o_ptr;
......@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
int64_t s_stride_in_bytes;
// The dimensions.
int b, h, s, d;
int b, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
......@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int *cu_seqlens;
// array of length b holding the actual sequence lenghts.
int *seqlens;
// The dropout probability (probability of keeping an activation).
float p_dropout;
......@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream);
......@@ -39,7 +39,9 @@ template<
// The number of rows of Q, K or V loaded by this tile.
int ROWS,
// The number of columns.
int COLS
int COLS,
// The number of matrics.
int NUM_MATS = 3
>
struct Gmem_tile_qkv {
......@@ -74,7 +76,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// Add the block index.
row_offset += (int64_t)((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
......
......@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base {
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Fragment = typename Base::Fragment;
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
const int read_row = tidx / THREADS_PER_ROW;
......@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
template<int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
......
......@@ -27,6 +27,7 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
......@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
template<int CHUNKS>
__global__
void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){
fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);
fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
......@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
}else if( num_chunks == 3 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;
} else {
assert(false && "Unsupperted number of chunks");
}
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
......@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params &params) {
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
......@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params &params) {
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
......@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params &params) {
}
}
float d_s[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
......@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params &params) {
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
const float s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
d_s[2 * mi + ii][4 * ni + jj] = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s[2 * mi + ii][4 * ni + jj] * fabsf(s_dmask);
const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);
}
}
}
......@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params &params) {
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {