GAN生成对抗网络

概述

生成对抗网络(Generative Adversarial Network,GAN)是由Ian Goodfellow等人于2014年提出的一种深度学习架构。GAN通过两个神经网络的对抗训练来学习数据分布,能够生成高质量的合成数据。这种"对抗"的思想在机器学习领域引起了革命性的变化,被誉为"机器学习领域近10年来最有趣的想法"。

核心概念

1. 对抗训练机制

GAN的核心思想是通过两个网络的博弈来实现数据生成:

生成器(Generator)

$$ G(z): \mathcal{Z} \rightarrow \mathcal{X} $$

判别器(Discriminator)

$$ D(x): \mathcal{X} \rightarrow [0,1] $$

其中:

  • $z$ 是从先验分布 $p_z(z)$ 采样的噪声向量
  • $x$ 是真实数据样本
  • $G(z)$ 生成假数据样本
  • $D(x)$ 判断样本真假的概率

2. 目标函数

GAN的训练目标是一个极小极大博弈问题:

$$ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] $$

判别器的目标(最大化):

  • 对真实数据输出接近1:$\max \mathbb{E}_{x \sim p_{data}}[\log D(x)]$
  • 对生成数据输出接近0:$\max \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$

生成器的目标(最小化):

  • 欺骗判别器:$\min \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$

3. 纳什均衡

理论上,当达到纳什均衡时:

$$ p_g(x) = p_{data}(x) $$

此时最优判别器为:

$$ D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} = \frac{1}{2} $$

GAN的数学原理

1. 信息论基础

GAN的优化过程可以从信息论角度理解:

JS散度(Jensen-Shannon Divergence)

$$ JS(p_{data} \| p_g) = \frac{1}{2}KL(p_{data} \| \frac{p_{data} + p_g}{2}) + \frac{1}{2}KL(p_g \| \frac{p_{data} + p_g}{2}) $$

与目标函数的关系

$$ C(G) = \max_D V(G, D) = -\log(4) + 2 \cdot JS(p_{data} \| p_g) $$

2. 梯度计算

判别器梯度

$$ \nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^{m} [\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))] $$

生成器梯度(实际训练中的变体):

$$ \nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^{m} \log D(G(z^{(i)})) $$

3. 训练稳定性分析

模式崩塌(Mode Collapse): 当生成器只学会生成少数几种模式时,可以用以下指标衡量:

$$ \text{Mode Score} = \exp(\mathbb{E}_{x \sim p_g}[KL(p(y|x) \| p(y))]) $$

GAN架构变体

1. 深度卷积GAN(DCGAN)

DCGAN为GAN在图像生成领域的成功应用奠定了基础:

生成器架构原则

  • 使用转置卷积进行上采样
  • 使用批标准化(除输出层)
  • 使用ReLU激活函数(输出层使用Tanh)

判别器架构原则

  • 使用步长卷积进行下采样
  • 使用LeakyReLU激活函数
  • 不使用全连接层

2. 条件GAN(cGAN)

条件GAN引入额外信息来控制生成过程:

生成器

$$ G(z, y): \mathcal{Z} \times \mathcal{Y} \rightarrow \mathcal{X} $$

判别器

$$ D(x, y): \mathcal{X} \times \mathcal{Y} \rightarrow [0,1] $$

目标函数

$$ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x|y)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z|y)))] $$

3. Wasserstein GAN(WGAN)

WGAN使用Wasserstein距离替代JS散度:

Wasserstein-1距离

$$ W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y) \sim \gamma}[\|x - y\|] $$

对偶形式

$$ W(p_r, p_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim p_r}[f(x)] - \mathbb{E}_{x \sim p_g}[f(x)] $$

WGAN目标函数

$$ \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim p_{data}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))] $$

4. WGAN-GP(Gradient Penalty)

为了满足Lipschitz约束,WGAN-GP引入梯度惩罚:

