transformers 阅读:Llama 模型

正文

学习一下 transformers 库中,Llama 模型的代码,学习过程中写下这篇笔记,一来加深印象,二来可以多次回顾。

笔者小白,里面错误之处请不吝指出。

层归一化 LlamaRMSNorm

transformers 中对于 LlamaRMSNorm 类的定义如下:

class LlamaRMSNorm(nn.Module):  
    def __init__(self, hidden_size, eps=1e-6):  
    """  
    LlamaRMSNorm is equivalent to T5LayerNorm  
    """  
    super().__init__()  
    self.weight = nn.Parameter(torch.ones(hidden_size))  
    self.variance_epsilon = eps  
  
def forward(self, hidden_states):  
    input_dtype = hidden_states.dtype  
    hidden_states = hidden_states.to(torch.float32)  
    variance = hidden_states.pow(2).mean(-1, keepdim=True)  
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  
    return self.weight * hidden_states.to(input_dtype)

这里采用了 RMS(Root Mean Square) 归一化,其中 RMS 计算公式为:

RMS(x)=1n∑xi2RMS(x)=\sqrt{\frac{1}{n}\sum{x_i^2}}RMS(x)=n1​∑xi2​​

则 RMSNorm 归一化的计算公式为:

RMS(x)=xRMS(x)+ϵ∗WRMS(x)=\frac{x}{\sqrt{RMS(x)+\epsilon}} * WRMS(x)=RMS(x)+ϵ​x​∗W

加上一个小常数,确保分母不为零,保持数据稳定性。

旋转位置编码 LlamaRotaryEmbedding

  • 绝对位置编码:计算高效,效果欠佳
  • 相对位置编码:满足 NLP 领域在序列长度方向上具有平移不变性,计算效率低。
  • 旋转位置编码:采用绝对位置编码达到相位置编码的效果

transformers 中对于 LlamaRotaryEmbedding 类的定义如下,它用于实现旋转位置嵌入:

class LlamaRotaryEmbedding(nn.Module):  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):  
        super().__init__()  

        self.dim = dim  
        self.max_position_embeddings = max_position_embeddings  
        self.base = base  
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
        self.register_buffer("inv_freq", inv_freq, persistent=False)  

        # Build here to make `torch.jit.trace` work.  
        self._set_cos_sin_cache(  
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()  
        )  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)  

    def forward(self, x, seq_len=None):  
        # x: [bs, num_attention_heads, seq_len, head_size]  
        if seq_len > self.max_seq_len_cached:  
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)  

        return (  
            self.cos_cached[:seq_len].to(dtype=x.dtype),  
            self.sin_cached[:seq_len].to(dtype=x.dtype),  
        )

其中定义的变量意义如下:

  • dim:表示模型输出维度
  • max_position_embeddings:最大编码长度,默认为2048
  • base:基数,默认为10000

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 实现公式为:

inv_freq=1base2i/diminv\_freq=\frac{1}{base^{2i/dim}}inv_freq=base2i/dim1​

在上面代码中,t 的维度为[max_position_embeddings], inv_freq 的维度为[dim/2]。

经过 torch.einsum("i,j->ij", t, self.inv_freq) 之后维度为 [max_position_embeddings, dim/2]。

然后经过 emb = torch.cat((freqs, freqs), dim=-1) 操作,维度变为 [max_position_embeddings, dim]。

二维情况下旋转矩阵如下:

R(k)=(coskθ−sinkθsinkθcoskθ)R(k)=\begin{pmatrix} cosk\theta & -sink\theta \\ sink\theta & cosk\theta \\ \end{pmatrix}R(k)=(coskθsinkθ​−sinkθcoskθ​)

旋转位置编码计算公式如下:

