"""
手撕 Transformer - 第4步：前馈网络 (Feed-Forward Network)

FFN(x) = ReLU(x W1 + b1) W2 + b2

每个位置独立应用相同的两层全连接网络
"""

import torch
import torch.nn as nn


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(self.relu(self.linear1(x))))


# ============ 详细计算过程演示 ============
if __name__ == "__main__":
    torch.manual_seed(42)

    batch_size = 1
    seq_len = 3
    d_model = 4
    d_ff = 8  # 通常是 d_model 的 4 倍

    x = torch.randn(batch_size, seq_len, d_model)

    ffn = FeedForward(d_model, d_ff, dropout=0.0)  # dropout=0 方便演示

    print("=" * 60)
    print("FFN 详细计算过程")
    print("=" * 60)
    print(f"输入 x shape: (batch={batch_size}, seq={seq_len}, d_model={d_model})")
    print(f"隐藏层维度: d_ff={d_ff}")
    print(f"\nx =\n{x[0].detach().numpy().round(3)}")

    # ========== Step 1: 第一层线性变换 (升维) ==========
    print("\n" + "=" * 60)
    print(f"Step 1: 第一层线性变换 (升维)  h = x @ W1 + b1")
    print("=" * 60)

    h_linear = ffn.linear1(x)
    print(f"W1 shape: ({d_model}, {d_ff})")
    print(f"b1 shape: ({d_ff},)")
    print(f"h = x @ W1 + b1  →  (1, {seq_len}, {d_model}) @ ({d_model}, {d_ff}) + ({d_ff},) = (1, {seq_len}, {d_ff})")
    print(f"\nh (线性变换后) =\n{h_linear[0].detach().numpy().round(3)}")
    print(f"\n含义: 每个 token 从 {d_model} 维升到 {d_ff} 维, 到更高维空间做特征加工")

    # ========== Step 2: ReLU 激活 ==========
    print("\n" + "=" * 60)
    print("Step 2: ReLU 激活  h = ReLU(h) = max(0, h)")
    print("=" * 60)

    h_relu = ffn.relu(h_linear)
    print(f"h (ReLU 后) =\n{h_relu[0].detach().numpy().round(3)}")

    # 统计激活比例
    total = h_relu.numel()
    active = (h_relu > 0).sum().item()
    print(f"\n激活神经元比例: {active}/{total} = {active/total:.1%}")
    print(f"ReLU 引入稀疏性: 负值被置零, 正值保留 -> 引入非线性")

    # ========== Step 3: 第二层线性变换 (降维) ==========
    print("\n" + "=" * 60)
    print(f"Step 3: 第二层线性变换 (降维)  output = h @ W2 + b2")
    print("=" * 60)

    output = ffn.linear2(h_relu)
    print(f"W2 shape: ({d_ff}, {d_model})")
    print(f"b2 shape: ({d_model},)")
    print(f"output = h @ W2 + b2  →  (1, {seq_len}, {d_ff}) @ ({d_ff}, {d_model}) + ({d_model},) = (1, {seq_len}, {d_model})")
    print(f"\noutput =\n{output[0].detach().numpy().round(3)}")
    print(f"\n含义: 从 {d_ff} 维降回 {d_model} 维, 可以与残差连接相加")

    # ========== 完整流程验证 ==========
    print("\n" + "=" * 60)
    print("验证: 使用模块的 forward 方法 (dropout=0)")
    print("=" * 60)

    output_fwd = ffn(x)
    diff = (output - output_fwd).abs().max().item()
    print(f"手动逐步 vs forward 最大差异: {diff:.2e}  (应为 0)")

    # 参数量统计
    total_params = sum(p.numel() for p in ffn.parameters())
    print(f"\n参数量统计:")
    print(f"  linear1: W1({d_model}x{d_ff}) + b1({d_ff}) = {d_model * d_ff + d_ff}")
    print(f"  linear2: W2({d_ff}x{d_model}) + b2({d_model}) = {d_ff * d_model + d_model}")
    print(f"  总计: {total_params}")
