"""
手撕 Transformer - 第2步：多头注意力 (Multi-Head Attention)

核心思想: 将 Q, K, V 投影到 h 个低维子空间, 分别计算注意力, 再拼接
"""

import torch
import torch.nn as nn
import math


def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = torch.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        # 1. 线性投影
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        # 2. 拆分成多头: (B, seq, d_model) -> (B, seq, h, d_k) -> (B, h, seq, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        # 3. 每个头独立计算注意力
        attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
        # 4. 拼接: (B, h, seq, d_k) -> (B, seq, h*d_k) = (B, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # 5. 输出投影
        output = self.W_o(attn_output)
        return output, weights


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    batch_size = 1
    seq_len = 3
    d_model = 8
    num_heads = 2
    d_k = d_model // num_heads  # = 4

    X = torch.randn(batch_size, seq_len, d_model)

    print("=" * 60)
    print("Step 0: 输入")
    print("=" * 60)
    print(f"X   shape: ({batch_size}, {seq_len}, {d_model})")
    print(f"参数: d_model={d_model}, num_heads={num_heads}, d_k={d_k}")
    print(f"\nX =\n{X[0].detach().numpy().round(3)}")

    # 创建多头注意力模块
    mha = MultiHeadAttention(d_model, num_heads)

    # ========== Step 1: 线性投影 ==========
    print("\n" + "=" * 60)
    print("Step 1: 线性投影  Q = X @ W_Q,  K = X @ W_K,  V = X @ W_V")
    print("=" * 60)

    Q_proj = mha.W_q(X)
    K_proj = mha.W_k(X)
    V_proj = mha.W_v(X)
    print(f"W_Q shape: ({d_model}, {d_model})")
    print(f"Q = X @ W_Q  →  (1, {seq_len}, {d_model}) @ ({d_model}, {d_model}) = (1, {seq_len}, {d_model})")
    print(f"K = X @ W_K  →  (1, {seq_len}, {d_model})")
    print(f"V = X @ W_V  →  (1, {seq_len}, {d_model})")
    print(f"\nQ =\n{Q_proj[0].detach().numpy().round(3)}")

    # ========== Step 2: 拆分成多头 ==========
    print("\n" + "=" * 60)
    print(f"Step 2: 拆分成 {num_heads} 个头  (B, seq, d_model) -> (B, seq, h, d_k) -> (B, h, seq, d_k)")
    print("=" * 60)

    B = batch_size
    Q_heads = Q_proj.view(B, -1, num_heads, d_k).transpose(1, 2)
    K_heads = K_proj.view(B, -1, num_heads, d_k).transpose(1, 2)
    V_heads = V_proj.view(B, -1, num_heads, d_k).transpose(1, 2)
    print(f"Q_heads shape: {tuple(Q_heads.shape)}  (batch, heads, seq, d_k)")
    print(f"  即 {num_heads} 个头, 每个头有 {seq_len} 个 query 向量, 每个 {d_k} 维")
    print(f"\nHead 0 的 Q:\n{Q_heads[0, 0].detach().numpy().round(3)}")
    print(f"\nHead 1 的 Q:\n{Q_heads[0, 1].detach().numpy().round(3)}")

    # ========== Step 3: 每个头独立计算注意力 ==========
    print("\n" + "=" * 60)
    print("Step 3: 每个头独立计算缩放点积注意力")
    print("=" * 60)

    attn_output, weights = scaled_dot_product_attention(Q_heads, K_heads, V_heads)
    print(f"每个头: scores = Q @ K^T  →  (B, h, seq, d_k) @ (B, h, d_k, seq) = (B, h, seq, seq)")
    print(f"attn_output shape: {tuple(attn_output.shape)}")
    print(f"\nHead 0 注意力权重:\n{weights[0, 0].detach().numpy().round(3)}")
    print(f"Head 1 注意力权重:\n{weights[0, 1].detach().numpy().round(3)}")
    print(f"\n两个头学到了不同的关注模式!")

    # ========== Step 4: 拼接所有头 ==========
    print("\n" + "=" * 60)
    print(f"Step 4: 拼接所有头  (B, h, seq, d_k) -> (B, seq, h*d_k) = (B, seq, d_model)")
    print("=" * 60)

    concat = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, d_model)
    print(f"拼接后 shape: {tuple(concat.shape)}")
    print(f"  Head 0 输出: (1, {seq_len}, {d_k})")
    print(f"  Head 1 输出: (1, {seq_len}, {d_k})")
    print(f"  拼接后:      (1, {seq_len}, {d_k}+{d_k}) = (1, {seq_len}, {d_model})")
    print(f"\n拼接结果:\n{concat[0].detach().numpy().round(3)}")

    # ========== Step 5: 输出投影 ==========
    print("\n" + "=" * 60)
    print(f"Step 5: 输出投影  output = concat @ W_O  ({d_model}, {d_model}) -> ({d_model}, {d_model})")
    print("=" * 60)

    output = mha.W_o(concat)
    print(f"输出 shape: {tuple(output.shape)}")
    print(f"\n最终输出:\n{output[0].detach().numpy().round(3)}")

    # ========== 完整调用验证 ==========
    print("\n" + "=" * 60)
    print("验证: 使用模块的 forward 方法")
    print("=" * 60)

    output_fwd, weights_fwd = mha(X, X, X)  # 自注意力
    diff = (output - output_fwd).abs().max().item()
    print(f"手动逐步 vs forward 最大差异: {diff:.2e}  (应为 0)")

    total_params = sum(p.numel() for p in mha.parameters())
    print(f"总参数量: {total_params}")
    print(f"  W_q: {d_model}x{d_model} = {d_model*d_model}")
    print(f"  W_k: {d_model}x{d_model} = {d_model*d_model}")
    print(f"  W_v: {d_model}x{d_model} = {d_model*d_model}")
    print(f"  W_o: {d_model}x{d_model} = {d_model*d_model}")