R(k)x=(coskθ0coskθ0coskθ1coskθ1…coskθd/2−1coskθd/2−1)∘(x0x1x2x3…xd−2xd−1)+(sinkθ0sinkθ0sinkθ1sinkθ1…sinkθd/2−1sinkθd/2−1)∘(−x1x0−x3x2…−xd−1xd−2)R(k)x= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ cos{k\theta_1} \\ … \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ … \\ x_{d-2} \\ x_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ sin{k\theta_1} \\ … \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ … \\ -x_{d-1} \\ x_{d-2} \end{pmatrix} R(k)x=⎝⎛​coskθ0​coskθ0​coskθ1​coskθ1​…coskθd/2−1​coskθd/2−1​​⎠⎞​∘⎝⎛​x0​x1​x2​x3​…xd−2​xd−1​​⎠⎞​+⎝⎛​sinkθ0​sinkθ0​sinkθ1​sinkθ1​…sinkθd/2−1​sinkθd/2−1​​⎠⎞​∘⎝⎛​−x1​x0​−x3​x2​…−xd−1​xd−2​​⎠⎞​

在使用 LLM 时,我们希望对上下文长度进行拓展,以便能进行多轮对话,由此有下面几种方法:

外推法:直接沿用当前公式计算计算更长位置的编码。

这种方法比较简单,但是存在相关性衰减问题,如果模型训练语料在 2k 长度左右,模型能够学习到 2k 长度左右的 token 之间相关性关系的规律。

如果直接将此规律沿用到 5k 上下文,可能导致在中间某个位置相关性衰减到零,从而无法捕捉两个 token 之间的相关性。

线性内插:线性内插会改变编码公式,将 token 之间的距离等比例缩小。

例如在 2k 上下文情况下,两个 token 之间距离为 16,那么在 32k 上下文下,这两个 token 之间距离缩短到 1。

对于短距离的衰减规律,可能造成非常大的变化,但是线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。

transformers 中对于线性内插的实现如下:

class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
        t = t / self.scaling_factor  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,在 t = t / self.scaling_factor 这行代码中,除以一个缩放因子,从而达到线性缩放的效果。

动态 NTK 扩展:外推法对于长距离的 token 不能很好计算相关性,线性内插对于短距离 token 计算相关性会产生很大变化,因此可以综合两者进行扩展。

为了在短距离情况下具有外推特性,长距离情况下具有内插特性,可以设置一个与位置序号有关频率因子,动态调整。

transformers 中对于动态 NTK 扩展的实现如下:

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  

        if seq_len > self.max_position_embeddings:  
            base = self.base * (  
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)  
            ) ** (self.dim / (self.dim - 2))  
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
            self.register_buffer("inv_freq", inv_freq, persistent=False)  

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,如果长度超过 max_position_embeddings,对于 base 做出了如下公式操作:

base=base∗(factor∗seq_lenmax_len−(factor−1))dimdim−2base=base*(factor*\frac{seq\_len}{max\_len}-(factor-1))^{\frac{dim}{dim-2}}base=base∗(factor∗max_lenseq_len​−(factor−1))dim−2dim​

如果 seq_len > max_position_embeddings,在 factor = 1 的情况下,base 变大。

显然 base > 1,则 inv_freq 值变小,这样将短距离的规律扩展到了长距离。

具体计算位置编码的代码如下:

def rotate_half(x):  
    """Rotates half the hidden dims of the input."""  
    x1 = x[..., : x.shape[-1] // 2]  
    x2 = x[..., x.shape[-1] // 2 :]  
    return torch.cat((-x2, x1), dim=-1)  
  
  
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb  
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):  
    cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]  
    sin = sin[position_ids].unsqueeze(1)  
    q_embed = (q * cos) + (rotate_half(q) * sin)  
    k_embed = (k * cos) + (rotate_half(k) * sin)  
    return q_embed, k_embed

rotate_half() 中,将输入 x 沿着最后一维分隔成两部分,然后进行拼接。

Llama 中对 Q 的旋转位置编码按照如下方式计算:

