"""
手撕 Transformer - 第1步：缩放点积注意力 (Scaled Dot-Product Attention)

核心公式: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
"""

import torch
import math


def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)

    # 第1步: Q * K^T  → 衡量每个 query 与每个 key 的相似度
    scores = torch.matmul(Q, K.transpose(-2, -1))
    # 第2步: 缩放
    scores = scores / math.sqrt(d_k)
    # 第3步: 掩码 (可选)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    # 第4步: Softmax → 注意力权重
    weights = torch.softmax(scores, dim=-1)
    # 第5步: 加权求和
    output = torch.matmul(weights, V)
    return output, weights


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    # ---------- 设置 ----------
    batch_size = 1   # 为演示清晰, batch=1
    seq_len = 3      # 3 个 token
    d_model = 4      # 每个 token 4 维
    d_k = 4          # Q/K 投影维度
    d_v = 4          # V 投影维度

    # 输入矩阵 X: 3 个 token, 每个 4 维
    X = torch.tensor([[1.0, 0.0, 1.0, 0.0],
                       [0.0, 2.0, 0.0, 2.0],
                       [1.0, 1.0, 1.0, 1.0]])  # (3, 4)

    # 投影矩阵 (随机初始化, 这里用固定值方便复现)
    W_Q = torch.randn(d_model, d_k)
    W_K = torch.randn(d_model, d_k)
    W_V = torch.randn(d_model, d_v)

    print("=" * 60)
    print("Step 0: 输入与投影矩阵")
    print("=" * 60)
    print(f"X   shape: {tuple(X.shape)}  (3 个 token, 每个 4 维)")
    print(f"W_Q shape: {tuple(W_Q.shape)}  (将 d_model={d_model} 投影到 d_k={d_k})")
    print(f"W_K shape: {tuple(W_K.shape)}")
    print(f"W_V shape: {tuple(W_V.shape)}")
    print(f"\nX =\n{X.numpy()}")

    # ========== Step 1: 计算 Q, K, V ==========
    print("\n" + "=" * 60)
    print("Step 1: 计算 Q, K, V — 线性投影")
    print("=" * 60)

    Q = torch.matmul(X, W_Q)  # (3, 4) @ (4, 4) = (3, 4)
    K = torch.matmul(X, W_K)  # (3, 4) @ (4, 4) = (3, 4)
    V = torch.matmul(X, W_V)  # (3, 4) @ (4, 4) = (3, 4)

    print(f"Q = X @ W_Q  →  ({seq_len}, {d_model}) @ ({d_model}, {d_k}) = ({seq_len}, {d_k})")
    print(f"K = X @ W_K  →  ({seq_len}, {d_model}) @ ({d_model}, {d_k}) = ({seq_len}, {d_k})")
    print(f"V = X @ W_V  →  ({seq_len}, {d_model}) @ ({d_model}, {d_v}) = ({seq_len}, {d_v})")
    print(f"\nQ =\n{Q.detach().numpy().round(3)}")
    print(f"\nK =\n{K.detach().numpy().round(3)}")
    print(f"\nV =\n{V.detach().numpy().round(3)}")

    # ========== Step 2: 计算注意力分数 Q @ K^T ==========
    print("\n" + "=" * 60)
    print("Step 2: 计算注意力分数  scores = Q @ K^T")
    print("=" * 60)

    scores_raw = torch.matmul(Q, K.transpose(-2, -1))  # (3, 4) @ (4, 3) = (3, 3)
    print(f"scores = Q @ K^T  →  ({seq_len}, {d_k}) @ ({d_k}, {seq_len}) = ({seq_len}, {seq_len})")
    print(f"\nscores (原始点积) =\n{scores_raw.detach().numpy().round(3)}")
    print(f"\n含义: scores[i][j] = token_i 的 query 与 token_j 的 key 的点积")
    print(f"  例: scores[0][1] = Q[0] · K[1] = {scores_raw[0, 1].item():.3f}")

    # ========== Step 3: 缩放 ==========
    print("\n" + "=" * 60)
    print(f"Step 3: 缩放  scores = scores / sqrt(d_k) = scores / sqrt({d_k}) = scores / {math.sqrt(d_k):.3f}")
    print("=" * 60)

    scores_scaled = scores_raw / math.sqrt(d_k)
    print(f"\nscores (缩放后) =\n{scores_scaled.detach().numpy().round(3)}")
    print(f"\n为什么缩放? d_k={d_k}, 点积方差为 {d_k}, 除以 sqrt({d_k}) 后方差归一为 1,")
    print(f"防止 softmax 输入值过大导致梯度饱和。")

    # ========== Step 4: Softmax ==========
    print("\n" + "=" * 60)
    print("Step 4: Softmax 归一化  → 注意力权重")
    print("=" * 60)

    weights = torch.softmax(scores_scaled, dim=-1)  # (3, 3)
    print(f"weights = softmax(scores)  →  ({seq_len}, {seq_len})")
    print(f"\n注意力权重 =\n{weights.detach().numpy().round(3)}")
    print(f"\n每行之和 (应为 1.0): {weights.sum(dim=-1).detach().numpy().round(5)}")
    print(f"\n含义: weights[i][j] = token_i 分配给 token_j 的注意力比例")
    print(f"  例: token 0 对各 token 的注意力: {weights[0].detach().numpy().round(3)}")

    # ========== Step 5: 加权求和 ==========
    print("\n" + "=" * 60)
    print("Step 5: 加权求和  output = weights @ V")
    print("=" * 60)

    output = torch.matmul(weights, V)  # (3, 3) @ (3, 4) = (3, 4)
    print(f"output = weights @ V  →  ({seq_len}, {seq_len}) @ ({seq_len}, {d_v}) = ({seq_len}, {d_v})")
    print(f"\n输出 =\n{output.detach().numpy().round(3)}")
    print(f"\n含义: output[i] = 所有 token 的 value 的加权和, 权重由 token_i 的注意力决定")
    print(f"  例: output[0] = weights[0][0]*V[0] + weights[0][1]*V[1] + weights[0][2]*V[2]")

    # ========== 因果掩码演示 ==========
    print("\n" + "=" * 60)
    print("附: 因果掩码 (Causal Mask) — 解码器用")
    print("=" * 60)

    causal_mask = torch.tril(torch.ones(seq_len, seq_len))
    print(f"因果掩码 (下三角矩阵):\n{causal_mask.int().numpy()}")
    print(f"含义: token_i 只能看到 token_0 ... token_i, 不能看到未来")

    scores_masked = scores_scaled.masked_fill(causal_mask == 0, float('-inf'))
    print(f"\n掩码后的 scores:\n{scores_masked.detach().numpy()}")
    print(f"  (被遮住的位置填 -inf, softmax 后变为 0)")

    weights_masked = torch.softmax(scores_masked, dim=-1)
    print(f"\n掩码后的注意力权重:\n{weights_masked.detach().numpy().round(3)}")
    print(f"每行之和: {weights_masked.sum(dim=-1).detach().numpy().round(5)}")
