当前位置: 首页 > article >正文

模型压缩与迁移:基于蒸馏技术的实战教程

1.前言

        模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型(通常称为学生模型,Student Model)上的技术。以下是模型蒸馏的介绍、出现原因及其作用:

(1)模型蒸馏的介绍

  1. 基本概念

    • 教师模型:一个已经训练好的、性能优异的大模型。
    • 学生模型:一个较小、较简单的模型,目标是学习教师模型的行为和知识。
    • 软标签(Soft Labels):教师模型输出的概率分布,而不是简单的类别标签,这些概率分布包含了教师模型关于输入数据的丰富信息。
  2. 训练过程

    • 训练教师模型直到它达到较高的准确率。
    • 使用教师模型的输出(软标签)来训练学生模型。
    • 学生模型同时学习硬标签(实际类别标签)和软标签,以此来模拟教师模型的行为。

(2)模型蒸馏为什么会出现

        模型蒸馏的出现主要是为了解决以下问题:

  1. 模型部署:大型模型在移动设备或嵌入式系统上部署时,由于计算资源有限,难以运行。
  2. 计算效率:大型模型在训练和推理过程中需要大量的计算资源,导致速度慢、成本高。
  3. 能源消耗:大型模型在数据中心运行时消耗大量电力,不符合节能减排的要求。

(3)模型蒸馏的作用

  1. 模型压缩:通过蒸馏,可以将大型模型压缩成小型模型,减少模型的参数数量,降低存储和计算需求。
  2. 性能保持:学生模型在保持较小规模的同时,能够尽可能地接近教师模型的性能。
  3. 加速推理:小型模型在推理时更快,适用于需要快速响应的应用场景。
  4. 降低能耗:小型模型在运行时消耗更少的计算资源,有助于降低能源消耗。
  5. 跨模型迁移:蒸馏技术可以用于将知识从一个领域的模型迁移到另一个领域,实现跨领域学习。

2.准备训练代码

(1) 定义模型结构

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

(2)训练代码

temperature

  • 这个参数用于调节教师模型和学生模型输出logits的软化程度。在代码中,temperature 被设置为 5.0。
  • 在蒸馏过程中,教师和学生的logits通过除以温度值来软化,这有助于在训练学生模型时更好地捕捉教师模型的概率分布。
  • 温度值较高时,概率分布更加平滑,有助于学生模型学习;温度值较低时,概率分布更尖锐,更接近硬标签。

loss_function

  • 这是一个用于计算蒸馏损失的函数,代码中使用的是 nn.KLDivLoss,它是Kullback-Leibler散度损失,用于测量两个概率分布之间的差异。
  • reduction='batchmean' 表示损失是通过对批次中的所有样本求平均来减少的。

student_loss_function

  • 这是用于计算学生模型在真实标签上的分类损失的函数,代码中使用的是 nn.CrossEntropyLoss,这是多分类问题中常用的损失函数。

loss 和 student_loss

  • loss 是蒸馏损失,它是通过比较软化后的学生logits和教师logits来计算的。
  • student_loss 是学生模型在真实标签上的分类损失。
  • 这两个损失通过加权平均组合起来,形成最终的训练损失,其中蒸馏损失和分类损失的权重都是0.5。

optimizer

  • 这是用于优化学生模型参数的优化器,代码中使用的是 optim.Adam,它是一种自适应学习率的优化算法。
  • params 是学生模型中需要优化的参数列表。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from torchvision import models