R(k)Q=(coskθ0coskθ1…coskθd/2−1coskθ0coskθ1…coskθd/2−1)∘(q0q1…qd/2−1qd/2qd/2+1…qd−1)+(sinkθ0sinkθ1…sinkθd/2−1sinkθ0sinkθ1…sinkθd/2−1)∘(−qd/2−qd/2+1…−qd−1q0q1…qd−1)R(k)Q= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_1} \\ … \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ … \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} q_0 \\ q_1 \\ … \\ q_{d/2-1} \\ q_{d/2} \\ q_{d/2+1} \\ … \\ q_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_1} \\ … \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ … \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -q_{d/2} \\ -q_{d/2+1} \\ … \\ -q_{d-1} \\ q_0 \\ q_1 \\ … \\ q_{d-1} \end{pmatrix} R(k)Q=⎝⎛​coskθ0​coskθ1​…coskθd/2−1​coskθ0​coskθ1​…coskθd/2−1​​⎠⎞​∘⎝⎛​q0​q1​…qd/2−1​qd/2​qd/2+1​…qd−1​​⎠⎞​+⎝⎛​sinkθ0​sinkθ1​…sinkθd/2−1​sinkθ0​sinkθ1​…sinkθd/2−1​​⎠⎞​∘⎝⎛​−qd/2​−qd/2+1​…−qd−1​q0​q1​…qd−1​​⎠⎞​

这里只对 Q 和 K 加入位置编码信息。

前馈网络 LlamaMLP

transformers 中对于前馈网络的定义如下:

class LlamaMLP(nn.Module):  
    def __init__(self, config):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.intermediate_size = config.intermediate_size  
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  
        self.act_fn = ACT2FN[config.hidden_act]  
  
    def forward(self, x):  
        if self.config.pretraining_tp > 1:  
            slice = self.intermediate_size // self.config.pretraining_tp  
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)  
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)  

            gate_proj = torch.cat(  
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1  
            )  
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)  
            down_proj = [  
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)  
            ]  
            down_proj = sum(down_proj)  
        else:  
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))  

        return down_proj

__init__() 函数中,定义了 hidden_sizeintermediate_size 控制模型尺寸。

同时定义了三个全连接层:

  • gate_proj:将 hidden_size 投影到 intermediate_size
  • up_proj:将 hidden_size 投影到 intermediate_size
  • down_proj:将 intermediate_size 投影到 hidden_size

这里会将输入通过 up_proj 先从 hidden_size 维度转换到 intermediate_size 维度,然后通过 down_proj 从 intermediate_size 维度转换到 hidden_size 维度。

同时里面采用 gate_proj 配合激活函数,实现了一个门控作用。

forward() 函数中会根据 config.pretraining_tp 选择不同的执行策略。这里是将三个全连接层切分为若干块,分别与输入 x 进行映射操作,得到多个子投影,然后将多个子投影拼接起来。

多头注意力 LlamaAttention

transformers 中对于多头注意力的定义如下:

class LlamaAttention(nn.Module):  
    """Multi-headed attention from 'Attention Is All You Need' paper"""  

    def __init__(self, config: LlamaConfig):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.num_heads = config.num_attention_heads  
        self.head_dim = self.hidden_size // self.num_heads  
        self.num_key_value_heads = config.num_key_value_heads  
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  
        self.max_position_embeddings = config.max_position_embeddings  
        self.rope_theta = config.rope_theta  

        if (self.head_dim * self.num_heads) != self.hidden_size:  
        raise ValueError(  
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"  
            f" and `num_heads`: {self.num_heads})."  
        )  
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)  
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)  
        self._init_rope()  

    def _init_rope(self):  
        if self.config.rope_scaling is None:  
            self.rotary_emb = LlamaRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                base=self.rope_theta,  
            )  
        else:  
            scaling_type = self.config.rope_scaling["type"]  
            scaling_factor = self.config.rope_scaling["factor"]  
        if scaling_type == "linear":  
            self.rotary_emb = LlamaLinearScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        elif scaling_type == "dynamic":  
            self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        else:  
            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

这里主要定义了下面几种属性:

  • hidden_size:隐藏层的大小
  • num_heads:注意力头的数量
  • head_dim:每个注意力头的维度,它通过 hidden_size // num_heads 得到
  • num_key_value_heads:键值注意力头的数量
  • num_key_value_groups:键值注意力头分组数量,它通过 num_heads // num_key_value_heads 得到
  • rope_theta:即前面 RoPE 的 base

