当前位置: 首页 > news >正文

基于ResNet34的花朵分类

一.数据集准备

新建一个项目文件夹ResNet,并在里面建立data_set文件夹用来保存数据集,在data_set文件夹下创建新文件夹"flower_data",点击链接下载花分类数据集https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz,会下载一个压缩包,将它解压到flower_data文件夹下,执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val。

 split.py如下:

import os
from shutil import copy, rmtree
import randomdef mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹在重新创建rmtree(file_path)os.makedirs(file_path)def main():# 保证随机可复现random.seed(0)# 将数据集中10%的数据划分到验证集中split_rate = 0.1# 指向你解压后的flower_photos文件夹cwd = os.getcwd()data_root = os.path.join(cwd, "flower_data")origin_flower_path = os.path.join(data_root, "flower_photos")assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)flower_class = [cla for cla in os.listdir(origin_flower_path)if os.path.isdir(os.path.join(origin_flower_path, cla))]# 建立保存训练集的文件夹train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(train_root, cla))# 建立保存验证集的文件夹val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(val_root, cla))for cla in flower_class:cla_path = os.path.join(origin_flower_path, cla)images = os.listdir(cla_path)num = len(images)# 随机采样验证集的索引eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 将分配至验证集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 将分配至训练集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")if __name__ == '__main__':main()

之后会在文件夹下生成train和val数据集,到此,完成了数据集的准备。

 二.定义网络

新建model.py,参照ResNet的网络结构和pytorch官方给出的代码,对代码进行略微的修改即可,首先定义了两个类BasicBlock和Bottleneck,分别对应着ResNet18、34和ResNet50、101、152,从下面这个图就可以区别开来。

可见,18和34层的网络,他们的conv2_x,conv3_x,conv4_x,conv5_x是相同的,不同的是每一个block的数量([2 2 2 2]和[3 4 6 3]),50和101和152层的网络,多了1*1卷积核,block数量也不尽相同。

接着定义了ResNet类,进行前向传播。对于34层的网络(这里借用了知乎牧酱老哥的图,18和34的block相同,所以用18的进行讲解),conv2_x和conv3_x对应的残差块对应的残差快在右侧展示出来(可以注意一下stride),当计算特征图尺寸时,要特别注意。在下方代码计算尺寸的部分我都进行了注释。

pytorch官方ResNet代码

修改后的train.py:

import torch.nn as nn
import torchclass BasicBlock(nn.Module):    #18 34层残差结构, 残差块expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)    # 不为none,对应虚线残差结构(下需要1*1卷积调整维度),为none,对应实线残差结构(不需要1*1卷积)out = self.conv1(x)        out = self.bn1(out)out = self.relu(out)out = self.conv2(out)     out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):    #50 101 152层残差结构"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None, groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groups# squeeze channelsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1, stride=1, bias=False)  self.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------# unsqueeze channelsself.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, kernel_size=1, stride=1, bias=False)  self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:  # 不为none,对应虚线残差结构(下需要1*1卷积调整维度),为none,对应实线残差结构(不需要1*1卷积)identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, blocks_num, num_classes=1000, include_top=True, groups=1, width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_group# (channel height width)self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False) # (3 224 224) -> (64 112 112)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)         # (64 112 112) -> (64 56 56)# 对于每一个block,第一次的两个卷积层stride=1和1,第二次stride=1和1self.layer1 = self._make_layer(block, 64, blocks_num[0])                # (64 56 56) -> (64 56 56)# 对于每一个block,第一次的两个卷积层stride=2和1,第二次stride=1和1self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)     # (64 56 56) -> (128 28 28)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)     # (128 28 28) -> (256 14 14)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)     # (256 28 28) -> (512 14 14)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):           # channel为当前block所使用的卷积核个数downsample = None if stride != 1 or self.in_channel != channel * block.expansion:   # 18和32不满足判断条件,会跳过;50 101 152会执行这部分downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)if __name__ == "__main__":resnet = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=5)in_data = torch.randn(1, 3, 224, 224)out = resnet(in_data)print(out)

完成网络的定义之后,可以单独执行一下这个文件,用来验证网络定义的是否正确。如果可以正确输出,就没问题。

在这里输出为

tensor([[-0.4490,  0.5792, -0.5026, -0.6024,  0.1399]],
grad_fn=<AddmmBackward0>)

说明网络定义正确。

三.开始训练

 加载数据集

首先定义一个字典,用于用于对train和val进行预处理,包括裁剪成224*224大小,训练集随机水平翻转(一般验证集不需要此操作),转换成张量,图像归一化。