from model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))# image_path = os.path.join(data_root, "data_set", "flower_data")image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])print('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# Load teacher modelteacher_net = resnet34(num_classes=5).to(device)tearcher_model_weight_path = "resNet34.pth"assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)teacher_net.to(device)# Load student modelstudent_net = models.resnet18(pretrained=False)student_model_weight_path = "resnet18-f37072fd.pth"assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))student_net.fc = nn.Linear(student_net.fc.in_features, 5)student_net.to(device)# Distillation loss functionloss_function = nn.KLDivLoss(reduction='batchmean')student_loss_function = nn.CrossEntropyLoss()# Optimizer for the student modelparams = [p for p in student_net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = ('./distilled_ConvNet.pth')train_steps = len(train_loader)temperature = 5.0  # Temperature for distillationfor epoch in range(epochs):student_net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()teacher_logits = teacher_net(images.to(device))student_logits = student_net(images.to(device))# Soften the logitsteacher_logits = teacher_logits / temperaturestudent_logits = student_logits / temperature# Compute the distillation lossloss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)# Compute the classification lossstudent_loss = student_loss_function(student_logits, labels.to(device))# Combine lossesloss = 0.5 * loss + 0.5 * student_lossloss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)student_net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = student_net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(student_net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

(3)模型和数据集的下载链接

        包含resnet18模型和resnet34模型,class_indices.json,图像等相关数据

https://pan.baidu.com/s/1ZDCbichDcdaiAH6kxYNsIA

提取码: svv5 

3.自建模型训练使用蒸馏技术训练自建模型

(1)模型结构-model_10.py

import torch
from torch import nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()# 定义10层卷积self.conv_layers = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 输入通道数为3,输出通道数为32nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # 添加自适应平均池化层# 全连接层self.fc_layers = nn.Sequential(nn.Linear(512 * 1 * 1, 1024),  # 根据MaxPool的使用次数和输入图像大小计算得来的维度nn.ReLU(),nn.Linear(1024, 5)  # 输出层,5分类)def forward(self, x):x = self.conv_layers(x)x = self.adaptive_pool(x)  # 应用自适应池化x = x.view(x.size(0), -1)x = self.fc_layers(x)return x

(2)自建模型训练-train-10.py

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_10 import ConvNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = ConvNet()weights_path = "ConvNet.pth"assert os.path.exists(weights_path), f"File '{weights_path}' does not exist."# model.load_state_dict(torch.load(weights_path, map_location="cpu"))state_dict = torch.load(weights_path, map_location="cpu")net.load_state_dict(state_dict,strict=False)net.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = './ConvNet.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

  (3)训练结果

        训练60epoch后的结果,模型val_accuracy: 0.780已经是最高了