此外还定义了四个线性变换的全连接层,分别用于计算查询(Q)、键(K)、值(V)和输出(O)。

注意到键值注意力头的数量与查询注意力头的数量不同。

键值注意力头数量可以是查询注意力头数量的几分之一,这样可以减少参数规模。

多头注意力的计算代码如下:

def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: bool = False,  
    use_cache: bool = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  
    bsz, q_len, _ = hidden_states.size()  

    if self.config.pretraining_tp > 1:  
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  
        query_slices = self.q_proj.weight.split(  
            (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0  
        )  
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]  
        query_states = torch.cat(query_states, dim=-1)  

        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]  
        key_states = torch.cat(key_states, dim=-1)  

        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]  
        value_states = torch.cat(value_states, dim=-1)  

    else:  
        query_states = self.q_proj(hidden_states)  
        key_states = self.k_proj(hidden_states)  
        value_states = self.v_proj(hidden_states)  

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)  
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  

        kv_seq_len = key_states.shape[-2]  
    if past_key_value is not None:  
        kv_seq_len += past_key_value[0].shape[-2]  
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)  
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)  

    if past_key_value is not None:  
        # reuse k, v, self_attention  
        key_states = torch.cat([past_key_value[0], key_states], dim=2)  
        value_states = torch.cat([past_key_value[1], value_states], dim=2)  

        past_key_value = (key_states, value_states) if use_cache else None  

        key_states = repeat_kv(key_states, self.num_key_value_groups)  
        value_states = repeat_kv(value_states, self.num_key_value_groups)  

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)  

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):  
        raise ValueError(  
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"  
            f" {attn_weights.size()}"  
        )  

    if attention_mask is not None:  
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):  
            raise ValueError(  
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"  
            )  
        attn_weights = attn_weights + attention_mask  

    # upcast attention to fp32  
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
    attn_output = torch.matmul(attn_weights, value_states)  

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):  
    raise ValueError(  
    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"  
    f" {attn_output.size()}"  
    )  

    attn_output = attn_output.transpose(1, 2).contiguous()  

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)  

    if self.config.pretraining_tp > 1:  
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  
    else:  
        attn_output = self.o_proj(attn_output)  

    if not output_attentions:  
        attn_weights = None  

    return attn_output, attn_weights, past_key_value

多头注意力基本与《Attention Is All You Need》中一致,计算公式如下:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk​​QKT​)V

在 llama 中每进行一次注意力计算,都会对 Q 和 K 计算一次位置编码(RoPE)。

因为 K 和 V 注意力头数是 Q 的几分之一,所以在计算前首先进行 repeat 操作,对应代码如下:

key_states = repeat_kv(key_states, self.num_key_value_groups)  
value_states = repeat_kv(value_states, self.num_key_value_groups)

计算注意力的代码如下:

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

attn_weights = attn_weights + attention_mask # 可选操作

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
attn_output = torch.matmul(attn_weights, value_states)

最终 attn_output 经过 o_proj 的线性变换之后输出。

与前馈网络类似,如果 config 中设置 pretraining_tp,会对输入进行切片后分块操作。

解码层 LlamaDecoderLayer

transfromers 中对解码层的定义如下:

class LlamaDecoderLayer(nn.Module):  
    def __init__(self, config: LlamaConfig):  
    super().__init__()  
    self.hidden_size = config.hidden_size  
    self.self_attn = (  
        LlamaAttention(config=config)  
        if not getattr(config, "_flash_attn_2_enabled", False)  
        else LlamaFlashAttention2(config=config)  
    )  
    self.mlp = LlamaMLP(config)  
    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
    self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

解码层由 AttentionLayerMLP 和两个 LayerNorm 组成。

前向计算代码如下:

