llama系列
目录
llama系列的主要差别在训练上下文长度、词表大小、训练token数、注意力机制以及对齐方法上,由于强化学习还没深入学习,因此跳过强化学习部分。
llama1
RMSnorm
与 layer Norm 相比,RMS Norm的主要区别在于去掉了减去均值的部分,计算公式为: \[ RMS(a)=\sqrt{\frac{1}{n}\Sigma_{i=1}^{n}a_{i}^{2}} \\ \]
\[ \overline{a}_{i}=\frac{a_{i}}{RMS(a)} \]
此外RMSNorm 还可以引入可学习的缩放因子gi 和偏移参数bi,从而得到
\[\overline{a}_i=\frac{a_i}{|RMS(\boldsymbol{a})}g_i+b_i\]
代码如下:
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps # eps 防止取倒数之后分母为0
def forward(self, hidden_states):
= hidden_states.dtype
input_dtype = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
variance = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # rsqrt 即sqrt后取倒数
hidden_states # weight 是末尾乘的可训练参数, 即g_i
return (self.weight * hidden_states).to(input_dtype)
RoPE
详见:rope
SwiGLU
\[ \begin{aligned} \mathrm{FFN}_{\mathrm{SwiGLU}}(x,W,V,W_{2})&=\mathrm{SwiGLU}(x,W,V)W_{2}\\\mathrm{SwiGLU}(x,W,V)&=\mathrm{Swish}_{\beta}(xW)\otimes xV\\\mathrm{Swish}_{\beta}(x)&=x\sigma(\beta x) \end{aligned} \]

代码如下:
class LlamaMLP(nn.Module):
def __init__(
self,
int,
hidden_size: int,
intermediate_size: str,
hidden_act:
):super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
# config 中 hidden_act = 'silu'
# 'silu' 和 'swish' 对应的激活函数均为:SiLUActivation
# https://github.com/huggingface/transformers/blob/717dadc6f36be9f50abc66adfd918f9b0e6e3502/src/transformers/activations.py#L229
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
# 对应上述公式的 SwiGLU
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
发现多了一个nn.Linear
进行门控。
加速训练
- 使用了xformers库。
- 减少了activation checkpointing 中,重新计算 activation 的计算量。手动实现 transformer 层的反向传递函数,保存了计算成本高的 activations,例如线性层的输出。
- 通过使用 model parallelism 和 sequence parallelism 来减少显存的使用量。
- 尽可能地将 activations 的计算和GPU之间的通讯进行并行。 # llama2
llama2主要引入了GQA,分组注意力机制,详见Attention