KV cache
目录
KV cache
LLM推理过程分为Prefill和Decode两个阶段,其中Prefill阶段会对Prompt中所有的token做并行计算,得到Prompt中所有Tokens的KV Cache以及计算得到首Token。Prompt阶段Token计算得到的KV Cache会保存下来,留给Decode阶段复用,Decode阶段是一个自回归过程,每decode一个新的Token,都需要用到所有之前计算得到的KV Cache来计算当前query token的Attention。因此,当输出长度越来越大或者context很长时,KV Cache将会占用大量的显存。如何优化KV Cache的显存占用,一直都是LLM推理的核心主题之一。
之前一直疑惑kv cache既然每次只输入生成token就可以,那么位置信息该怎么注入呢?翻了翻llama的源码,找到了答案:
def forward(self, tokens: torch.Tensor, start_pos: int):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
start_pos (int): Starting position for attention caching.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
"""
= tokens.shape
_bsz, seqlen = self.tok_embeddings(tokens)
h self.freqs_cis = self.freqs_cis.to(h.device)
= self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis
= None
mask if seqlen > 1:
= torch.full(
mask float("-inf"), device=tokens.device
(seqlen, seqlen),
)
= torch.triu(mask, diagonal=1)
mask
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
= torch.hstack([
mask =tokens.device),
torch.zeros((seqlen, start_pos), device
mask
]).type_as(h)
for layer in self.layers:
= layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
h = self.output(h).float()
output return output
注意freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
这一行,即是实现了rope相对位置编码的kv
cache的核心。
kv cache代码
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
= x.shape
bs, slen, n_kv_heads, head_dim if n_rep == 1:
return x
return (
None, :]
x[:, :, :,
.expand(bs, slen, n_kv_heads, n_rep, head_dim)* n_rep, head_dim)
.reshape(bs, slen, n_kv_heads
)
class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
"""
Initialize the Attention module.
Args:
args (ModelArgs): Model configuration parameters.
Attributes:
n_kv_heads (int): Number of key and value heads.
n_local_heads (int): Number of local query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (ColumnParallelLinear): Linear transformation for queries.
wk (ColumnParallelLinear): Linear transformation for keys.
wv (ColumnParallelLinear): Linear transformation for values.
wo (RowParallelLinear): Linear transformation for output.
cache_k (torch.Tensor): Cached keys for attention.
cache_v (torch.Tensor): Cached values for attention.
"""
# ColumnParallelLinear是一个在大规模并行训练中使用的术语,特别是在训练大型的深度学习模型,
# 如Transformer模型时。在模型并行训练中,一个大型的矩阵(例如神经网络的权重矩阵)会被分割成不同的列,
# 并分散到不同的计算设备(如GPU)上。
#
# 在ColumnParallelLinear的情况下,每个计算设备存储权重矩阵的一部分列,而不是整个矩阵。
# 每个设备计算它自己的前向传播部分,并将结果发送给其他设备以进行进一步的处理或合并结果。
# 对于反向传播和梯度计算,每个设备计算其自己列的梯度,并可能需要与其他设备交换信息以更新权重。
#
# 这种方式可以显著减少每个设备上的内存需求,并允许训练更大的模型,因为模型的不同部分可以分布在多个设备上。
# ColumnParallelLinear和RowParallelLinear(另一种将权重矩阵按行划分的方法)是实现模型并行的两种常见策略。
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
= fs_init.get_model_parallel_world_size()
model_parallel_size self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,* self.head_dim,
args.n_heads =False,
bias=False,
gather_output=lambda x: x,
init_method
)self.wk = ColumnParallelLinear(
args.dim,self.n_kv_heads * self.head_dim,
=False,
bias=False,
gather_output=lambda x: x,
init_method
)self.wv = ColumnParallelLinear(
args.dim,self.n_kv_heads * self.head_dim,
=False,
bias=False,
gather_output=lambda x: x,
init_method
)self.wo = RowParallelLinear(
* self.head_dim,
args.n_heads
args.dim,=False,
bias=True,
input_is_parallel=lambda x: x,
init_method
)# kv_cache是缓存键值对,在训练过程中,我们只保存最近n个键值对
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,self.n_local_kv_heads,
self.head_dim,
)
).cuda()self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,int,
start_pos:
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):"""
Forward pass of the attention module.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.
Returns:
torch.Tensor: Output tensor after attention.
"""
# 假设当前x为(1, 1, dim),也就是上一个预测的token
# self-attention的输入,标准的(bs, seqlen, hidden_dim)
= x.shape
bsz, seqlen, _ # 计算当前token的qkv
# q k v分别进行映射,注意这里key, value也需要先由输入进行映射再和kv_cache里面的key, value进行拼接
= self.wq(x), self.wk(x), self.wv(x)
xq, xk, xv
= xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xk = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv
# 对当前输入的query和key进行RoPE,注意kv_cache里面的key已经做过了RoPE
= apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq, xk
# 缓存当前token的kv
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
# 取出前seqlen个token的kv缓存
# 取出全部缓存的key和value(包括之前在cache里面的和本次输入的),作为最终的key和value
= self.cache_k[:bsz, : start_pos + seqlen]
keys = self.cache_v[:bsz, : start_pos + seqlen]
values
# 将kv重复填充,使kv和q的头数个数相同
# repeat k/v heads if n_kv_heads < n_heads,对齐头的数量
= repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
keys = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values
# 计算当前token的attention score,,注意mask需要加上,另外维度要对应上
= xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xq = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
keys = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
scores if mask is not None:
= scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output return self.wo(output)