train epoch[1/30] loss:0.971: 100%|██████████| 207/207 [00:08<00:00, 24.01it/s]
valid epoch[1/30]: 100%|██████████| 23/23 [00:00<00:00, 31.44it/s]
[epoch 1] train_loss: 0.623  val_accuracy: 0.742
train epoch[2/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[2/30]: 100%|██████████| 23/23 [00:00<00:00, 33.18it/s]
[epoch 2] train_loss: 0.604  val_accuracy: 0.736
train epoch[3/30] loss:0.661: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[3/30]: 100%|██████████| 23/23 [00:00<00:00, 32.38it/s]
[epoch 3] train_loss: 0.614  val_accuracy: 0.723
train epoch[4/30] loss:0.797: 100%|██████████| 207/207 [00:07<00:00, 26.66it/s]
valid epoch[4/30]: 100%|██████████| 23/23 [00:00<00:00, 31.70it/s]
[epoch 4] train_loss: 0.619  val_accuracy: 0.725
train epoch[5/30] loss:0.809: 100%|██████████| 207/207 [00:07<00:00, 26.87it/s]
valid epoch[5/30]: 100%|██████████| 23/23 [00:00<00:00, 32.26it/s]
[epoch 5] train_loss: 0.594  val_accuracy: 0.698
train epoch[6/30] loss:0.302: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[6/30]: 100%|██████████| 23/23 [00:00<00:00, 32.49it/s]
[epoch 6] train_loss: 0.591  val_accuracy: 0.728
train epoch[7/30] loss:0.708: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[7/30]: 100%|██████████| 23/23 [00:00<00:00, 33.09it/s]
[epoch 7] train_loss: 0.589  val_accuracy: 0.720
train epoch[8/30] loss:0.709: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[8/30]: 100%|██████████| 23/23 [00:00<00:00, 32.55it/s]
[epoch 8] train_loss: 0.575  val_accuracy: 0.734
train epoch[9/30] loss:0.691: 100%|██████████| 207/207 [00:07<00:00, 26.61it/s]
valid epoch[9/30]: 100%|██████████| 23/23 [00:00<00:00, 34.43it/s]
[epoch 9] train_loss: 0.555  val_accuracy: 0.734
train epoch[10/30] loss:0.442: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[10/30]: 100%|██████████| 23/23 [00:00<00:00, 32.91it/s]
[epoch 10] train_loss: 0.548  val_accuracy: 0.703
train epoch[11/30] loss:0.363: 100%|██████████| 207/207 [00:07<00:00, 26.46it/s]
valid epoch[11/30]: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
[epoch 11] train_loss: 0.550  val_accuracy: 0.728
train epoch[12/30] loss:0.519: 100%|██████████| 207/207 [00:07<00:00, 26.19it/s]
valid epoch[12/30]: 100%|██████████| 23/23 [00:00<00:00, 33.14it/s]
[epoch 12] train_loss: 0.545  val_accuracy: 0.734
train epoch[13/30] loss:0.478: 100%|██████████| 207/207 [00:07<00:00, 26.48it/s]
valid epoch[13/30]: 100%|██████████| 23/23 [00:00<00:00, 32.75it/s]
[epoch 13] train_loss: 0.532  val_accuracy: 0.755
train epoch[14/30] loss:0.573: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[14/30]: 100%|██████████| 23/23 [00:00<00:00, 33.40it/s]
[epoch 14] train_loss: 0.542  val_accuracy: 0.747
train epoch[15/30] loss:0.595: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[15/30]: 100%|██████████| 23/23 [00:00<00:00, 34.54it/s]
[epoch 15] train_loss: 0.542  val_accuracy: 0.758
train epoch[16/30] loss:0.191: 100%|██████████| 207/207 [00:07<00:00, 26.83it/s]
valid epoch[16/30]: 100%|██████████| 23/23 [00:00<00:00, 32.04it/s]
[epoch 16] train_loss: 0.532  val_accuracy: 0.761
train epoch[17/30] loss:0.566: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[17/30]: 100%|██████████| 23/23 [00:00<00:00, 33.56it/s]
[epoch 17] train_loss: 0.523  val_accuracy: 0.739
train epoch[18/30] loss:0.509: 100%|██████████| 207/207 [00:07<00:00, 26.79it/s]
valid epoch[18/30]: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
[epoch 18] train_loss: 0.526  val_accuracy: 0.742
train epoch[19/30] loss:0.781: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[19/30]: 100%|██████████| 23/23 [00:00<00:00, 31.58it/s]
[epoch 19] train_loss: 0.506  val_accuracy: 0.764
train epoch[20/30] loss:0.336: 100%|██████████| 207/207 [00:07<00:00, 26.64it/s]
valid epoch[20/30]: 100%|██████████| 23/23 [00:00<00:00, 33.95it/s]
[epoch 20] train_loss: 0.537  val_accuracy: 0.764
train epoch[21/30] loss:0.475: 100%|██████████| 207/207 [00:07<00:00, 26.65it/s]
valid epoch[21/30]: 100%|██████████| 23/23 [00:00<00:00, 33.27it/s]
[epoch 21] train_loss: 0.511  val_accuracy: 0.764
train epoch[22/30] loss:0.513: 100%|██████████| 207/207 [00:07<00:00, 26.53it/s]
valid epoch[22/30]: 100%|██████████| 23/23 [00:00<00:00, 32.16it/s]
[epoch 22] train_loss: 0.482  val_accuracy: 0.761
train epoch[23/30] loss:0.172: 100%|██████████| 207/207 [00:07<00:00, 26.62it/s]
valid epoch[23/30]: 100%|██████████| 23/23 [00:00<00:00, 33.02it/s]
[epoch 23] train_loss: 0.501  val_accuracy: 0.761
train epoch[24/30] loss:1.127: 100%|██████████| 207/207 [00:07<00:00, 26.54it/s]
valid epoch[24/30]: 100%|██████████| 23/23 [00:00<00:00, 34.24it/s]
[epoch 24] train_loss: 0.492  val_accuracy: 0.755
train epoch[25/30] loss:0.905: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[25/30]: 100%|██████████| 23/23 [00:00<00:00, 30.22it/s]
[epoch 25] train_loss: 0.492  val_accuracy: 0.758
train epoch[26/30] loss:1.044: 100%|██████████| 207/207 [00:07<00:00, 26.75it/s]
valid epoch[26/30]: 100%|██████████| 23/23 [00:00<00:00, 33.86it/s]
[epoch 26] train_loss: 0.476  val_accuracy: 0.777
train epoch[27/30] loss:0.552: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[27/30]: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
[epoch 27] train_loss: 0.465  val_accuracy: 0.745
train epoch[28/30] loss:0.387: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[28/30]: 100%|██████████| 23/23 [00:00<00:00, 32.30it/s]
[epoch 28] train_loss: 0.482  val_accuracy: 0.769
train epoch[29/30] loss:0.251: 100%|██████████| 207/207 [00:07<00:00, 26.69it/s]
valid epoch[29/30]: 100%|██████████| 23/23 [00:00<00:00, 32.98it/s]
[epoch 29] train_loss: 0.466  val_accuracy: 0.777
train epoch[30/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.57it/s]
valid epoch[30/30]: 100%|██████████| 23/23 [00:00<00:00, 31.95it/s]
[epoch 30] train_loss: 0.467  val_accuracy: 0.780
Finished Training

(4)蒸馏训练

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model_10 import ConvNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))# image_path = os.path.join(data_root, "data_set", "flower_data")image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])print('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))teacher_net = resnet34(num_classes=5).to(device)tearcher_model_weight_path = "resNet34.pth"assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)teacher_net.to(device)# Load student modelstudent_net = ConvNet()student_model_weight_path = "ConvNet.pth"assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))student_net.to(device)# Distillation loss functionloss_function = nn.KLDivLoss(reduction='batchmean')student_loss_function = nn.CrossEntropyLoss()# Optimizer for the student modelparams = [p for p in student_net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = ('./distilled_ConvNet.pth')train_steps = len(train_loader)temperature = 5.0  # Temperature for distillationfor epoch in range(epochs):student_net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()teacher_logits = teacher_net(images.to(device))student_logits = student_net(images.to(device))# Soften the logitsteacher_logits = teacher_logits / temperaturestudent_logits = student_logits / temperature# Compute the distillation lossloss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)# Compute the classification lossstudent_loss = student_loss_function(student_logits, labels.to(device))# Combine lossesloss = 0.5 * loss + 0.5 * student_lossloss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)student_net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = student_net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(student_net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

       没有截屏,可以自己试试,测试了自建模型训练30epoch后接着蒸馏训练30epoch,val_accuracy可以到达0.81.