然后利用DataLoader模块加载数据集,并设置batch_size为16,同时,设置数据加载器的工作进程数nw,加快速度。

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdmfrom model import resnet34def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"using {device} device.")data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 获取数据集路径image_path = os.path.join(os.getcwd(), "data_set", "flower_data")assert os.path.exists(image_path), f"{image_path} path does not exist."# 加载数据集,准备读取train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"])validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"])nw = min([os.cpu_count(), 16 if 16 > 1 else 0, 8])  # number of workers,加速图像预处理print(f'Using {nw} dataloader workers every process')# 加载数据集train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=nw)validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=16, shuffle=False, num_workers=nw)train_num = len(train_dataset)val_num = len(validate_dataset)print(f"using {train_num} images for training, {val_num} images for validation.")

生成json文件

将训练数据集的类别标签转换为字典格式,并将其写入名为'class_indices.json'的文件中。

  1. train_dataset中获取类别标签到索引的映射关系,存储在flower_list变量中。
  2. 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dict,其中键是原始类别标签,值是对应的索引。
  3. 使用json.dumps()函数将cla_dict转换为JSON格式的字符串,设置缩进为4个空格。
  4. 使用with open()语句以写入模式打开名为'class_indices.json'的文件,并将JSON字符串写入文件
   # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 雏菊 蒲公英 玫瑰 向日葵 郁金香# 从训练集中获取类别标签到索引的映射关系,存储在flower_list变量flower_list = train_dataset.class_to_idx# 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dictcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)

加载预训练模型开始训练

首先定义网络对象net,在这里我们使用了迁移学习来使网络训练效果更好;使用net.fc = nn.Linear(in_channel, 5)设置输出类别数(这里为5);训练10轮,并使用train_bar = tqdm(train_loader, file=sys.stdout)来可视化训练进度条,之后再进行反向传播和参数更新;同时,每一轮训练完成都要进行学习率更新;之后开始对验证集进行计算精确度,完成后保存模型。

    # load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthnet = resnet34()model_weight_path = "./resnet34-pre.pth"assert os.path.exists(model_weight_path), f"file {model_weight_path} does not exist."net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# change fc layer structurein_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=0.0001)epochs = 10best_acc = 0.0train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f}"# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = f"valid epoch[{epoch + 1}/{epochs}]"val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net, "./resnet.pth")print('Finished Training')

最后对代码进行整理,完整的train.py如下

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdmfrom model import resnet34def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"using {device} device.")data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 获取数据集路径image_path = os.path.join(os.getcwd(), "data_set", "flower_data")assert os.path.exists(image_path), f"{image_path} path does not exist."# 加载数据集,准备读取train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"])validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"])nw = min([os.cpu_count(), 16 if 16 > 1 else 0, 8])  # number of workers,加速图像预处理print(f'Using {nw} dataloader workers every process')# 加载数据集train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=nw)validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=16, shuffle=False, num_workers=nw)train_num = len(train_dataset)val_num = len(validate_dataset)print(f"using {train_num} images for training, {val_num} images for validation.")# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 雏菊 蒲公英 玫瑰 向日葵 郁金香# 从训练集中获取类别标签到索引的映射关系,存储在flower_list变量flower_list = train_dataset.class_to_idx# 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dictcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)# load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthnet = resnet34()model_weight_path = "./resnet34-pre.pth"assert os.path.exists(model_weight_path), f"file {model_weight_path} does not exist."net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# change fc layer structurein_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=0.0001)epochs = 10best_acc = 0.0train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f}"# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = f"valid epoch[{epoch + 1}/{epochs}]"val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net, "./resnet.pth")print('Finished Training')if __name__ == '__main__':main()

四.模型预测

新建一个predict.py文件用于预测,将输入图像处理后转换成张量格式,img = torch.unsqueeze(img, dim=0)是在输入图像张量 img 的第一个维度上增加一个大小为1的维度,因此将图像张量的形状从 [通道数, 高度, 宽度 ] 转换为 [1, 通道数, 高度, 宽度]。然后加载模型进行预测,并打印出结果,同时可视化。

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import resnet34def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg = Image.open("./2536282942_b5ca27577e.jpg")plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimension# 在输入图像张量 img 的第一个维度上增加一个大小为1的维度# 将图像张量的形状从 [通道数, 高度, 宽度 ] 转换为 [1, 通道数, 高度, 宽度]img = torch.unsqueeze(img, dim=0)# read class_indictwith open('./class_indices.json', "r") as f:class_indict = json.load(f)# create modelmodel = resnet34(num_classes=5).to(device)model = torch.load("./resnet34.pth")# predictionmodel.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_class = torch.argmax(predict).numpy()print_result = f"class: {class_indict[str(predict_class)]}   prob: {predict[predict_class].numpy():.3}"plt.title(print_result)for i in range(len(predict)):print(f"class: {class_indict[str(i)]:10}   prob: {predict[i].numpy():.3}")plt.show()if __name__ == '__main__':main()

