介绍

神经网络除了最后一部分,其余部分都是在做特征提取,如果目标数据集属于原数据集,那么训练好的模型就可以直接拿来用了.

实际上Chatgpt就是属于这种(一句话总结:让你的模型站在巨人的肩膀上……)

假设使用数据集已经训练出了一个模型,那这个对目标数据集仍有很好的适用性(最后一层因为label的原因,往往发生改变)

图片 1

微调的方式

训练

  • 是一个目标数据集上的正常训练任务,但使用更强的正则化
    • 使用更小的学习率
    • 使用更少的数据迭代
  • 源数据集远复杂于目标数据,通常微调效果更好

重用分类器权重

![](https://cdn.jsdelivr.net/gh/vanillaholic/image-bed@main/img/截屏2025-03-04 16.10.07.png)

固定一些层

  • 神经网络通常学习有层次的特征表示
    • 低层次的特征更加通用
    • 高层次的特征则更跟数据集相关
  • 可以固定底部一些层的参数,不参与
    更新
    • 更强的正则

代码

这里的代码比较复杂,所以没有全部放入文章中建议认真看

https://courses.d2l.ai/zh-v2/assets/notebooks/chapter_computer-vision/fine-tuning.slides.html

python
1
2
3
pretrained_net = torchvision.models.resnet18(pretrained=True)

pretrained_net.fc

Linear(in_features=512, out_features=1000, bias=True)

python
1
2
3
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) #输出种类变化,改成2
nn.init.xavier_uniform_(finetune_net.fc.weight);

微调模型:

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
param_group=True):
train_iter = torch.utils.data.DataLoader(
torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),
transform=train_augs),
batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(
torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),
transform=test_augs),
batch_size=batch_size)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss(reduction="none")
if param_group: #此处发生变化:如果这个参数存在
params_1x = [
param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{ #只有最后一层学习率发生变化
'params': params_1x}, {
'params': net.fc.parameters(),
'lr': learning_rate * 10}], lr=learning_rate,
weight_decay=0.001)
else: #如果没有True,则跟之前一样
trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=0.001)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
devices)

是否使用预训练的比较

截屏2025-03-04 16.27.13

总结

  • 微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来完成提升精度
  • 预训练模型质量很重要
  • 微调通常速度更快、精度更