(5)模型文件

https://pan.baidu.com/s/1gVTJPvAQ3oDEZcGYoJvuLw

提取码: ddk5 

4.总结

        如果模型结果简单,可以使用蒸馏训练提升模型的准确性,当然要先训练一个教师模型.

相关文章:

模型压缩与迁移:基于蒸馏技术的实战教程

1.前言 模型蒸馏&#xff08;Model Distillation&#xff09;&#xff0c;又称为知识蒸馏&#xff08;Knowledge Distillation&#xff09;&#xff0c;是一种将大型、复杂的模型&#xff08;通常称为教师模型&#xff0c;Teacher Model&#xff09;的知识转移到小型、简单模型…...

XSS通关技巧

目录 第一关&#xff1a; 第二关&#xff1a; 第三关&#xff1a; 第四关&#xff1a; 第五关&#xff1a; 第六关&#xff1a; 第七关&#xff1a; 第八关&#xff1a; 第九关&#xff1a; 第十关&#xff1a; 第十一关&#xff1a; 第十二关&#xff1a; 第十三关&#xff1a…...

el-tree树多选,将选中的树对象中某个字段值改为true,并过滤出所有为true的对象,组成新的数组

功能实现&#xff1a; el-tree树多选&#xff0c;将选中的树对象中某个字段值改为true,并过滤出所有为true的对象&#xff0c;组成新的数组提交给后端 <template><div><!-- 树形菜单 --><el-tree:data"stageList"show-checkboxdefault-expand-…...