预测结果

五.模型可视化

将生成的pth文件导入netron工具,可视化结果为

发现很不清晰,因此将它转换成多用于嵌入式设备部署的onnx格式

编写onnx.py

import torch
import torchvision
from model import resnet34device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = resnet34(num_classes=5).to(device)
model=torch.load("/home/lm/Resnet/resnet34.pth")
model.eval()
example = torch.ones(1, 3, 244, 244)
example = example.to(device)
torch.onnx.export(model, example, "resnet34.onnx", verbose=True, opset_version=11)

  将生成的onnx文件导入,这样的可视化清晰了许多

 六.批量数据预测

现在新建一个dta文件夹,里面放入五类带预测的样本,编写代码完成对整个文件夹下所有样本的预测,即批量预测。

batch_predict.py如下:

import os
import jsonimport torch
from PIL import Image
from torchvision import transformsfrom model import resnet34def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image# 指向需要遍历预测的图像文件夹imgs_root = "./data/imgs"# 读取指定文件夹下所有jpg图像路径img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]# read class_indictjson_file = open('./class_indices.json', "r")class_indict = json.load(json_file)# create modelmodel = resnet34(num_classes=5).to(device)model = torch.load("./resnet34.pth")# predictionmodel.eval()batch_size = 8  # 每次预测时将多少张图片打包成一个batchwith torch.no_grad():for ids in range(0, len(img_path_list) // batch_size):img_list = []for img_path in img_path_list[ids * batch_size: (ids + 1) * batch_size]:img = Image.open(img_path)img = data_transform(img)img_list.append(img)# batch img# 将img_list列表中的所有图像打包成一个batchbatch_img = torch.stack(img_list, dim=0)# predict classoutput = model(batch_img.to(device)).cpu()predict = torch.softmax(output, dim=1)probs, classes = torch.max(predict, dim=1)for idx, (pro, cla) in enumerate(zip(probs, classes)):print(f"image: {img_path_list[ids*batch_size+idx]}  class: {class_indict[str(cla.numpy())]}  prob: {pro.numpy():.3}")if __name__ == '__main__':main()

运行之后,输出

image: ./data/imgs/455728598_c5f3e7fc71_m.jpg  class: dandelion  prob: 0.989
image: ./data/imgs/3464015936_6845f46f64.jpg  class: dandelion  prob: 0.999
image: ./data/imgs/3461986955_29a1abc621.jpg  class: dandelion  prob: 0.996
image: ./data/imgs/8223949_2928d3f6f6_n.jpg  class: dandelion  prob: 0.991
image: ./data/imgs/10919961_0af657c4e8.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/10443973_aeb97513fc_m.jpg  class: dandelion  prob: 0.906
image: ./data/imgs/8475758_4c861ab268_m.jpg  class: dandelion  prob: 0.805
image: ./data/imgs/3857059749_fe8ca621a9.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/2457473644_5242844e52_m.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/146023167_f905574d97_m.jpg  class: dandelion  prob: 0.998
image: ./data/imgs/2502627784_4486978bcf.jpg  class: dandelion  prob: 0.488
image: ./data/imgs/2481428401_bed64dd043.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/13920113_f03e867ea7_m.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/2535769822_513be6bbe9.jpg  class: dandelion  prob: 0.997
image: ./data/imgs/3954167682_128398bf79_m.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/2516714633_87f28f0314.jpg  class: dandelion  prob: 0.998
image: ./data/imgs/2634665077_597910235f_m.jpg  class: dandelion  prob: 0.996
image: ./data/imgs/3502447188_ab4a5055ac_m.jpg  class: dandelion  prob: 0.999
image: ./data/imgs/425800274_27dba84fac_n.jpg  class: dandelion  prob: 0.422
image: ./data/imgs/3365850019_8158a161a8_n.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/674407101_57676c40fb.jpg  class: dandelion  prob: 1.0
image: ./data/imgs/2628514700_b6d5325797_n.jpg  class: dandelion  prob: 0.999
image: ./data/imgs/3688128868_031e7b53e1_n.jpg  class: dandelion  prob: 0.962
image: ./data/imgs/2502613166_2c231b47cb_n.jpg  class: dandelion  prob: 1.0

完成预期功能(这里我的样本都是dandelion,当然混合的也可以)

