Github项目-CNNResnet9-残差神经网络水果多分类项目
ResNet-论文全文完整翻译+注解 - 知乎
你必须要知道CNN模型:ResNet - 知乎
#!/usr/bin/env python
# coding: utf-8
#https://github.com/SehajS/cnn-resnet-fruit-classification
# # Classifying Fruits from their Images
#
# This project aims at creating a deep learning model which predicts the names of the fruits by looking at their images.
#
# The dataset is taken from kaggle and can be accessed using this link: https://www.kaggle.com/moltean/fruits
#
# A complete walkthrough from downloading the dataset to the creating the CNN-ResNet model with extensive comments has been provided. # ## Import all the requried libraries/modules# In[1]:#import opendatasets as od
import os
import shutil
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')# ## Downloading the dataset# In[2]:#dataset_url = "https://www.kaggle.com/moltean/fruits"
#od.download(dataset_url)# ## Cleaning the downloaded dataset# In[3]:data_direc = './datadev'
os.listdir(data_direc)# There are some files that one won't be needing in the project. Hence, one should remove them.# In[4]:#shutil.rmtree('./fruits/fruits-360/test-multiple_fruits')# In[5]:#shutil.rmtree('./fruits/fruits-360/papers')# In[6]:train_data_direc = "./datadev/train"
test_data_direc = "./datadev/test"# ## Import the Dataset using PyTorch# In[7]:print(f'The total number of labels is: {len(os.listdir(train_data_direc))}')# In[8]:dataset = ImageFolder(train_data_direc)
len(dataset)# In total, there are 67692 non-test images in our dataset.# Let us peek at one of the elements of the dataset. This gives further insights on the way data is stored.# In[9]:dataset[0]# In[10]:img, label = dataset[0]
plt.imshow(img)# One would now like to convert the images to tensors.# In[11]:dataset = ImageFolder(train_data_direc, tt.ToTensor())# In[12]:image, label = dataset[0]
plt.imshow(image.permute(1,2,0))# ## Training and Validation Sets# In[13]:val_pct = 0.1 # 10% of the images in Train folder will be used as validation set
val_size = int(len(dataset) * 0.1)
train_size = len(dataset) - val_size
val_size, train_size# In[14]:train_ds, val_ds = random_split(dataset, [train_size, val_size])# In[15]:len(train_ds), len(val_ds)# It is time to use Data Loaders to load the dataset in batches.# In[16]:batch_size = 64
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers = 4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size*2, num_workers = 4, pin_memory=True)# In[17]:def show_batch(dl):for images, labels in dl:fig, ax = plt.subplots(figsize=(12, 6))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0))break# In[18]:show_batch(train_dl)# ## Utility Functions and Classes
#
# The creation and training of the model is done using GPU. Below are the functions that make sure that tensors and the model is using a GPU as the default device.# In[19]:def get_default_device():"""Pick GPU if available, else CPU"""if torch.cuda.is_available():return torch.device('cuda')else:return torch.device('cpu')def to_device(data, device):"""Move tensor(s) to chosen device"""if isinstance(data, (list,tuple)):return [to_device(x, device) for x in data]return data.to(device, non_blocking=True)class DeviceDataLoader():"""Wrap a dataloader to move data to a device"""def __init__(self, dl, device):self.dl = dlself.device = devicedef __iter__(self):"""Yield a batch of data after moving it to device"""for b in self.dl: yield to_device(b, self.device)def __len__(self):"""Number of batches"""return len(self.dl)# In[20]:device = get_default_device()
device# In[21]:train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)# ## Model and Training Utilities# In[22]:class ImageClassificationBase(nn.Module):def training_step(self, batch):images, labels = batch out = self(images) # Generate predictionsloss = F.cross_entropy(out, labels) # Calculate lossreturn lossdef validation_step(self, batch):images, labels = batch out = self(images) # Generate predictionsloss = F.cross_entropy(out, labels) # Calculate lossacc = accuracy(out, labels) # Calculate accuracyreturn {'val_loss': loss.detach(), 'val_acc': acc}def validation_epoch_end(self, outputs):batch_losses = [x['val_loss'] for x in outputs]epoch_loss = torch.stack(batch_losses).mean() # Combine lossesbatch_accs = [x['val_acc'] for x in outputs]epoch_acc = torch.stack(batch_accs).mean() # Combine accuraciesreturn {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}def epoch_end(self, epoch, result):print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['train_loss'], result['val_loss'], result['val_acc']))def accuracy(outputs, labels):_, preds = torch.max(outputs, dim=1)return torch.tensor(torch.sum(preds == labels).item() / len(preds))# In[23]:@torch.no_grad()
def evaluate(model, val_loader):model.eval()outputs = [model.validation_step(batch) for batch in val_loader]return model.validation_epoch_end(outputs)def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):history = []optimizer = opt_func(model.parameters(), lr)for epoch in range(epochs):# Training Phase model.train()train_losses = []for batch in train_loader:loss = model.training_step(batch)train_losses.append(loss)loss.backward()optimizer.step()optimizer.zero_grad()# Validation phaseresult = evaluate(model, val_loader)result['train_loss'] = torch.stack(train_losses).mean().item()model.epoch_end(epoch, result)history.append(result)return history# In[24]:def conv_block(in_channels, out_channels, pool=False):layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)]if pool: layers.append(nn.MaxPool2d(2))return nn.Sequential(*layers)# In[25]:class ResNet9(ImageClassificationBase):def __init__(self, in_channels, num_classes):super().__init__()self.conv1 = conv_block(in_channels, 64)self.conv2 = conv_block(64, 128, pool=True)self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))self.conv3 = conv_block(128, 256, pool=True)self.conv4 = conv_block(256, 512, pool=True)self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Dropout(0.2),nn.Linear(512, num_classes))def forward(self, xb):out = self.conv1(xb)out = self.conv2(out)out = self.res1(out) + outout = self.conv3(out)out = self.conv4(out)out = self.res2(out) + outout = self.classifier(out)return out# In[26]:model = to_device(ResNet9(3, len(dataset.classes)), device)
model#
# Pass one batch of input tensor through the model.
# # In[27]:torch.cuda.empty_cache()for batch in train_dl:images, labels = batchprint('images.shape: ', images.shape)print('images.device: ', images.device)preds = model(images)print('preds.shape: ', preds.shape)break# ## Training the Model# In[28]:history = [evaluate(model, val_dl)]
history# Let us train for 5 epochs with the learning rate of 0.001. Note that we use Adam as the optimizer of choice.# In[29]:history += fit(5, 0.001, model, train_dl, val_dl, torch.optim.Adam)# The accuracy achieved on teh valiation set is very high and close to 100%, therefore, one should not train the model for any more epochs. We end the training at 5 epochs.# In[ ]:def plot_accuracies(history):accuracies = [x['val_acc'] for x in history]plt.plot(accuracies, '-x')plt.xlabel('epoch')plt.ylabel('accuracy')plt.title('Accuracy vs. No. of epochs');# In[ ]:plot_accuracies(history)# In[ ]:def plot_losses(history):train_losses = [x.get('train_loss') for x in history]val_losses = [x['val_loss'] for x in history]plt.plot(train_losses, '-bx')plt.plot(val_losses, '-rx')plt.xlabel('epoch')plt.ylabel('loss')plt.legend(['Training', 'Validation'])plt.title('Loss vs. No. of epochs');# In[ ]:plot_losses(history)# ## Testing with Individual Images
#
# Now, one would like to test outthe model that we have built in previous section on the Test dataset and see how it performs.# In[ ]:def predict_image(img, model):# Convert to a batch of 1xb = to_device(img.unsqueeze(0), device)# Get predictions from modelyb = model(xb)# Pick index with highest probability_, preds = torch.max(yb, dim=1)# Retrieve the class labelreturn dataset.classes[preds[0].item()]# In[ ]:test_dataset = ImageFolder(test_data_direc, tt.ToTensor())# In[ ]:len(test_dataset)# In[ ]:def get_prediction(torch_ds, model):img, label = torch_dsplt.imshow(img.permute(1, 2, 0))print('Label:', dataset.classes[label], ', Predicted:', predict_image(img, model))# In[ ]:get_prediction(test_dataset[0], model)# In[ ]:get_prediction(test_dataset[-1], model)# In[ ]:get_prediction(test_dataset[999], model)# In[ ]:test_loader = DeviceDataLoader(DataLoader(test_dataset, batch_size*2), device)
result = evaluate(model, test_loader)
result# Therefore, the accuracy of the model on the test set is little above 98% which is great.
#
# Naturally, a curious mind would like to know for which items did the model perform the worst.# In[ ]:wrong_preds = []
for test_ds in test_dataset:img, label = test_dsprediction = predict_image(img, model)if dataset.classes[label] != prediction:wrong_preds.append([dataset.classes[label], prediction])# In[ ]:print(f'Therefore, there are in total {len(wrong_preds)} out of {len(test_dataset)} items in the test set for which the model has made a wrong prediction')# In[ ]:#len(wrong_labels)# Let us check what did our model predict for each of the wrongly predicted items. # In[ ]:checked = []
for item in wrong_preds:if item not in checked:checked.append(item)print(f'{item[0]} has been wrongfully predicted as {item[1]}')# ## Saving the Model# In[ ]:torch.save(model.state_dict(), '√SehajS-CNN-ResNet9-fruit-prediction.pth')
相关文章:
Github项目-CNNResnet9-残差神经网络水果多分类项目
ResNet-论文全文完整翻译注解 - 知乎 你必须要知道CNN模型:ResNet - 知乎 #!/usr/bin/env python # coding: utf-8 #https://github.com/SehajS/cnn-resnet-fruit-classification # # Classifying Fruits from their Images # # This project aims at creating a…...
学习感悟一己之言
学习感悟一己之言 学习上克服困难实际上是克服心理上或认识上的障碍的过程。所谓的理解,就是化陌生为熟悉。看不懂,一方面是因为接触的材料太陌生,即远离你当前的背景知识;另一方面是材料或讲述者的描述刻画不准确或晦涩不当。有了…...
【设计模式-2.3】创建型——原型模式
说明:本文介绍设计模式中,创建型中的原型模式; 飞机大战 创建型设计模式关注于对象的创建,原型模式也不例外。如简单工厂和工厂模式中提到过的飞机大战这个例子,游戏中飞机、坦克对象会创建许许多多的实例࿰…...
八大插入算法(有注释)
直接插入排序 //直接插入排序 void InsertSortingDirectly(int* nums,int numsSize){int j0;for(int i1;i<numsSize-1;i){//定义一个中间变量保存当前要插入的值int tempnums[i];//在前面已排好序的序列中,找到合适的位置插入for(ji-1;j>0;j--){if(nums[j]&g…...
【2】基于多设计模式下的同步异步日志系统
6. 相关技术知识补充 6.1 不定参函数 在初学C语⾔的时候,我们都⽤过printf函数进⾏打印。其中printf函数就是⼀个不定参函数,在函数内部可以根据格式化字符串中格式化字符分别获取不同的参数进⾏数据的格式化。 ⽽这种不定参函数在实际的使⽤中也⾮常…...
npm管理发布包-创建与发布
创建与发布 我们可以将自己开发的工具包发布到 npm 服务上,方便自己和其他开发者使用,操作步骤如下 创建文件夹,并创建文件indexjs,在文件中声明函数,使用 module.exports 暴露npm初始化工具包,package.j…...
基于Spring,SpringMVC,MyBatis的校园二手交易网站
文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 🍅文末获取源码联系🍅 项目介绍 基于Spring,SpringMVC,MyBatis的校园二…...
酒店 KPI绩效考核指标及应用
“路遥知马力,日久见人心”,目前国内各类型酒店风起云涌,大有在市场竞争中一比高下之势,各路精英受经济型酒店低投入高回报的市场利益驱动,都分分抢占市场,从而使国内经济型酒店的数量不断增加,…...
WordPress两种方法实现上传媒体图片文件自动重命名
我们发布文章时,会上传一些图片、音频之类的文件。但是WordPress没有自动 给新上传文件重命名的功能,逐个文件去重命名那就太麻烦了,那么我们改如何自动给上传的媒体文件图片重命名呢? 我在网站搜索了些上WordPress上传媒体文件自…...
TZOJ 1405 An easy problem
翻译有些出错,但大概是那个意思 答案: #include <stdio.h> #include <ctype.h> //引用库函数isupper的头文件int main() {int T 0, i 0;scanf("%d", &T); //要输入的行数while (T--) //循环T次{char c;int y 0…...
SpringBoot+mysql+vue实现大学生健康档案管理系统前后端分离
一、项目简介 本项目是一套基于SpringBoot实现大学生健康档案管理系统,主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目可以直接作为bishe使用。 项目都经过严格调试&#…...
CCC联盟数字车钥匙(三)——UWB MAC时间网格同步及Hopping
本文继续上一篇UWB MAC时间网格继续介绍UWB MAC中关于时间同步相关内容。 3、MAC时间网格同步 每个测距会话的定义都基于相对的指定时钟参考 U W B t i m e 0 k UWB^k_{time0} UWBtime0k,相对于发起者的内部时钟定义。 时钟参考 U W B t i m e 0 k UWB^k_{time0} …...
一周上手 steam搬砖项目或成2024年最受欢迎副业
蒸汽砖拆除项目,兼职创业两不误,助你轻松赚钱 你是否想要找到一个既可以兼职又可以创业的项目?蒸汽砖拆除项目正逐渐崭露头角,引起了越来越多人的关注。这个项目不仅门槛低,上手快,而且不用担心卖不出去&am…...
java数据结构(哈希表—HashMap)含LeetCode例题讲解
目录 1、HashMap的基本方法 1.1、基础方法(增删改查) 1.2、其他方法 2、HashMap的相关例题 2.1、题目介绍 2.2、解题 2.2.1、解题思路 2.2.2、解题图解 2.3、解题代码 1、HashMap的基本方法 HashMap 是一个散列表,它存储的内容是键…...
快速了解ChatGPT(大语言模型)
目录 GPT原理:文字接龙,输入一个字,后面会接最有可能出现的文字。 GPT4 学会提问:发挥语言模型的最大能力 参考李宏毅老师的课快速了解大语言模型做的笔记: Lee老师幽默的开场: GPT:chat Ge…...
计算机软件的分类
以功能进行分类,计算机软件通常可以分为系统软件和应用软件两大类。 系统软件:系统软件是计算机运行和管理的基本软件,包括操作系统、驱动程序、系统工具和服务程序等。操作系统是系统软件的核心,负责管理计算机的硬件资源、提供用…...
数据库应用:Ubuntu 20.04 安装MongoDB
目录 一、理论 1.MongoDB 二、实验 1.Ubuntu 20.04 安装MongoDB 三、问题 1.Ubuntu Linux的apt 包管理器更新安装软件报错 2.Ubuntu20.04安装vim报错 3.Ubuntu20.04如何更换阿里源 4.Ubuntu22.04如何更换阿里源 一、理论 1.MongoDB (1)概念 …...
服务器配置 jupyter lab,并在本地浏览器免密登陆
一、背景 快速搭建一个jupyter lab 不用每次用ssh登录输入密码 二、步骤 方法1、临时在服务器启动 jupyter lab,并在本地浏览器免密登陆 两句命令解决 pip install jupyterlabnohup jupyter lab --ServerApp.ip"*" --ServerApp.password"" -…...
WebUI自动化学习(Selenium+Python+Pytest框架)002
新建项目 New Project 新建一个python代码文件 file-new-python file 会自动创建一个.py后缀的代码文件 注意:命名规则,包含字母、数字、下划线,不能以数字开头,不能跟python关键字或包名重复。 ********************华丽分割线********************…...
miot-plugin-sdk. npm install安装失败
miot-plugin-sdk-npm install安装失败 最紧公司要开发一台智能设备,经过同事的对比,选中了米家作为云平台,于是,我就负责开发app界面端,根据官方文档教程 下载了miot-plugin-sdk 程序,准备开始开发,结果悲…...
CTF show Web 红包题第六弹
提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...
关于nvm与node.js
1 安装nvm 安装过程中手动修改 nvm的安装路径, 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解,但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后,通常在该文件中会出现以下配置&…...
最新SpringBoot+SpringCloud+Nacos微服务框架分享
文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...
分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...
HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
jmeter聚合报告中参数详解
sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...
MySQL 索引底层结构揭秘:B-Tree 与 B+Tree 的区别与应用
文章目录 一、背景知识:什么是 B-Tree 和 BTree? B-Tree(平衡多路查找树) BTree(B-Tree 的变种) 二、结构对比:一张图看懂 三、为什么 MySQL InnoDB 选择 BTree? 1. 范围查询更快 2…...