$$ L = \mathbb{E}_{x \sim p_g}[D(x)] - \mathbb{E}_{x \sim p_r}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] $$

其中 $\hat{x} = \epsilon x + (1-\epsilon)G(z)$,$\epsilon \sim U[0,1]$。

实现示例

使用TensorFlow/Keras实现基础GAN

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

class GAN:
    def __init__(self, latent_dim=100, img_shape=(28, 28, 1)):
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        
        # 构建判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(
            optimizer=keras.optimizers.Adam(0.0002, 0.5),
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        
        # 构建生成器
        self.generator = self.build_generator()
        
        # 构建组合模型(用于训练生成器)
        z = layers.Input(shape=(self.latent_dim,))
        img = self.generator(z)
        
        # 训练生成器时,固定判别器
        self.discriminator.trainable = False
        validity = self.discriminator(img)
        
        self.combined = keras.Model(z, validity)
        self.combined.compile(
            optimizer=keras.optimizers.Adam(0.0002, 0.5),
            loss='binary_crossentropy'
        )
    
    def build_generator(self):
        model = keras.Sequential([
            layers.Dense(256, input_dim=self.latent_dim),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            
            layers.Dense(1024),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            
            layers.Dense(np.prod(self.img_shape), activation='tanh'),
            layers.Reshape(self.img_shape)
        ])
        
        noise = layers.Input(shape=(self.latent_dim,))
        img = model(noise)
        
        return keras.Model(noise, img)
    
    def build_discriminator(self):
        model = keras.Sequential([
            layers.Flatten(input_shape=self.img_shape),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(256),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(1, activation='sigmoid')
        ])
        
        img = layers.Input(shape=self.img_shape)
        validity = model(img)
        
        return keras.Model(img, validity)
    
    def train(self, X_train, epochs, batch_size=128, save_interval=50):
        # 标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        for epoch in range(epochs):
            # ---------------------
            #  训练判别器
            # ---------------------
            
            # 选择真实图像的随机批次
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            
            # 生成噪声
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            
            # 生成假图像
            gen_imgs = self.generator.predict(noise)
            
            # 训练判别器
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # ---------------------
            #  训练生成器
            # ---------------------
            
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            
            # 生成器想要判别器认为生成的图像是真实的
            g_loss = self.combined.train_on_batch(noise, valid)
            
            # 打印进度
            if epoch % 100 == 0:
                print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
            
            # 保存生成的图像样本
            if epoch % save_interval == 0:
                self.save_imgs(epoch)
    
    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        
        # 重新缩放图像到 [0, 1]
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(f"images/mnist_{epoch}.png")
        plt.close()

# 使用示例
if __name__ == '__main__':
    # 加载和预处理MNIST数据
    (X_train, _), (_, _) = keras.datasets.mnist.load_data()
    
    # 重新缩放到 [-1, 1]
    X_train = X_train / 127.5 - 1.
    X_train = np.expand_dims(X_train, axis=3)
    
    # 创建和训练GAN
    gan = GAN()
    gan.train(X_train, epochs=30000, batch_size=32, save_interval=1000)