大文件版本管理git-lfs

1. 安装 Git Large File Storage (LFS) 是一个 开源的 Git 扩展&#xff0c;用于替换 Git 仓库中的大文件&#xff0c;用指针文件替代实际的大文件&#xff0c;可以在保持仓库轻量级的同时&#xff0c;有效地管理大型文件。 如果install提示失败&#xff0c;多试几次&#xf…...

Android RemoteViews:跨进程 UI 更新的奥秘与实践

目录 一、RemoteViews 的舞台:使用场景 (一)通知栏:动态交互的窗口 (二)桌面小部件:桌面上的动态名片 二、RemoteViews 的本质:定义与架构 (一)什么是 RemoteViews? (二)架构设计:层次分明的协作 (三)操作限制:能力边界在哪里? 三、RemoteViews 的引擎…...

es 3期 第27节-运用Script脚本实现复杂需求

#### 1.Elasticsearch是数据库&#xff0c;不是普通的Java应用程序&#xff0c;传统数据库需要的硬件资源同样需要&#xff0c;提升性能最有效的就是升级硬件。 #### 2.Elasticsearch是文档型数据库&#xff0c;不是关系型数据库&#xff0c;不具备严格的ACID事务特性&#xff…...

玩转python: 掌握Python常用库之数据分析pandas

pandas是Python中最强大的数据处理和分析库之一&#xff0c;它提供了高效、灵活且易于使用的数据结构&#xff0c;使数据清洗、转换和分析变得简单直观。 1、pandas库简介 pandas由Wes McKinney于2008年开发&#xff0c;名称来源于"panel data"&#xff08;面板数据…...

融合YOLO11与行为树的人机协作智能框架:动态工效学优化与自适应安全决策

人工智能技术要真正发挥其价值&#xff0c;必须与生产生活深度融合&#xff0c;为产业发展和人类生活带来实际效益。近年来&#xff0c;基于深度学习的机器视觉技术在工业自动化领域取得了显著进展&#xff0c;其中YOLO&#xff08;You Only Look Once&#xff09;算法作为一种…...

Postgresql源码(142)子查询提升pull_up_sublinks

1 案例 drop table t_fun01; create table t_fun01 (image_id numeric primary key, content_id varchar(50), file_code varchar(20)); create index idx3 on t_fun01(content_id); create index idx4 on t_fun01(file_code); insert into t_fun01 select t.i, t.i%10, t.i%1…...

通过 ECNWP 洋流、海浪可视化开发的方法和架构

第一篇:ECNWP 洋流、海浪可视化开发方法 1. 数据获取与预处理 ECNWP(Eastern Central North Pacific)区域的洋流和海浪数据通常来源于 NOAA、EMC 等权威机构。数据包括流速、流向、波高、波向、周期等。 代码示例:数据下载与预处理 import pandas as pd import numpy as…...

如何 编译 px4

