"""
手撕 Transformer - 第3步：LayerNorm 与残差连接

LayerNorm: 对每个样本在特征维度上归一化
残差连接: y = F(x) + x, 保证梯度直通
"""

import torch
import torch.nn as nn


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    d_model = 4
    x = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                       [10.0, 20.0, 30.0, 40.0]])  # (2, 4)

    print("=" * 60)
    print("LayerNorm 详细计算过程")
    print("=" * 60)
    print(f"输入 x shape: {tuple(x.shape)}  (2 个样本, 每个 {d_model} 维)")
    print(f"\nx =\n{x.numpy()}")

    # ========== Step 1: 计算均值 ==========
    mean = x.mean(dim=-1, keepdim=True)
    print(f"\n--- Step 1: 计算均值 (沿最后一维) ---")
    print(f"mean = x.mean(dim=-1)  →  (2, 1)")
    print(f"  样本 0: mean = (1+2+3+4)/4 = {mean[0].item():.1f}")
    print(f"  样本 1: mean = (10+20+30+40)/4 = {mean[1].item():.1f}")
    print(f"mean =\n{mean.numpy()}")

    # ========== Step 2: 计算方差 ==========
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    print(f"\n--- Step 2: 计算方差 (沿最后一维) ---")
    print(f"var = x.var(dim=-1, unbiased=False)  →  (2, 1)")
    print(f"  样本 0: var = [(1-2.5)^2 + (2-2.5)^2 + (3-2.5)^2 + (4-2.5)^2] / 4 = {var[0].item():.3f}")
    print(f"  样本 1: var = [(10-25)^2 + (20-25)^2 + (30-25)^2 + (40-25)^2] / 4 = {var[1].item():.3f}")
    print(f"var =\n{var.numpy()}")

    # ========== Step 3: 归一化 ==========
    eps = 1e-5
    x_norm = (x - mean) / torch.sqrt(var + eps)
    print(f"\n--- Step 3: 归一化  x_norm = (x - mean) / sqrt(var + eps) ---")
    print(f"  样本 0: x_norm = (x - {mean[0].item():.1f}) / sqrt({var[0].item():.3f} + {eps})")
    print(f"x_norm =\n{x_norm.detach().numpy().round(3)}")
    print(f"\n验证: x_norm 的均值应为 0, 方差应为 1")
    print(f"  均值: {x_norm.mean(dim=-1).detach().numpy().round(5)}")
    print(f"  方差: {x_norm.var(dim=-1, unbiased=False).detach().numpy().round(5)}")

    # ========== Step 4: 仿射变换 ==========
    ln = LayerNorm(d_model)
    output = ln(x)
    print(f"\n--- Step 4: 仿射变换  output = gamma * x_norm + beta ---")
    print(f"gamma (可学习缩放): {ln.gamma.data.numpy()}")
    print(f"beta  (可学习偏移): {ln.beta.data.numpy()}")
    print(f"output =\n{output.detach().numpy().round(3)}")

    # ========== 与 PyTorch 内置对比 ==========
    print(f"\n--- 与 PyTorch nn.LayerNorm 对比 ---")
    ln_torch = nn.LayerNorm(d_model)
    ln_torch.weight.data = ln.gamma.data.clone()
    ln_torch.bias.data = ln.beta.data.clone()
    out_torch = ln_torch(x)
    diff = (output - out_torch).abs().max().item()
    print(f"手写 vs PyTorch 最大差异: {diff:.2e}")

    # ========== 残差连接演示 ==========
    print("\n" + "=" * 60)
    print("残差连接  y = LayerNorm(x + sublayer(x))")
    print("=" * 60)

    sublayer_output = torch.tensor([[0.5, -0.3, 0.1, -0.2],
                                     [1.0, -1.0, 0.5, -0.5]])  # (2, 4)
    print(f"输入 x:\n{x.numpy()}")
    print(f"\n子层输出 (如注意力或 FFN 的输出):\n{sublayer_output.numpy()}")

    x_residual = x + sublayer_output
    print(f"\n--- 残差连接: x + sublayer_output ---")
    print(f"结果:\n{x_residual.numpy()}")
    print(f"  梯度可以通过 '+' 直接流回, 不经过子层权重 -> 缓解梯度消失")

    output_ln = ln(x_residual)
    print(f"\n--- LayerNorm(x + sublayer_output) ---")
    print(f"最终输出:\n{output_ln.detach().numpy().round(3)}")