使用PyTorch实现DCGAN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 自定义权重初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是Z,进入卷积
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入是(nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

class DCGAN:
    def __init__(self, dataloader, device, nz=100, lr=0.0002, beta1=0.5):
        self.dataloader = dataloader
        self.device = device
        self.nz = nz
        
        # 创建生成器和判别器
        self.netG = Generator(nz).to(device)
        self.netD = Discriminator().to(device)
        
        # 应用权重初始化
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)
        
        # 定义损失函数
        self.criterion = nn.BCELoss()
        
        # 创建用于可视化的固定噪声
        self.fixed_noise = torch.randn(64, nz, 1, 1, device=device)
        
        # 设置优化器
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, 0.999))
        
        # 设置标签
        self.real_label = 1.
        self.fake_label = 0.
    
    def train(self, num_epochs):
        img_list = []
        G_losses = []
        D_losses = []
        iters = 0
        
        print("Starting Training Loop...")
        for epoch in range(num_epochs):
            for i, data in enumerate(self.dataloader, 0):
                
                ############################
                # (1) 更新判别器网络: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                ## 用全部真实数据训练
                self.netD.zero_grad()
                # 格式化批次
                real_cpu = data[0].to(self.device)
                b_size = real_cpu.size(0)
                label = torch.full((b_size,), self.real_label, 
                                 dtype=torch.float, device=self.device)
                # 前向传播真实批次
                output = self.netD(real_cpu).view(-1)
                # 计算真实批次的损失
                errD_real = self.criterion(output, label)
                # 反向传播计算梯度
                errD_real.backward()
                D_x = output.mean().item()

                ## 用全部假数据训练
                # 生成噪声向量
                noise = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                # 用生成器生成假图像
                fake = self.netG(noise)
                label.fill_(self.fake_label)
                # 用判别器分类全部假批次
                output = self.netD(fake.detach()).view(-1)
                # 计算判别器在全部假批次上的损失
                errD_fake = self.criterion(output, label)
                # 计算这个批次的梯度
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                # 计算判别器的总误差
                errD = errD_real + errD_fake
                # 更新判别器
                self.optimizerD.step()

                ############################
                # (2) 更新生成器网络: maximize log(D(G(z)))
                ###########################
                self.netG.zero_grad()
                label.fill_(self.real_label)  # 假标签对生成器损失来说是真实的
                # 因为我们刚刚更新了判别器,所以再次执行前向传播
                output = self.netD(fake).view(-1)
                # 基于这个输出计算生成器的损失
                errG = self.criterion(output, label)
                # 计算生成器的梯度
                errG.backward()
                D_G_z2 = output.mean().item()
                # 更新生成器
                self.optimizerG.step()

                # 输出训练状态
                if i % 50 == 0:
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, num_epochs, i, len(self.dataloader),
                             errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

                # 保存损失以便后续绘制
                G_losses.append(errG.item())
                D_losses.append(errD.item())

                # 检查生成器的输出
                if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(self.dataloader)-1)):
                    with torch.no_grad():
                        fake = self.netG(self.fixed_noise).detach().cpu()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                iters += 1
        
        return img_list, G_losses, D_losses

训练技巧与优化

1. 训练稳定性改进

特征匹配(Feature Matching)

$$ \|f(x) - \mathbb{E}_{z \sim p_z}[f(G(z))]\|_2^2 $$

其中 $f(x)$ 是判别器中间层的特征。

历史平均

$$ \|\theta - \frac{1}{t} \sum_{i=1}^{t} \theta[i]\|^2 $$

单侧标签平滑: 将真实标签从1改为0.9,假标签保持0。

2. 学习率调度

# 指数衰减
def lr_schedule(epoch, initial_lr=0.0002, decay_rate=0.95, decay_steps=1000):
    return initial_lr * (decay_rate ** (epoch // decay_steps))

# 余弦退火
def cosine_annealing(epoch, T_max, eta_min=0, eta_max=0.0002):
    return eta_min + (eta_max - eta_min) * (1 + math.cos(math.pi * epoch / T_max)) / 2

3. 批次大小和架构选择

批次大小影响

  • 较大批次:更稳定的梯度,但可能降低多样性
  • 较小批次:更多随机性,有助于探索

架构设计原则

  • 使用LeakyReLU而非ReLU
  • 避免最大池化,使用步长卷积
  • 使用批标准化
  • 谨慎使用全连接层

评估指标

1. Inception Score (IS)

$$ IS(G) = \exp(\mathbb{E}_{x \sim p_g} D_{KL}(p(y|x) \| p(y))) $$

其中:

  • $p(y|x)$ 是给定生成图像的条件标签分布
  • $p(y)$ 是边际标签分布

2. Fréchet Inception Distance (FID)

$$ FID = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2}) $$