DeepSeek-R1满血版 回答完成 深度思考完成 用时27秒&#xff0c;搜索全网29篇资料 以下是详细的PX4固件编译教程&#xff0c;综合多平台及常见问题解决方案&#xff1a; 一、环境准备 ‌系统要求‌ ‌Ubuntu 20.04/22.04‌&#xff08;推荐&#xff09;或 ‌MacOS 12‌ Windo…...

23种设计模式-桥接(Bridge)设计模式

桥接设计模式 &#x1f6a9;什么是桥接设计模式&#xff1f;&#x1f6a9;桥接设计模式的特点&#x1f6a9;桥接设计模式的结构&#x1f6a9;桥接设计模式的优缺点&#x1f6a9;桥接设计模式的Java实现&#x1f6a9;代码总结&#x1f6a9;总结 &#x1f6a9;什么是桥接设计模式…...

【黑皮书】 AVL树

目录 前言 一 AVL树的介绍 二 单旋转 二 双旋转 总结 前言 AVL树的学习 一 AVL树的介绍 AVL树是带有平衡条件的二叉查找树&#xff0c;这个平衡条件要持续保持&#xff0c;而且必须保证树的深度为O(log(N))最简单的想法就是要求左右子树具有相同的高度 一棵AVL树是…...

【机器学习】什么是决策树?

什么是决策树&#xff1f; 决策树是一种用于分类和回归问题的模型。它通过一系列的“决策”将数据逐步分裂&#xff0c;最终得出预测结果。可以把它看作是一个“树”&#xff0c;每个节点表示一个特征的判断&#xff0c;而每个分支代表了可能的判断结果&#xff0c;最终的叶子…...

【商城实战(74)】数据采集与整理,夯实电商运营基石

【商城实战】专栏重磅来袭&#xff01;这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建&#xff0c;运用 uniapp、Element Plus、SpringBoot 搭建商城框架&#xff0c;到用户、商品、订单等核心模块开发&#xff0c;再到性能优化、安全加固、多端适配&#xf…...

使用独立服务器的最佳方式指南

在寻找合适的主机服务方案时&#xff0c;可以考虑独立服务器&#xff0c;因为它拥有管理员权限以及更高的性能配置。在本指南中&#xff0c;我们将介绍独立服务器的多种用途&#xff0c;并分析为什么选择独立服务器可能是处理高性能、资源密集型应用和大流量网站的最佳方案。 搭…...

视频格式转换:畅享多平台无缝视频体验

视频格式转换&#xff1a;畅享多平台无缝视频体验 视频已成为我们日常生活中不可或缺的一部分&#xff0c;不论是工作中展示方案的演示&#xff0c;还是生活里记录美好瞬间的短片&#xff0c;视频的存在无处不在。然而&#xff0c;面对各类设备、平台对视频格式的不同要求&…...

【HTML 基础教程】HTML 属性

HTML 属性 属性是 HTML 元素提供的附加信息。 属性通常出现在 HTML 标签的开始标签中&#xff0c;用于定义元素的行为、样式、内容或其他特性。 属性总是以 name"value" 的形式写在标签内&#xff0c;name 是属性的名称&#xff0c;value 是属性的值。 HTML 属性 …...

爬虫问题整理(2025.3.27)

此时此刻&#xff0c;困扰我一天的两个问题终于得到了解决&#xff0c;在此分享给大家。 问题1&#xff1a;使用anaconda prompt无法进行pip安装&#xff0c;这里只是一个示例&#xff0c;实际安装任何模块都会出现类似报错。 解决办法&#xff1a;关掉梯子......没错&#xf…...

短信验证码安全需求设计

背景&#xff1a; 近期发现部分系统再短信充值频繁&#xff0c;发现存在恶意消耗短信额度现象&#xff0c;数据库表排查&#xff0c;发现大量非合法用户非法调用短信接口API导致额度耗尽。由于系统当初设计存在安全缺陷&#xff0c;故被不法分子进行利用&#xff0c;造成损失。…...

若依专题——基础应用篇

