未验证 提交 6af09dd9 编辑于 作者: Burc Eryilmaz's avatar Burc Eryilmaz 提交者: GitHub
浏览文件

Seryilmaz/more cublas lt (#1147)



* support for fused dense layer with cublasLt, fusion in both fprop and bprop

* fix typo causing syntax error

* add fused GEMM+gelu+GEMM modue

* fix typo for workspace size

* update cublas check for 11600

* add tests for fused dense layer

* fix CUDA 10.x path
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
上级 9d86158d
import torch
import unittest
import torch.nn.functional as F
from apex import fused_dense
from torch import nn
from apex import amp
class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
self.seq_length = 512
self.sequences = 3
self.hidden_dim = 1024
self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
self.dense = fused_dense.FusedDense(1024, 3072)
self.dense.half()
self.dense.cuda()
def test_fused_dense(self) :
y_tst = self.dense(self.tst_inputs)
y_ref = torch.matmul(self.ref_inputs,self.dense.weight.t())+self.dense.bias
dy = torch.randn_like(y_tst).half()
y_tst.backward(dy)
dw_ref = torch.matmul(dy.t(), self.ref_inputs)
dx_ref = torch.matmul(dy, self.dense.weight.clone())
db_ref = dy.sum(0, False)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
if __name__ == '__main__':
unittest.main()
from .fused_dense import *
import torch
from torch import nn
import fused_dense_cuda
from .. import amp
#implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
output = fused_dense_cuda.linear_bias_forward(input, weight, bias)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(input, weight, grad_output)
return grad_input, grad_weight, grad_bias
class DenseNoBiasFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = torch.matmul(input, weight.t())
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(input)
return grad_input, grad_weight
class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight1, bias1, weight2, bias2):
ctx.save_for_backward(input, weight1, weight2)
output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(input, weight1, bias1, weight2, bias2)
ctx.save_for_backward(input, weight1, weight2, gelu_in, output1)
return output2
@staticmethod
def backward(ctx, grad_output):
input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu_in, output1, weight1, weight2, grad_output)
return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2
fused_dense_function = amp.half_function(FusedDenseFunc.apply)
dense_no_bias_function = amp.half_function(DenseNoBiasFunc.apply)
fused_dense_gelu_dense_function = amp.half_function(FusedDenseGeluDenseFunc.apply)
class FusedDense(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(FusedDense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
#assert False, "no-bias option not added yet"
self.register_parameter('bias', None)
def forward(self, input):
if self.bias is not None:
return fused_dense_function(input, self.weight, self.bias)
else:
return dense_no_bias_function(input, self.weight)
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, intermediate_features, out_features, bias=True):
super(FusedDenseGeluDense, self).__init__()
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.in_features = in_features
self.intermediate_features = intermediate_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features))
self.bias1 = nn.Parameter(torch.Tensor(intermediate_features))
self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features))
self.bias2 = nn.Parameter(torch.Tensor(out_features))
def forward(self, input):
return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2)
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <stdio.h>
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace);
template <typename T>
int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) ;
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace);
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, out_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
auto result = linear_bias_forward_cuda<scalar_t>(
input,
w_ptr,
bias,
in_features,
batch_size,
out_features,
out,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {out};
}
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto d_weight = at::empty({out_features, in_features}, input.type());
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, input.type());
#endif
auto d_input = at::empty({batch_size, in_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto output1 = at::empty({batch_size, hidden_features}, input.type());
auto gelu_in = at::empty({batch_size, hidden_features}, input.type());
auto output2 = at::empty({batch_size, out_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_gelu_linear_forward", [&] {
scalar_t* w1_ptr = weight1.data_ptr<scalar_t>();
scalar_t* b1_ptr = bias1.data_ptr<scalar_t>();
scalar_t* w2_ptr = weight2.data_ptr<scalar_t>();
scalar_t* b2_ptr = bias2.data_ptr<scalar_t>();
auto result = linear_gelu_linear_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w1_ptr,
b1_ptr,
w2_ptr,
b2_ptr,
in_features,
hidden_features,
batch_size,
out_features,
output1.data_ptr<scalar_t>(),
output2.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {output1, output2, gelu_in};
}
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto d_weight1 = at::empty({hidden_features, in_features}, input.type());
auto d_weight2 = at::empty({out_features, hidden_features}, input.type());
auto d_bias1 = at::empty({hidden_features}, input.type());
auto d_bias2 = at::empty({out_features}, input.type());
auto d_input = at::empty({batch_size, in_features}, input.type());
auto d_output1 = at::empty({batch_size, hidden_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward");
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
}
此差异已折叠。
......@@ -210,6 +210,12 @@ if "--cuda_ext" in sys.argv:
'csrc/mlp_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp',
'csrc/fused_dense_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册