其中:

  • $\mu_r, \Sigma_r$ 是真实数据在Inception网络特征空间的均值和协方差
  • $\mu_g, \Sigma_g$ 是生成数据的对应统计量

3. Precision and Recall

Precision:生成样本的质量

$$ \text{Precision} = \frac{1}{M} \sum_{j=1}^{M} \mathbf{1}[\text{NN}_k(G_j, X_r) \subseteq B(G_j, \text{NN}_k(G_j, G))] $$

Recall:生成样本的多样性

$$ \text{Recall} = \frac{1}{N} \sum_{i=1}^{N} \mathbf{1}[\text{NN}_k(X_{r,i}, G) \subseteq B(X_{r,i}, \text{NN}_k(X_{r,i}, X_r))] $$

应用领域

1. 图像生成与编辑

超分辨率GAN

class SRGAN_Generator(nn.Module):
    def __init__(self, scale_factor=4):
        super(SRGAN_Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.prelu = nn.PReLU()
        
        # 残差块
        self.residual_blocks = nn.Sequential(*[ResidualBlock() for _ in range(16)])
        
        # 上采样层
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor//2),
            nn.PReLU(),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor//2),
            nn.PReLU()
        )
        
        self.conv_final = nn.Conv2d(64, 3, kernel_size=9, padding=4)
    
    def forward(self, x):
        x = self.prelu(self.conv1(x))
        residual = x
        x = self.residual_blocks(x)
        x = torch.add(x, residual)
        x = self.upsampling(x)
        x = self.conv_final(x)
        return torch.tanh(x)

风格迁移GAN

class StyleGAN_Generator(nn.Module):
    def __init__(self, z_dim=512, w_dim=512, img_resolution=1024, img_channels=3):
        super().__init__()
        self.mapping = MappingNetwork(z_dim, w_dim)
        self.synthesis = SynthesisNetwork(w_dim, img_resolution, img_channels)
    
    def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None):
        w = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
        img = self.synthesis(w)
        return img

2. 数据增强

医学图像增强

class MedicalGAN:
    def __init__(self):
        self.generator = self.build_medical_generator()
        self.discriminator = self.build_medical_discriminator()
    
    def build_medical_generator(self):
        # 专门为医学图像设计的生成器
        model = Sequential([
            Dense(256 * 8 * 8, input_dim=100),
            Reshape((8, 8, 256)),
            UpSampling2D(),
            Conv2D(128, kernel_size=3, padding="same"),
            BatchNormalization(momentum=0.8),
            Activation("relu"),
            UpSampling2D(),
            Conv2D(64, kernel_size=3, padding="same"),
            BatchNormalization(momentum=0.8),
            Activation("relu"),
            Conv2D(1, kernel_size=3, padding="same"),
            Activation("tanh")
        ])
        return model

3. 文本到图像生成

StackGAN架构: Stage I: 文本 → 低分辨率图像

$$ h_0 = F_0(z, \phi_t) $$

Stage II: 文本 + 低分辨率图像 → 高分辨率图像

$$ h_1 = F_1(h_0, \phi_t) $$

其中 $\phi_t$ 是文本特征,$F_0, F_1$ 是生成器网络。

高级GAN技术

1. Progressive GAN

Progressive GAN通过逐步增加分辨率来训练:

渐进式训练

$$ \alpha \cdot \text{old\_layer} + (1-\alpha) \cdot \text{new\_layer} $$

其中 $\alpha$ 从1逐渐减少到0。

2. Self-Attention GAN (SAGAN)

自注意力机制