def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: Optional[bool] = False,  
    use_cache: Optional[bool] = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:  
    """  
    Args:  
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`  
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size  
        `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.  
        output_attentions (`bool`, *optional*):  
        Whether or not to return the attentions tensors of all attention layers. See `attentions` under  
        returned tensors for more detail.  
        use_cache (`bool`, *optional*):  
        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding  
        (see `past_key_values`).  
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states  
    """  

    residual = hidden_states  

    hidden_states = self.input_layernorm(hidden_states)  

    # Self Attention  
    hidden_states, self_attn_weights, present_key_value = self.self_attn(  
        hidden_states=hidden_states,  
        attention_mask=attention_mask,  
        position_ids=position_ids,  
        past_key_value=past_key_value,  
        output_attentions=output_attentions,  
        use_cache=use_cache,  
        padding_mask=padding_mask,  
    )  
    hidden_states = residual + hidden_states  

    # Fully Connected  
    residual = hidden_states  
    hidden_states = self.post_attention_layernorm(hidden_states)  
    hidden_states = self.mlp(hidden_states)  
    hidden_states = residual + hidden_states  

    outputs = (hidden_states,)  

    if output_attentions:  
        outputs += (self_attn_weights,)  

    if use_cache:  
        outputs += (present_key_value,)  

    return outputs

在解码器层中,输入 hidden_states 依次经历如下计算:

  1. 经过 input_layernorm 进行层归一化。
  2. 计算一次自注意力。
  3. 做一次残差连接。
  4. 经过 post_attention_layernorm 进行层归一化。
  5. 经过 mlp,并将结果与步骤3结果做一次残差连接。

模型 LlamaModel

transformers 中对模型定义如下:

class LlamaModel(LlamaPreTrainedModel):  
    """  
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]  

    Args:  
        config: LlamaConfig  
    """  

    def __init__(self, config: LlamaConfig):  
        super().__init__(config)  
        self.padding_idx = config.pad_token_id  
        self.vocab_size = config.vocab_size  

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)  
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])  
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  

        self.gradient_checkpointing = False  
        # Initialize weights and apply final processing  
        self.post_init()

Llama 模型是由若干个解码层堆叠而成。

在前向传播时设置 gradient_checkpointing=True 可以节约显存。

但是这个参数不能和 use_cache=True 同时设置,这两个参数不兼容。

if self.gradient_checkpointing and self.training:  
    if use_cache:  
        logger.warning_once(  
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."  
        )  
    use_cache = False

在前向传播中自定义了前向传播函数:

def create_custom_forward(module):  
    def custom_forward(*inputs):  
        # None for past_key_value  
        return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)  
  
    return custom_forward

使用 torch.utils.checkpoint.checkpoint() 函数,它允许将前向传播的一部分分成小块以减小内存占用,并且可以在反向传播时实现显存优化。前提是设置 gradient_checkpointing=True

layer_outputs = torch.utils.checkpoint.checkpoint(  
    create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids  
)

代码中的 decode_layer 为前文中提到的解码器层。

经过多层解码器层后,将输出经过 RMSNorm 层,得到最终结果。

语言模型 LlamaForCausalLM

transformers 中对语言模型的定义如下:

class LlamaForCausalLM(LlamaPreTrainedModel):  
    _tied_weights_keys = ["lm_head.weight"]  

    def __init__(self, config):  
        super().__init__(config)  
        self.model = LlamaModel(config)  
        self.vocab_size = config.vocab_size  
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  

        # Initialize weights and apply final processing  
        self.post_init()

实质是在前文提到的 LlamaModel 基础上加入一个 llm_head 来生成结果。

前向传播核心计算代码如下:

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)  
outputs = self.model(  
    input_ids=input_ids,  
    attention_mask=attention_mask,  
    position_ids=position_ids,  
    past_key_values=past_key_values,  
    inputs_embeds=inputs_embeds,  
    use_cache=use_cache,  
    output_attentions=output_attentions,  
    output_hidden_states=output_hidden_states,  
    return_dict=return_dict,  
)  
  
hidden_states = outputs[0]  
if self.config.pretraining_tp > 1:  
    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  
    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  
    logits = torch.cat(logits, dim=-1)  
else:  
    logits = self.lm_head(hidden_states)  
logits = logits.float()

