未验证 提交 0c2c6eea 编辑于 作者: Nan Zheng's avatar Nan Zheng 提交者: GitHub
浏览文件

Added more fusion and vectorized kernel for transducer (#1125)

* Added support for fused ReLU and dropout into transducer joint

* Reorganized code selection path in transducer joint fwd
* Added support for fused ReLU+dropout into transducer joint

* Vectorize transducer loss backward with fused softmax (#3)

* Nanz/transducer loss (#4)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potential IMA

* Nanz/transducer loss (#5)

* Vectorize transducer loss backward with fused softmax

* Added a predicate to avoid potentional IMA

* Added more predicates to avoid IMAs

* Updated documentations for newly added features.

* Fixed a error in transducer.py
上级 ed719967
......@@ -5,7 +5,7 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor transducer_joint_cuda_forward(
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
......@@ -14,19 +14,23 @@ torch::Tensor transducer_joint_cuda_forward(
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
torch::Tensor grad,
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput);
bool packOutput,
float scale);
torch::Tensor transducer_joint_forward(
std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
......@@ -35,6 +39,9 @@ torch::Tensor transducer_joint_forward(
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
......@@ -51,30 +58,37 @@ torch::Tensor transducer_joint_forward(
packedBatch,
opt,
packOutput,
relu,
dropout,
dropoutProb,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
torch::Tensor grad,
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput) {
CHECK_INPUT(grad);
bool packOutput,
float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
grad,
in,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput);
packOutput,
scale);
}
......
......@@ -408,7 +408,7 @@ __global__ void transducer_loss_fused_backward(
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){
......@@ -421,6 +421,9 @@ __global__ void transducer_loss_fused_backward(
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
myLabelShared = myLabel[u];
}
__syncthreads();
......@@ -429,14 +432,14 @@ __global__ void transducer_loss_fused_backward(
// Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabel[u]){
myGrad -= std::exp(grad + myBeta[t*maxGLen + u + 1]);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBeta[(t+1)*maxGLen + u]);
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGrad[h] = myGrad;
}
......@@ -450,6 +453,104 @@ __global__ void transducer_loss_fused_backward(
}
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// Variabels for vectorization
scalar_t myXBuffer[V], myXGradBuffer[V];
auto myXVec = reinterpret_cast<vec_t const *>(myX);
auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
if (t < myFLen and u < myGLen){
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
if (t != myFLen - 1)
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
if (u != myGLen - 1){
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myLabelShared = myLabel[u];
}
}
__syncthreads();
#pragma unroll
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
// Load myX in a vector form
*myXBufferVec = myXVec[h0/V];
// Do the update for a vector of input
#pragma unroll
for (int i = 0; i < V; ++i){
auto h = h0 + i;
acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGradBuffer[i] = myGrad;
}
// Store myXGrad in a vector form
myXGradVec[h0/V] = *myXGradBufferVec;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
myXGradVec[h0/V] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
......@@ -586,23 +687,51 @@ torch::Tensor transducer_loss_cuda_backward(
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using vec_t = uint64_t;
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0
and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>())
% vecAlignment == 0;
if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){
transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>
<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
else{
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
}));
}
else{
......
......@@ -28,6 +28,7 @@ class TransducerJointTest(unittest.TestCase):
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
......@@ -49,30 +50,38 @@ class TransducerJointTest(unittest.TestCase):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed
def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked
def run_transducer_joint(self, for_vector_kernel, pack_output):
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output)
my_joint= TransducerJoint(pack_output=pack_output)
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
......@@ -82,6 +91,22 @@ class TransducerJointTest(unittest.TestCase):
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)
# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)
f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad
......@@ -91,16 +116,41 @@ class TransducerJointTest(unittest.TestCase):
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=False)
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False)
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True)
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True)
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
if __name__ == '__main__':
......
......@@ -8,13 +8,13 @@ class TransducerLossTest(unittest.TestCase):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def gen_input(self, scalar_t):
def gen_input(self, scalar_t, for_vector_kernel):
self.B = 5
T_min = 23
T_max = 51
U_min = 12
U_max = 25
V = 16
V = 16 if for_vector_kernel else 14
self.blank_idx = V - 1
device = "cuda"
......@@ -61,8 +61,8 @@ class TransducerLossTest(unittest.TestCase):
x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u]
return x_unpacked
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input):
self.gen_input(scalar_t)
def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
self.gen_input(scalar_t, for_vector_kernel)
my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward,
packed_input=packed_input)
if not packed_input:
......@@ -90,28 +90,40 @@ class TransducerLossTest(unittest.TestCase):
def test_transducer_loss_fp32(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32,
fuse_softmax_backward=False,
packed_input=False)
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5))
def test_transducer_loss_fp16(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=False,
packed_input=False)
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=False)
packed_input=False,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True)
packed_input=True,
for_vector_kernel=False)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16,
fuse_softmax_backward=True,
packed_input=True,
for_vector_kernel=True)
self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3))
......
......@@ -76,12 +76,21 @@ def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
return alpha, beta, x.grad, loss
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output):
def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
dropout_prob=0, mask=None):
if dropout and mask == None:
raise NotImplementedError("mask needs to supplied to test dropout.")
B, T, H = f.size()
U = g.size(1)
f_expand = f.unsqueeze(dim=2)
g_expand = g.unsqueeze(dim=1)
h = f_expand + g_expand
if relu:
h = torch.nn.functional.relu(h)
if dropout:
h *= mask
scale = 1/(1-dropout_prob)
h *= scale
h.backward(h_grad)
if pack_output == False:
......@@ -90,6 +99,7 @@ def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output):
for b in range(B):
h[b, f_len[b]:] = -1
h[b, :, g_len[b]:] = -1
return h, f.grad, g.grad
# packing
......
......@@ -10,18 +10,34 @@ class TransducerJoint(torch.nn.Module):
Arguments:
pack_output (bool, optional): whether to pack the output in a compact form with don't-care
data being removed. (default: False)
relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1
(default: False)
dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1
(default: False)
opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.
(default: 1)
fwd_tile_size (int, optional): tile size used in forward operation. This argument will be
ignored if opt != 1. (default: 4)
dropout_prob (float, optional): dropout probability. (default: 0.0)
probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
operation. When this argument is set to True, the mask can be accessed through
self.mask_probe. (default: false)
"""
def __init__(self, pack_output=False, opt=1, fwd_tile_size=4):
def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4,
dropout_prob=0, probe_mask=False):
super(TransducerJoint, self).__init__()
self.pack_output = pack_output
self.relu = relu
self.dropout = dropout
self.dropout_prob = dropout_prob
self.opt = opt
self.fwd_tile_size = fwd_tile_size
self.dummy_batch_offset = torch.empty(0)
masked = self.relu or self.dropout
self.mask_probe = [] if masked and probe_mask else None
if masked and opt != 1:
raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")
def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
......@@ -43,8 +59,10 @@ class TransducerJoint(torch.nn.Module):
my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
if self.pack_output and (batch_offset is None or packed_batch == 0):
raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, my_batch_offset,
packed_batch, self.opt, self.fwd_tile_size)
dropout = self.dropout and self.training # only dropout for training
return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout,
my_batch_offset, packed_batch, self.opt,