$$ \text{Attention}(x) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim//8, 1)
        self.key_conv = nn.Conv2d(in_dim, in_dim//8, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
        
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        
        out = self.gamma*out + x
        return out

3. BigGAN

BigGAN通过大规模训练和架构创新实现高质量生成:

类别条件批标准化

$$ BN(h, y) = \gamma_y \frac{h - \mu}{\sigma} + \beta_y $$

截断技巧

$$ z' = \text{truncate}(z, \theta) $$

其中 $\theta$ 控制截断程度。

4. StyleGAN

StyleGAN引入风格控制和高质量图像生成:

自适应实例标准化(AdaIN)

$$ \text{AdaIN}(x_i, y) = y_{s,i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i} $$

风格混合

$$ w' = \begin{cases} w_1 & \text{if layer} < \text{crossover\_point} \\ w_2 & \text{otherwise} \end{cases} $$

常见问题与解决方案

1. 模式崩塌(Mode Collapse)

问题描述:生成器只生成少数几种模式的样本。

解决方案

  • 使用Unrolled GAN
  • 特征匹配
  • 小批次判别(Minibatch Discrimination)

小批次判别实现

class MinibatchDiscrimination(nn.Module):
    def __init__(self, input_features, output_features, intermediate_features):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.intermediate_features = intermediate_features
        
        self.T = nn.Parameter(torch.randn(input_features, output_features, intermediate_features))
    
    def forward(self, x):
        M = torch.mm(x, self.T.view(self.input_features, -1))
        M = M.view(-1, self.output_features, self.intermediate_features)
        
        diffs = M.unsqueeze(0) - M.unsqueeze(1)
        abs_diffs = torch.sum(torch.abs(diffs), dim=3)
        minibatch_features = torch.sum(torch.exp(-abs_diffs), dim=0)
        
        return torch.cat([x, minibatch_features], dim=1)

2. 训练不稳定

梯度监控

def monitor_gradients(model, max_grad_norm=10.0):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    
    if total_norm > max_grad_norm:
        print(f"Warning: Large gradient norm: {total_norm}")
    
    return total_norm

学习率平衡

$$ \eta_G = k \cdot \eta_D $$

通常 $k \in [0.5, 2.0]$。

3. 判别器过强

解决策略

  • 降低判别器学习率
  • 增加生成器训练频率
  • 使用噪声注入

噪声注入实现

def add_noise_to_discriminator(real_imgs, fake_imgs, noise_std=0.1):
    real_noise = torch.randn_like(real_imgs) * noise_std
    fake_noise = torch.randn_like(fake_imgs) * noise_std
    
    real_imgs_noisy = real_imgs + real_noise
    fake_imgs_noisy = fake_imgs + fake_noise
    
    return real_imgs_noisy, fake_imgs_noisy

实验设计与调优

1. 超参数搜索

import itertools
from sklearn.model_selection import ParameterGrid

param_grid = {
    'lr_g': [0.0001, 0.0002, 0.0004],
    'lr_d': [0.0001, 0.0002, 0.0004],
    'beta1': [0.0, 0.5, 0.9],
    'beta2': [0.999, 0.99, 0.9],
    'batch_size': [32, 64, 128],
    'z_dim': [100, 128, 256]
}

def hyperparameter_search(param_grid, train_data, num_epochs=100):
    best_fid = float('inf')
    best_params = None
    
    for params in ParameterGrid(param_grid):
        print(f"Testing parameters: {params}")
        
        # 创建模型
        gan = GAN(**params)
        
        # 训练
        gan.train(train_data, num_epochs)
        
        # 评估
        fid_score = calculate_fid(gan, train_data)
        
        if fid_score < best_fid:
            best_fid = fid_score
            best_params = params
            
        print(f"FID: {fid_score:.4f}")
    
    return best_params, best_fid

2. 模型验证

def evaluate_gan(generator, real_data_loader, device, num_samples=10000):
    """全面评估GAN模型"""
    generator.eval()
    
    # 生成样本
    generated_samples = []
    with torch.no_grad():
        for _ in range(num_samples // 64):  # 假设batch_size=64
            z = torch.randn(64, 100, device=device)
            fake_imgs = generator(z)
            generated_samples.append(fake_imgs.cpu())
    
    generated_samples = torch.cat(generated_samples, dim=0)
    
    # 计算各种指标
    is_score = inception_score(generated_samples)
    fid_score = calculate_fid(generated_samples, real_data_loader)
    precision, recall = precision_recall(generated_samples, real_data_loader)
    
    return {
        'inception_score': is_score,
        'fid_score': fid_score,
        'precision': precision,
        'recall': recall
    }

3. 损失分析

def analyze_training_dynamics(G_losses, D_losses):
    """分析训练动态"""
    import matplotlib.pyplot as plt
    
    # 计算移动平均
    window_size = 100
    G_smooth = np.convolve(G_losses, np.ones(window_size)/window_size, mode='valid')
    D_smooth = np.convolve(D_losses, np.ones(window_size)/window_size, mode='valid')
    
    # 绘制损失曲线
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(G_losses, alpha=0.3, label='Generator Loss')
    plt.plot(G_smooth, label='Generator Loss (Smoothed)')
    plt.legend()
    plt.title('Generator Loss Over Time')
    
    plt.subplot(1, 2, 2)
    plt.plot(D_losses, alpha=0.3, label='Discriminator Loss')
    plt.plot(D_smooth, label='Discriminator Loss (Smoothed)')
    plt.legend()
    plt.title('Discriminator Loss Over Time')
    
    plt.tight_layout()
    plt.show()
    
    # 分析收敛性
    recent_G = np.mean(G_losses[-1000:])
    recent_D = np.mean(D_losses[-1000:])
    
    print(f"Recent Generator Loss: {recent_G:.4f}")
    print(f"Recent Discriminator Loss: {recent_D:.4f}")
    print(f"Loss Ratio (G/D): {recent_G/recent_D:.4f}")
    
    # 检测模式崩塌
    G_variance = np.var(G_losses[-1000:])
    if G_variance < 0.01:
        print("Warning: Low generator loss variance - possible mode collapse")

前沿发展

1. 扩散模型 vs GAN

扩散模型优势

  • 训练更稳定
  • 生成质量更高
  • 理论基础更扎实

GAN的独特价值

  • 推理速度快
  • 实时生成能力
  • 对抗训练的哲学价值

2. 条件生成的新方向

CLIP-guided生成

$$ L_{CLIP} = -\cos(\text{CLIP}_{text}(prompt), \text{CLIP}_{image}(G(z))) $$

潜空间编辑

$$ z_{edited} = z + \alpha \cdot \text{direction} $$

3. 3D和多模态GAN

3D-aware生成

$$ I = \pi(G(z, \theta)) $$

其中 $\pi$ 是渲染函数,$\theta$ 是相机参数。

跨模态生成

$$ G: (z, c_{text}, c_{image}) \rightarrow x_{output} $$

总结

生成对抗网络自2014年诞生以来,已经成为深度学习领域最具影响力的技术之一。通过生成器和判别器的对抗训练,GAN能够学习复杂的数据分布并生成高质量的合成数据。

核心贡献

  1. 对抗训练范式:开创了通过竞争学习的新思路
  2. 无监督生成:无需大量标注数据即可学习生成
  3. 理论框架:建立了生成模型的博弈论基础
  4. 应用广泛:从图像生成到科学计算的广泛应用

技术要点

  1. 数学基础:理解极小极大博弈和纳什均衡
  2. 架构设计:掌握生成器和判别器的设计原则
  3. 训练技巧:学会处理训练不稳定和模式崩塌等问题
  4. 评估方法:熟悉IS、FID等评估指标

发展方向

  1. 理论完善:更深入的收敛性和稳定性分析
  2. 架构创新:结合Transformer、扩散模型等新技术
  3. 应用拓展:3D生成、科学计算、多模态生成
  4. 效率优化:模型压缩、加速推理、绿色AI

GAN的发展历程展现了深度学习领域的创新活力,其对抗训练的思想不仅推动了生成模型的发展,也为其他机器学习任务提供了新的视角和方法。