若依搭建 搭建后端项目 ① Git 克隆并初始化项目 ② MySQL 导入与配置 ③ 启动 Redis 搭建后端项目注意事项&#xff1f; ① 项目初始化慢&#xff0c;执行clean、package ② MySQL导入后&#xff0c;修改application-druid.yml ③ Redis有密码&#xff0c;修改ap…...

给AI装“记忆U盘“:LangChain记忆持久化入门指南

&#x1f9e0; 什么是记忆持久化&#xff1f; 想象AI对话就像和朋友聊天&#xff1a; ​普通模式&#xff1a;每次重启都忘记之前聊过什么​持久化模式&#xff1a;给AI配了个"记忆U盘"&#xff0c;聊天记录永不丢失 核心组件三件套 #mermaid-svg-ORm8cbBXsaRy2sZ…...

AI for CFD入门指南(传承版)

AI for CFD入门指南 前言适用对象核心目标基础准备传承机制 AI for CFDLibtorch的介绍与使用方法PytorchAutogluon MakefileVscodeOpenFOAMParaviewGambit 前言 适用对象 新加入课题组的硕士/博士研究生对AICFD交叉领域感兴趣的本科生实习生需要快速上手组内研究工具的合作研…...

DeepSeek+RAG局域网部署

已经有很多平台集成RAG模式&#xff0c;dify&#xff0c;cherrystudio等&#xff0c;这里通过AI辅助&#xff0c;用DS的API实现一个简单的RAG部署。框架主要技术栈是Chroma,langchain,streamlit&#xff0c;答案流式输出&#xff0c;并且对答案加上索引。支持doc,docx,pdf,txt。…...

JavaScript快速入门之函数

引言 总所周知&#xff0c;JavaScript是一个很随便的语言&#xff0c;因此&#xff0c;在学习它的语法的时候&#xff0c;我是和Java语法对比着学的&#xff0c;可能会有些绕 函数 方法&#xff1a;对象&#xff08;属性&#xff0c;方法&#xff09; 函数&#xff1a;放在对…...

Java中synchronized 和 Lock

1. synchronized 关键字 工作原理 对象锁&#xff1a;在Java中&#xff0c;每个对象都有一个与之关联的监视器锁&#xff08;monitor lock&#xff09;。当一个线程尝试进入由 synchronized 保护的代码块或方法时&#xff0c;它必须首先获取该对象的监视器锁。如果锁已经被其…...

Linux系统-ls命令

一、ls命令的定义 Linux ls命令&#xff08;英文全拼&#xff1a;list directory contents&#xff09;用于显示指定工作目录下之内容&#xff08;列出目前工作目录所含的文件及子目录)。 二、ls命令的语法 ls [选项] [目录或文件名] ls [-alrtAFR] [name...] 三、参数[选项…...

个人学习编程(3-24) 数据结构

括号的匹配&#xff1a; if((s[i]) && now() || (s[i]] && now[)){ #include <bits/stdc.h>using namespace std;int main() {char s[300];scanf("%s",&s);int i;int len strlen(s);stack <char> st;for (i 0; i < len; i){if(…...

.NET开源的智能体相关项目推荐

一、AntSK 由AIDotNet团队开发的人工智能知识库与智能体框架&#xff0c;支持多模型集成和离线部署能力。 核心能力&#xff1a; • 支持OpenAI、Azure OpenAI、星火、阿里灵积等主流大模型&#xff0c;以及20余种国产数据库&#xff08;如达梦&#xff09; • 内置语义内核&a…...

面试八股文--框架篇(SSM)

一、Spring框架 1、什么是spring Spring框架是一个开源的Java平台应用程序框架&#xff0c;由Rod Johnson于2003年首次发布。它提供了一种全面的编程和配置模型&#xff0c;用于构建现代化的基于Java的企业应用程序。Spring框架的核心特性包括依赖注入&#xff08;DI&#xf…...