七.模型改进

当不加载预训练模型,而从头开始训练的话,当epoch为50时,经实际训练,准确率为80%多,但当加载预训练模型时,完成第一次迭代准确率就已达到了90%,这也正说明了迁移学习的好处。

同时,这里采用的是Resnet34,也可以尝试更深的50、101、152层网络。

还有其他方法会在之后进行补充。

相关文章:

基于ResNet34的花朵分类

一.数据集准备 新建一个项目文件夹ResNet&#xff0c;并在里面建立data_set文件夹用来保存数据集&#xff0c;在data_set文件夹下创建新文件夹"flower_data"&#xff0c;点击链接下载花分类数据集https://storage.googleapis.com/download.tensorflow.org/example_i…...

[计算机提升] 数据及相关概念

1.9 数据及相关概念 1.9.1 数据、信息 在Windows系统中&#xff0c;数据是指事实或信息的集合&#xff0c;可以是数字、文本、图像、声音等形式的内容。数据是计算机系统中处理和操作的基本元素&#xff0c;是信息的表现形式和载体。 与信息相比&#xff0c;数据的范围更广泛…...

第18章 SpringCloud生态(二)

18.11 说说你了解的负载均衡算法 难度:★★ 重点:★★★★ 白话解析 常用的负载均衡算法有: 1、轮询(Round Robin):说白了就是让服务器排好队,一个个轮着来调用;Ribbon默认采用该算法。 优点:实现起来简单; 缺点:服务器性能不一样的情况下,导致能力强的会经常空闲…...

【Android】BRVAH多布局实现

前言 基于3.0.4版本的BRVAH框架实现的 实现方法 1.创建多个不同类型的布局&#xff08;步骤忽略&#xff09; 2.创建数据实体类 数据类要实现【MultiItemEntity】接口 class MyMultiItemEntity(//获取布局类型override var itemType: Int,var tractorRes: Int? null,va…...

AWS SAP-C02教程9-节省成本

SAP-C01变成SAP-C02的时候,最大的变化就是没有把成本单独列出一个模块,但是成本依然包含在各个其它模块之中,所以成本还是很重要的。本章将列举一些成本优化方案以及一些成本辅助功能。 目录 1 Cost Allocation Tags2 Trusted Advisor2.1 AWS Support Plans2.2 基本特性2.3…...

[CSP-S 2023] 种树 —— 二分+前缀和

This way 题意&#xff1a; 一开始以为是水题&#xff0c;敲了一个二分贪心检查的代码&#xff0c;20分。发现从根往某个节点x走的时候&#xff0c;一路走来的子树上的节点到已栽树的节点的距离会变短&#xff0c;那么并不能按照初始情况贪心。 于是就想着检查时候用线段树…...

【LeetCode周赛】LeetCode第368场周赛

目录 元素和最小的山形三元组 I元素和最小的山形三元组 II合法分组的最少组数 元素和最小的山形三元组 I 给你一个下标从 0 开始的整数数组 nums 。 如果下标三元组 (i, j, k) 满足下述全部条件&#xff0c;则认为它是一个山形三元组 &#xff1a; i < j < k nums[i] &l…...

【智慧工地源码】基于AI视觉技术赋能智慧工地

伴随着技术的不断发展&#xff0c;信息化手段、移动技术、智能穿戴及工具在工程施工阶段的应用不断提升&#xff0c;智慧工地概念应运而生&#xff0c;庞大的建设规模催生着智慧工地的探索和研发。 建筑施工具有周期长、环境复杂、工序繁杂、人员流动性大等特点&#xff0c;所以…...

云服务器搭建Hadoop分布式

文章目录 1.服务器配置2.Java环境3. 安装Hadoop4. 集群配置5. 编写集群的启动脚本 1.服务器配置 服务器主机名配置115.157.197.82s110核115.157.197.84s210核115.157.197.109s310核115.157.197.31s410核115.157.197.60gracal10核 所有的软件安装在/opt/module下&#xff0c;软…...

2678. 老人的数目

给你一个下标从 0 开始的字符串 details 。details 中每个元素都是一位乘客的信息&#xff0c;信息用长度为 15 的字符串表示&#xff0c;表示方式如下&#xff1a; 前十个字符是乘客的手机号码。 接下来的一个字符是乘客的性别。 接下来两个字符是乘客的年龄。 最后两个字符是…...

【刷题-牛客】出栈、入栈的顺序匹配 (代码+动态演示)

