"""
手撕 Transformer - 第5步：Transformer Block

编码器块: Self-Attention -> Add&Norm -> FFN -> Add&Norm
解码器块: Masked Self-Attention -> Add&Norm -> Cross-Attention -> Add&Norm -> FFN -> Add&Norm
"""

import torch
import torch.nn as nn
import math


# ---- 组件 ----

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = torch.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        Q = self.W_q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(attn_output), weights


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))


# ---- Transformer Block ----

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn_out, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn_out, _ = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + attn_out)
        cross_out, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + cross_out)
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        return x


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    batch_size = 1
    src_len = 4
    tgt_len = 3
    d_model = 8
    num_heads = 2
    d_ff = 16

    src = torch.randn(batch_size, src_len, d_model)
    tgt = torch.randn(batch_size, tgt_len, d_model)
    causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0)

    print("=" * 60)
    print("Encoder Block 详细计算过程")
    print("=" * 60)
    print(f"输入 x shape: (batch={batch_size}, src_len={src_len}, d_model={d_model})")

    enc_block = EncoderBlock(d_model, num_heads, d_ff)

    # --- Sublayer 1: Self-Attention ---
    print(f"\n--- Sublayer 1: Multi-Head Self-Attention ---")
    print(f"Q = K = V = x  (自注意力: 查询、键、值都来自同一输入)")
    attn_out, attn_weights = enc_block.self_attn(src, src, src)
    print(f"attn_output shape: {tuple(attn_out.shape)}")
    print(f"注意力权重 shape: {tuple(attn_weights.shape)}  (batch, heads, src_len, src_len)")

    # --- Add & Norm ---
    print(f"\n--- Add & Norm: norm1(x + attn_output) ---")
    x_after_norm1 = enc_block.norm1(src + attn_out)
    print(f"  残差: x + attn_output  ->  shape 不变")
    print(f"  LayerNorm: 归一化到均值~0, 方差~1")
    print(f"  输出 shape: {tuple(x_after_norm1.shape)}")

    # --- Sublayer 2: FFN ---
    print(f"\n--- Sublayer 2: Feed-Forward Network ---")
    ffn_out = enc_block.ffn(x_after_norm1)
    print(f"  输入: ({batch_size}, {src_len}, {d_model}) -> 升维 -> ReLU -> 降维 -> ({batch_size}, {src_len}, {d_model})")
    print(f"FFN 输出 shape: {tuple(ffn_out.shape)}")

    # --- Add & Norm ---
    print(f"\n--- Add & Norm: norm2(x + ffn_output) ---")
    x_after_norm2 = enc_block.norm2(x_after_norm1 + ffn_out)
    print(f"  输出 shape: {tuple(x_after_norm2.shape)}")

    # 验证
    enc_out = enc_block(src)
    diff = (enc_out - x_after_norm2).abs().max().item()
    print(f"\n手动逐步 vs forward 最大差异: {diff:.2e}")

    # ========== 解码器 ==========
    print("\n" + "=" * 60)
    print("Decoder Block 详细计算过程")
    print("=" * 60)
    print(f"目标输入 shape: (batch={batch_size}, tgt_len={tgt_len}, d_model={d_model})")
    print(f"编码器输出 shape: (batch={batch_size}, src_len={src_len}, d_model={d_model})")
    print(f"\n因果掩码 (下三角):\n{causal_mask[0].int().numpy()}")

    dec_block = DecoderBlock(d_model, num_heads, d_ff)

    # --- Sublayer 1: Masked Self-Attention ---
    print(f"\n--- Sublayer 1: Masked Self-Attention ---")
    print(f"Q = K = V = tgt  (自注意力, 但用因果掩码遮住未来位置)")
    msa_out, msa_weights = dec_block.masked_self_attn(tgt, tgt, tgt, causal_mask)
    print(f"输出 shape: {tuple(msa_out.shape)}")
    print(f"Head 0 注意力权重:\n{msa_weights[0, 0].detach().numpy().round(3)}")
    print(f"  (下三角: 每个 token 只关注自己和前面的 token)")

    x_dec1 = dec_block.norm1(tgt + msa_out)
    print(f"Add&Norm 后 shape: {tuple(x_dec1.shape)}")

    # --- Sublayer 2: Cross-Attention ---
    print(f"\n--- Sublayer 2: Cross-Attention ---")
    print(f"Q = x_dec1 (来自解码器)")
    print(f"K = V = enc_out (来自编码器)")
    print(f"  解码器的每个 token 去'查询'编码器的输出, 获取源序列信息")
    ca_out, ca_weights = dec_block.cross_attn(x_dec1, enc_out, enc_out)
    print(f"输出 shape: {tuple(ca_out.shape)}")
    print(f"注意力权重 shape: {tuple(ca_weights.shape)}  (tgt_len x src_len)")
    print(f"Head 0 注意力权重:\n{ca_weights[0, 0].detach().numpy().round(3)}")
    print(f"  (每行: 一个目标 token 对所有源 token 的注意力分布)")

    x_dec2 = dec_block.norm2(x_dec1 + ca_out)
    print(f"Add&Norm 后 shape: {tuple(x_dec2.shape)}")

    # --- Sublayer 3: FFN ---
    print(f"\n--- Sublayer 3: FFN + Add&Norm ---")
    dec_out = dec_block(tgt, enc_out, tgt_mask=causal_mask)
    print(f"最终输出 shape: {tuple(dec_out.shape)}")
