Login
首页 > 精选好文 > AI大模型

Transformer解码器深度解析:掩码自注意力与编码器-解码器注意力

聚客AI 2025-06-13 12:25:56 人看过

本文深入剖析Transformer解码器的核心机制,通过数学原理、可视化图解和完整代码实现,详细讲解掩码自注意力和编码器-解码器注意力的工作原理及其在序列生成任务中的应用。


一、Transformer解码器架构概览

1.1 解码器整体结构

11.png


1.2 解码器与编码器关键差异

image.png


二、掩码自注意力(Masked Self-Attention)

2.1 掩码自注意力原理

掩码自注意力确保解码器在生成位置t时只能访问位置0t的信息,防止未来信息泄露。

数学表示:

image.png

其中$M$为掩码矩阵:

image.png


2.2 掩码矩阵生成与可视化

import torch
import numpy as np
import matplotlib.pyplot as plt
def generate_mask(seq_len):
    """生成下三角掩码矩阵"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    mask = mask.masked_fill(mask == 0, float('-inf'))
    return mask.masked_fill(mask == 1, float(0.0))
# 可视化掩码矩阵
seq_len = 8
mask = generate_mask(seq_len)
plt.figure(figsize=(8, 8))
plt.imshow(mask.numpy(), cmap='viridis')
plt.title('掩码自注意力矩阵')
plt.xlabel('Key位置')
plt.ylabel('Query位置')
plt.xticks(range(seq_len))
plt.yticks(range(seq_len))
# 添加矩阵值
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, f"{mask[i,j].item():.0f}", 
                 ha="center", va="center", color="w", fontsize=12)
plt.colorbar()
plt.show()

image.png

2.3 掩码自注意力实现

class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert self.head_dim * heads == embed_size, "嵌入维度必须是头数的整数倍"
        
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
    
    def forward(self, x, mask):
        # x: (batch_size, seq_len, embed_size)
        batch_size, seq_length, _ = x.size()
        
        # 线性变换
        V = self.values(x)  # (batch_size, seq_len, embed_size)
        K = self.keys(x)    # (batch_size, seq_len, embed_size)
        Q = self.queries(x) # (batch_size, seq_len, embed_size)
        
        # 分割多头
        V = V.view(batch_size, seq_length, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, seq_length, self.heads, self.head_dim).permute(0, 2, 1, 3)
        Q = Q.view(batch_size, seq_length, self.heads, self.head_dim).permute(0, 2, 1, 3)
        
        # 计算注意力能量 (QK^T)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / np.sqrt(self.head_dim)
        
        # 应用掩码
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))
        
        # 计算注意力权重
        attention = F.softmax(energy, dim=-1)
        
        # 加权求和
        out = torch.matmul(attention, V)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(batch_size, seq_length, self.embed_size)
        out = self.fc_out(out)
        
        return out, attention
# 测试掩码自注意力
embed_size = 128
heads = 8
seq_len = 6
batch_size = 2
# 创建输入和掩码
x = torch.randn(batch_size, seq_len, embed_size)
mask = generate_mask(seq_len).unsqueeze(0).unsqueeze(0)  # 增加批次和头维度
# 初始化模型
masked_attn = MaskedSelfAttention(embed_size, heads)
output, attn_weights = masked_attn(x, mask)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", attn_weights.shape)
# 可视化注意力权重 (第一个批次,第一个头)
plt.figure(figsize=(8, 8))
plt.imshow(attn_weights[0, 0].detach().numpy(), cmap='viridis')
plt.title('掩码自注意力权重 (头1)')
plt.xlabel('Key位置')
plt.ylabel('Query位置')
plt.xticks(range(seq_len))
plt.yticks(range(seq_len))
plt.colorbar()
plt.show()

掩码自注意力的关键特性:

因果性保证:确保当前位置只能关注之前位置

自回归支持:实现序列的逐步生成

并行计算:训练时所有位置同时计算(带掩码)

信息流控制:防止未来信息泄露


三、编码器-解码器注意力(Encoder-Decoder Attention)

3.1 交叉注意力原理

编码器-解码器注意力连接源序列和目标序列,使解码器可以动态关注源序列的相关部分。

数学表示:

image.png


3.2 交叉注意力实现

class EncoderDecoderAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
    
    def forward(self, x, encoder_output, src_mask=None):
        # x: 解码器输入 (batch_size, tgt_seq_len, embed_size)
        # encoder_output: 编码器输出 (batch_size, src_seq_len, embed_size)
        batch_size, tgt_len, _ = x.size()
        src_len = encoder_output.size(1)
        
        # 线性变换
        V = self.values(encoder_output)
        K = self.keys(encoder_output)
        Q = self.queries(x)
        
        # 分割多头
        V = V.view(batch_size, src_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, src_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        Q = Q.view(batch_size, tgt_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        
        # 计算能量
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / np.sqrt(self.head_dim)
        
        # 应用源序列掩码(如填充掩码)
        if src_mask is not None:
            energy = energy.masked_fill(src_mask.unsqueeze(1).unsqueeze(2) == 0, float('-1e20'))
        
        # 计算注意力权重
        attention = F.softmax(energy, dim=-1)
        
        # 加权求和
        out = torch.matmul(attention, V)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(batch_size, tgt_len, self.embed_size)
        out = self.fc_out(out)
        
        return out, attention
# 测试交叉注意力
src_seq_len = 10
tgt_seq_len = 6
# 创建模拟数据
encoder_output = torch.randn(batch_size, src_seq_len, embed_size)
decoder_input = torch.randn(batch_size, tgt_seq_len, embed_size)
# 源序列掩码(模拟填充位置)
src_mask = torch.ones(batch_size, src_seq_len)
src_mask[:, 8:] = 0  # 最后两个位置为填充
# 初始化模型
cross_attn = EncoderDecoderAttention(embed_size, heads)
output, attn_weights = cross_attn(decoder_input, encoder_output, src_mask=src_mask)
print("编码器输出形状:", encoder_output.shape)
print("解码器输入形状:", decoder_input.shape)
print("交叉注意力输出形状:", output.shape)
# 可视化注意力权重 (第一个批次,第一个头)
plt.figure(figsize=(10, 8))
plt.imshow(attn_weights[0, 0].detach().numpy(), cmap='viridis')
plt.title('编码器-解码器注意力权重 (头1)')
plt.xlabel('源序列位置')
plt.ylabel('目标序列位置')
plt.xticks(range(src_seq_len))
plt.yticks(range(tgt_seq_len))
plt.colorbar()
plt.show()

交叉注意力的核心作用:

对齐机制:自动学习源序列与目标序列的对应关系

信息融合:将源序列信息注入解码过程

上下文提取:动态聚焦相关源信息

解耦依赖:分离源序列和目标序列的处理


四、Transformer解码器层完整实现

4.1 解码器层结构

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, heads, ff_dim, dropout=0.1):
        super().__init__()
        # 掩码自注意力
        self.masked_self_attn = MaskedSelfAttention(d_model, heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        
        # 编码器-解码器注意力
        self.enc_dec_attn = EncoderDecoderAttention(d_model, heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model)
        )
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # 掩码自注意力 + 残差连接
        attn1, self_attn_weights = self.masked_self_attn(x, tgt_mask)
        x = x + self.dropout1(attn1)
        x = self.norm1(x)
        
        # 编码器-解码器注意力 + 残差连接
        attn2, cross_attn_weights = self.enc_dec_attn(x, encoder_output, src_mask)
        x = x + self.dropout2(attn2)
        x = self.norm2(x)
        
        # 前馈网络 + 残差连接
        ffn_out = self.ffn(x)
        x = x + self.dropout3(ffn_out)
        x = self.norm3(x)
        
        return x, self_attn_weights, cross_attn_weights
# 测试解码器层
d_model = 128
heads = 8
ff_dim = 512
# 创建输入和掩码
decoder_input = torch.randn(batch_size, tgt_seq_len, d_model)
encoder_output = torch.randn(batch_size, src_seq_len, d_model)
src_mask = torch.ones(batch_size, src_seq_len)  # 假设无填充
tgt_mask = generate_mask(tgt_seq_len).unsqueeze(0).unsqueeze(0)  # 增加批次和头维度
# 初始化解码器层
decoder_layer = TransformerDecoderLayer(d_model, heads, ff_dim)
output, self_attn, cross_attn = decoder_layer(
    decoder_input, 
    encoder_output, 
    src_mask, 
    tgt_mask
)
print("解码器输入形状:", decoder_input.shape)
print("解码器输出形状:", output.shape)

4.2 解码器层数据流

2525.png


五、Transformer解码器完整架构

image.png

5.1 多层解码器实现

class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.num_layers = num_layers
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        all_self_attn = []
        all_cross_attn = []
        
        for layer in self.layers:
            x, self_attn, cross_attn = layer(
                x, 
                encoder_output, 
                src_mask, 
                tgt_mask
            )
            all_self_attn.append(self_attn)
            all_cross_attn.append(cross_attn)
        
        return x, all_self_attn, all_cross_attn
# 构建完整解码器
num_layers = 6
decoder = TransformerDecoder(decoder_layer, num_layers)
# 测试解码器
output, all_self_attn, all_cross_attn = decoder(
    decoder_input,
    encoder_output,
    src_mask,
    tgt_mask
)
print("解码器输出形状:", output.shape)
print("自注意力权重列表长度:", len(all_self_attn))
print("交叉注意力权重列表长度:", len(all_cross_attn))

5.2 解码器与编码器集成

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, ff_dim):
        super().__init__()
        # 编码器部分
        self.encoder_embed = nn.Embedding(src_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, ff_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        
        # 解码器部分
        self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_decoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, ff_dim)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
        
        # 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 编码器
        src_emb = self.pos_encoder(self.encoder_embed(src))
        memory = self.encoder(src_emb, src_key_padding_mask=src_mask)
        
        # 解码器
        tgt_emb = self.pos_decoder(self.decoder_embed(tgt))
        output = self.decoder(
            tgt_emb, 
            memory, 
            tgt_mask=tgt_mask,
            memory_key_padding_mask=src_mask
        )
        
        # 输出预测
        return self.fc_out(output)
# 参数设置
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
ff_dim = 2048
# 初始化模型
model = Transformer(
    src_vocab_size, 
    tgt_vocab_size, 
    d_model, 
    nhead, 
    num_encoder_layers, 
    num_decoder_layers, 
    ff_dim
)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")


六、实战:机器翻译任务

6.1 数据准备(模拟)

# 模拟数据生成
src_data = torch.randint(1, src_vocab_size, (32, 20))  # 批量32, 源序列长20
tgt_data = torch.randint(1, tgt_vocab_size, (32, 15))  # 批量32, 目标序列长15
# 创建掩码
src_mask = (src_data != 0).float()  # 假设0为填充
tgt_mask = generate_mask(tgt_data.size(1))
# 前向传播
output = model(src_data, tgt_data[:, :-1], src_mask=src_mask, tgt_mask=tgt_mask)
print("模型输出形状:", output.shape)  # (batch_size, tgt_seq_len-1, tgt_vocab_size)

6.2 训练循环

criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充位置
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
def train_step(src, tgt):
    optimizer.zero_grad()
    
    # 创建目标输入和标签
    tgt_input = tgt[:, :-1]
    tgt_labels = tgt[:, 1:]
    
    # 创建掩码
    src_mask = (src != 0).float()
    tgt_mask = generate_mask(tgt_input.size(1))
    
    # 前向传播
    output = model(src, tgt_input, src_mask=src_mask, tgt_mask=tgt_mask)
    
    # 计算损失
    loss = criterion(output.reshape(-1, output.size(-1)), tgt_labels.reshape(-1))
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    return loss.item()
# 模拟训练
for step in range(1000):
    src = torch.randint(1, src_vocab_size, (32, 20))
    tgt = torch.randint(1, tgt_vocab_size, (32, 15))
    loss = train_step(src, tgt)
    
    if (step+1) % 100 == 0:
        print(f"Step {step+1}, Loss: {loss:.4f}")

6.3 推理过程(贪婪解码)

def greedy_decode(model, src, max_len=20, start_symbol=1):
    src_mask = (src != 0).float()
    src_emb = model.pos_encoder(model.encoder_embed(src))
    memory = model.encoder(src_emb, src_key_padding_mask=src_mask)
    
    # 初始化目标序列
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src)
    
    for i in range(max_len-1):
        # 创建目标掩码
        tgt_mask = generate_mask(ys.size(1))
        
        # 解码
        tgt_emb = model.pos_decoder(model.decoder_embed(ys))
        out = model.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        
        # 预测下一个词
        prob = model.fc_out(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        
        # 添加到序列
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src).fill_(next_word)], dim=1)
        
        # 遇到结束符停止
        if next_word == 2:  # 假设2是结束符
            break
    
    return ys
# 测试推理
src_sentence = torch.randint(1, src_vocab_size, (1, 10))  # 单个源序列
translated = greedy_decode(model, src_sentence)
print("源序列:", src_sentence)
print("翻译结果:", translated)


关键要点总结

掩码自注意力核心

# 创建因果掩码
mask = torch.tril(torch.ones(seq_len, seq_len))
mask = mask.masked_fill(mask == 0, float('-inf'))
# 应用掩码
energy = QK^T / sqrt(d_k) + mask

编码器-解码器注意力流程

22.png

解码器层结构

输入
│
├─> 掩码自注意力 → 残差连接 → 层归一化
│
├─> 编码器-解码器注意力 → 残差连接 → 层归一化
│
└─> 前馈网络 → 残差连接 → 层归一化
│
输出

训练与推理差异

image.png

超参数设置建议

d_model:512-1024(平衡性能与计算)

  • nhead:8-16(确保d_model可整除)

  • 层数:编码器/解码器各6层(基础模型)

  • ff_dim:通常为4×d_model

  • Dropout:0.1(防止过拟合)


通过掌握Transformer解码器的核心机制,你已经具备了构建现代序列生成模型的基础能力。更多AI大模型应用开发学习视频内容和资料,尽在聚客AI学院




版权声明:倡导尊重与保护知识产权。未经许可,任何人不得复制、转载、或以其他方式使用本站《原创》内容,违者将追究其法律责任。本站文章内容,部分图片来源于网络,如有侵权,请联系我们修改或者删除处理。

编辑推荐

热门文章

大厂标准培训
海量精品课程
汇聚优秀团队
打造完善体系
Copyright © 2023-2025 聚客AI 版权所有
网站备案号:湘ICP备2024094305号-1