【刷题-牛客】出栈、入栈的顺序匹配 (代码动态演示) 文章目录 【刷题-牛客】出栈、入栈的顺序匹配 (代码动态演示) 解题思路 动图演示完整代码多组测试 &#x1f497;题目描述 &#x1f497;: 输入两个整数序列&#xff0c;第一个序列表示栈的压入顺序&#xff0c;请判断第二个…...

vscode类似GitHub Copilot的插件推荐

由于GitHub Copilot前段时间学生认证的账号掉了很多&#xff0c;某宝激活也是价格翻了几倍&#xff0c;而却&#xff0c;拿来用一天就掉线&#xff0c;可以试试同类免费的插件哦。 例如&#xff1a;TabNine&#xff0c;下载插件后&#xff0c;他会提示你登录&#xff0c;直接登…...

Html -- 文字时钟

Html – 文字时钟 文字时钟&#xff0c;之前在Android上实现了相关效果&#xff0c;闲来无事&#xff0c;弄个网页版的玩玩。。。直接上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><titl…...

快问快答:关于线上流量卡“归属地随机”几个问题!

在网上办过流量卡的朋友应该都知道&#xff0c;资费虽然便宜&#xff0c;但是归属地却是异地&#xff0c;今天小编就给大家聊一聊关于流量卡归属地的问题。 ​ 网上的流量卡都是归属地随机的卡&#xff0c;今天小编以问答的方式给大家普及一下&#xff0c;如果对于归属地有疑问…...

Linux常用命令——clock命令

在线Linux命令查询工具 clock 用于调整 RTC 时间。 补充说明 clock命令用于调整 RTC 时间。 RTC 是电脑内建的硬件时间&#xff0c;执行这项指令可以显示现在时刻&#xff0c;调整硬件时钟的时间&#xff0c;将系统时间设成与硬件时钟之时间一致&#xff0c;或是把系统时间…...

澎湃OS上线:小米告别MIUI,跟小米汽车Say Hi

作者 | Amy 编辑 | 德新 10月17日&#xff0c;雷军发博官宣&#xff0c;「小米将启用全新操作系统&#xff0c;小米澎湃OS&#xff08;Xiaomi HyperOS&#xff09;」。 短短几百字的微博&#xff0c;数次提到了「小米汽车」&#xff1a; 小米向人车家全生态迈进&#xff0c;…...

域名不部署SSL证书有什么影响?

SSL证书是保护网站数据传输安全的重要工具&#xff0c;通过加密用户和服务器之间的通信来确保数据的保密性和完整性。然而&#xff0c;如果一个域名没有部署SSL证书&#xff0c;会对网站和用户产生一系列的负面影响。下文中将介绍域名不部署SSL证书的影响&#xff0c;并提供相应…...

Delphi 编程实现拖动排序并输出到文档

介绍&#xff1a;实现拖动排序功能&#xff0c;并将排序后的内容输出到文档中。我们将使用 Delphi 的组件来创建一个界面&#xff0c;其中包括一个 Memo 控件用于输入内容&#xff0c;一个 ListBox 控件用于显示排序后的内容&#xff0c;并且提供按钮来触发排序和输出操作。 代…...

android利用FFmpeg进行视频转换

大致思路&#xff1a;首先安装FFmpeg库到windows电脑上&#xff0c;先测试命令行工具是否可以使用&#xff08;需要先配置环境&#xff09;&#xff0c;之后再集成到android程序中。 一些命令&#xff1a; 转化为流文件&#xff1a; ffmpeg -i input.mp4 -codec copy -bsf:v …...

Python中不同进制间的转换

Python中不同进制间的转换 一、不同进制在计算机科学、数学和其他领域中具广泛的应用。以下是一些常见的应用&#xff1a;1. 二进制&#xff08;base-2&#xff09;: 在计算机系统中&#xff0c;数据以二进制形式存储和处理。二进制由0和1组成&#xff0c;是数字电子技术的基础…...

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂

蛋白质结合剂&#xff08;如抗体、抑制肽&#xff09;在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上&#xff0c;高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术&#xff0c;但这类方法普遍面临资源消耗巨大、研发周期冗长…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

LLM基础1_语言模型如何处理文本

基于GitHub项目&#xff1a;https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken&#xff1a;OpenAI开发的专业"分词器" torch&#xff1a;Facebook开发的强力计算引擎&#xff0c;相当于超级计算器 理解词嵌入&#xff1a;给词语画"…...

服务器--宝塔命令

一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行&#xff01; sudo su - 1. CentOS 系统&#xff1a; yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

面向无人机海岸带生态系统监测的语义分割基准数据集

描述&#xff1a;海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而&#xff0c;目前该领域仍面临一个挑战&#xff0c;即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...