"""
手撕 Transformer - 第6步：完整 Tiny Transformer

从零拼出一个可以跑通前向传播的完整 Transformer 模型
"""

import torch
import torch.nn as nn
import math


# ============ 组件定义 ============

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = torch.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        Q = self.W_q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(attn_output), weights


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))


class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn_out, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn_out, _ = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + attn_out)
        cross_out, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + cross_out)
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        return x


# ============ 完整 Transformer ============

class TinyTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, num_heads=4,
                 d_ff=256, num_layers=2, max_len=100):
        super().__init__()
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.register_buffer('pos_encoding', self._get_positional_encoding(max_len, d_model))
        self.encoder_layers = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)

    def _get_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def encode(self, src):
        seq_len = src.size(1)
        x = self.src_embedding(src) * math.sqrt(self.d_model) + self.pos_encoding[:, :seq_len, :]
        for layer in self.encoder_layers:
            x = layer(x)
        return x

    def decode(self, tgt, enc_output):
        seq_len = tgt.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).to(tgt.device)
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model) + self.pos_encoding[:, :seq_len, :]
        for layer in self.decoder_layers:
            x = layer(x, enc_output, tgt_mask=causal_mask)
        return x

    def forward(self, src, tgt):
        enc_output = self.encode(src)
        dec_output = self.decode(tgt, enc_output)
        return self.output_proj(dec_output)


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    src_vocab_size = 50
    tgt_vocab_size = 50
    d_model = 16
    num_heads = 4
    d_ff = 64
    num_layers = 2
    batch_size = 1
    src_len = 5
    tgt_len = 4

    print("=" * 60)
    print("Tiny Transformer 完整数据流演示")
    print("=" * 60)

    model = TinyTransformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers)

    # 输入: token IDs
    src_ids = torch.tensor([[3, 7, 12, 5, 9]])   # 源序列: 5 个 token
    tgt_ids = torch.tensor([[2, 15, 8, 4]])        # 目标序列: 4 个 token

    print(f"\n--- 输入 ---")
    print(f"源序列 token IDs:   {src_ids[0].tolist()}  shape: {tuple(src_ids.shape)}")
    print(f"目标序列 token IDs: {tgt_ids[0].tolist()}  shape: {tuple(tgt_ids.shape)}")

    # ========== Step 1: Embedding + Positional Encoding ==========
    print(f"\n{'='*60}")
    print(f"Step 1: Embedding + 位置编码")
    print(f"{'='*60}")

    src_emb = model.src_embedding(src_ids)
    tgt_emb = model.tgt_embedding(tgt_ids)
    print(f"源 Embedding:   {tuple(src_ids.shape)} -> {tuple(src_emb.shape)}")
    print(f"  nn.Embedding({src_vocab_size}, {d_model}): 每个 token ID 查表得到 {d_model} 维向量")
    print(f"目标 Embedding: {tuple(tgt_ids.shape)} -> {tuple(tgt_emb.shape)}")

    seq_len = src_len
    src_pos = model.pos_encoding[:, :seq_len, :]
    src_input = src_emb * math.sqrt(d_model) + src_pos
    print(f"\n位置编码 shape: {tuple(src_pos.shape)}  (正弦/余弦, 不可学习)")
    print(f"最终输入 = Embedding * sqrt(d_model) + PositionalEncoding")
    print(f"  乘 sqrt(d_model) 的原因: 嵌入向量的范数 ~1, 位置编码范数 ~1,")
    print(f"  不缩放的话位置编码会相对太弱, 乘 sqrt({d_model})={math.sqrt(d_model):.1f} 平衡两者")
    print(f"编码器输入 shape: {tuple(src_input.shape)}")

    # ========== Step 2: Encoder ==========
    print(f"\n{'='*60}")
    print(f"Step 2: 编码器 ({num_layers} 层)")
    print(f"{'='*60}")

    enc_out = model.encode(src_ids)
    print(f"编码器输出 shape: {tuple(enc_out.shape)}")
    print(f"  每一层: Self-Attention -> Add&Norm -> FFN -> Add&Norm")
    print(f"  经过 {num_layers} 层后, 每个 token 已经融合了所有源 token 的信息")

    # ========== Step 3: Decoder ==========
    print(f"\n{'='*60}")
    print(f"Step 3: 解码器 ({num_layers} 层)")
    print(f"{'='*60}")

    dec_out = model.decode(tgt_ids, enc_out)
    print(f"解码器输出 shape: {tuple(dec_out.shape)}")
    print(f"  每一层: MaskedSelfAttn -> Add&Norm -> CrossAttn -> Add&Norm -> FFN -> Add&Norm")
    print(f"  掩码自注意力: token 只能看前面的 token (因果性)")
    print(f"  交叉注意力:   Q 来自解码器, K/V 来自编码器输出")

    # ========== Step 4: 输出投影 ==========
    print(f"\n{'='*60}")
    print(f"Step 4: 输出投影  logits = dec_output @ W_out + b_out")
    print(f"{'='*60}")

    logits = model.output_proj(dec_out)
    print(f"W_out shape: ({d_model}, {tgt_vocab_size})")
    print(f"logits = dec_output @ W_out  ->  (1, {tgt_len}, {d_model}) @ ({d_model}, {tgt_vocab_size}) = (1, {tgt_len}, {tgt_vocab_size})")
    print(f"logits shape: {tuple(logits.shape)}")
    print(f"  每个位置的 {d_model} 维向量 -> {tgt_vocab_size} 维 (词表大小)")
    print(f"  经过 softmax 即可得到下一个 token 的概率分布")

    # ========== Step 5: 完整 forward 验证 ==========
    print(f"\n{'='*60}")
    print(f"Step 5: 完整 forward 调用验证")
    print(f"{'='*60}")

    logits_fwd = model(src_ids, tgt_ids)
    diff = (logits - logits_fwd).abs().max().item()
    print(f"手动逐步 vs forward 最大差异: {diff:.2e}  (应为 0)")

    # ========== Step 6: 训练一步 ==========
    print(f"\n{'='*60}")
    print(f"Step 6: 模拟训练一步")
    print(f"{'='*60}")

    target = torch.tensor([[5, 12, 3, 8]])  # ground truth 下一个 token
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    logits = model(src_ids, tgt_ids)
    loss = criterion(logits.view(-1, tgt_vocab_size), target.view(-1))
    print(f"交叉熵损失: {loss.item():.4f}")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"反向传播 + 参数更新: 成功!")

    # 参数量统计
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n总参数量: {total_params:,}")
    print(f"  Embedding: {src_vocab_size}*{d_model} + {tgt_vocab_size}*{d_model} = {(src_vocab_size+tgt_vocab_size)*d_model}")
    print(f"  每层 Encoder: ~{sum(p.numel() for p in model.encoder_layers[0].parameters()):,}")
    print(f"  每层 Decoder: ~{sum(p.numel() for p in model.decoder_layers[0].parameters()):,}")
    print(f"  输出投影: {d_model}*{tgt_vocab_size} + {tgt_vocab_size} = {d_model*tgt_vocab_size + tgt_vocab_size}")