如果输入 labels 会自动计算交叉熵损失。

分类模型 LlamaForSequenceClassification

分类模型也是由 LlamaModel 加上一个 score 的线性层构成。

在计算损失的时候,会根据不同的类型,选择不同的损失函数:

if self.config.problem_type == "regression":  
    loss_fct = MSELoss()  
    if self.num_labels == 1:  
        loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())  
    else:  
        loss = loss_fct(pooled_logits, labels)  
elif self.config.problem_type == "single_label_classification":  
    loss_fct = CrossEntropyLoss()  
    loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))  
elif self.config.problem_type == "multi_label_classification":  
    loss_fct = BCEWithLogitsLoss()  
    loss = loss_fct(pooled_logits, labels)

总结

以 LlamaModel 为例总结数据流向:

  • 输入的如果是 input_ids,会首先计算 inputs_embeds,然后作为 hidden_states,经过若干个 LlamaDecoderLayer、LlamaRMSNorm 组合后输出。
  • 在 LlamaDecoderLayer 中,经历如下步骤:
    1. 先记录原始输入,然后对于输入的 hidden_states 先做一次 LlamaRMSNorm。
    2. 对步骤1的结果做一次 LlamaAttention。
    3. 将步骤2的结果与原始输入做一次残差连接,并记录这次结果。
    4. 将步骤3的结果做一次 LlamaRMSNorm。
    5. 将步骤4的结果做一次 LlamaMLP。
    6. 将步骤5的结果与步骤3的结果做一次残差连接,将结果输出。
  • 在 LlamaAttention 中,经历如下步骤:
    1. 将输入的 hidden_states 做 Q、K、V 变换。
    2. 计算 Q、K 的旋转位置编码。
    3. 根据公式计算自注意力。
    4. 注意力经过线性变换后输出。
  • 在 LlamaMLP 中,经历如下步骤:
    1. 原始输入经过线性变换,得到上投影。
    2. 原始输入经过门函数和激活函数,得到门控投影。
    3. 将步骤1的上投影和步骤2的门控投影对应元素相乘。
    4. 将步骤3的结果经过线性变换,得到下投影,输出这个结果。

那么,我们该如何学习大模型?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

一、大模型全套的学习路线

学习大型人工智能模型,如GPT-3、BERT或任何其他先进的神经网络模型,需要系统的方法和持续的努力。既然要系统的学习大模型,那么学习路线是必不可少的,下面的这份路线能帮助你快速梳理知识,形成自己的体系。

L1级别:AI大模型时代的华丽登场

L2级别:AI大模型API应用开发工程

L3级别:大模型应用架构进阶实践

L4级别:大模型微调与私有化部署

一般掌握到第四个级别,市场上大多数岗位都是可以胜任,但要还不是天花板,天花板级别要求更加严格,对于算法和实战是非常苛刻的。建议普通人掌握到L4级别即可。

以上的AI大模型学习路线,不知道为什么发出来就有点糊,高清版可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。

相关推荐

  1. Transformer 论文阅读笔记

    2024-06-09 07:16:03       37 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-09 07:16:03       5 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-09 07:16:03       5 阅读
  3. 在Django里面运行非项目文件

    2024-06-09 07:16:03       4 阅读
  4. Python语言-面向对象

    2024-06-09 07:16:03       7 阅读

热门阅读

  1. Python怎么配置环境变量:深度探索与实战指南

    2024-06-09 07:16:03       20 阅读
  2. Python怎么调用JAR包:揭秘跨语言交互的奥秘

    2024-06-09 07:16:03       19 阅读
  3. Qt富文本查找

    2024-06-09 07:16:03       11 阅读
  4. KerasCV和KerasNLP:视觉和语言的增强

    2024-06-09 07:16:03       21 阅读
  5. 学习分享-声明式的 HTTP 客户端OpenFeign

    2024-06-09 07:16:03       18 阅读
  6. 程序员搞副业一些会用到的工具

    2024-06-09 07:16:03       19 阅读
  7. CSS基础

    2024-06-09 07:16:03       11 阅读