当前位置: 首页 > news >正文

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。
在这里插入图片描述
使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计算最终生成特征图。

在这里插入图片描述
通过池化层进行降维度
在这里插入图片描述

卷积网络结构从输入到输出, 3* 32*32 --> 10:

类型WeightBIAS
卷积(3, 12, 5)(12, 3, 5, 5)12
卷积(12, 12, 5)(12, 12, 5, 5)12
Norm1212
卷积(12, 24, 5)(24, 12, 5, 5)24
卷积(24 24, 5)(24, 24, 5, 5)24
Norm2424
Linear(10, 2400)10

训练分类模型

准备数据
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader# Loading and normalizing the data.
# Define transformations for the training and test sets
transformations = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# CIFAR10 dataset consists of 50K training images. We define the batch size of 10 to load 5,000 batches of images.
batch_size = 10
number_of_labels = 10 # Create an instance for training. 
# When we run this code for the first time, the CIFAR10 train dataset will be downloaded locally. 
train_set =CIFAR10(root="./data",train=True,transform=transformations,download=True)# Create a loader for the training set which will read the data within batch size and put into memory.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
print("The number of images in a training set is: ", len(train_loader)*batch_size)# Create an instance for testing, note that train is set to False.
# When we run this code for the first time, the CIFAR10 test dataset will be downloaded locally. 
test_set = CIFAR10(root="./data", train=False, transform=transformations, download=True)# Create a loader for the test set which will read the data within batch size and put into memory. 
# Note that each shuffle is set to false for the test loader.
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print("The number of images in a test set is: ", len(test_loader)*batch_size)print("The number of batches per epoch is: ", len(train_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
创建网络
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F# Define a convolution neural network
class Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*10*10, 10)def forward(self, input):output = F.relu(self.bn1(self.conv1(input)))      output = F.relu(self.bn2(self.conv2(output)))     output = self.pool(output)                        output = F.relu(self.bn4(self.conv4(output)))     output = F.relu(self.bn5(self.conv5(output)))     output = output.view(-1, 24*10*10)output = self.fc1(output)return output# Instantiate a neural network model 
model = Network()

定义损失函数

使用交叉熵函数作为损失函数,交叉熵分为两种

  • 二分类交叉熵函数
    在这里插入图片描述
  • 多分类交叉熵函数
    在这里插入图片描述
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
模型训练
from torch.autograd import Variable# Function to save the model
def saveModel():path = "./myFirstModel.pth"torch.save(model.state_dict(), path)# Function to test the model with the test dataset and print the accuracy for the test images
def testAccuracy():model.eval()accuracy = 0.0total = 0.0device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")with torch.no_grad():for data in test_loader:images, labels = data# run the model on the test set to predict labelsoutputs = model(images.to(device))# the label with the highest energy will be our prediction_, predicted = torch.max(outputs.data, 1)total += labels.size(0)accuracy += (predicted == labels.to(device)).sum().item()# compute the accuracy over all test imagesaccuracy = (100 * accuracy / total)return(accuracy)# Training function. We simply have to loop over our data iterator and feed the inputs to the network and optimize.
def train(num_epochs):best_accuracy = 0.0# Define your execution devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("The model will be running on", device, "device")# Convert model parameters and buffers to CPU or Cudamodel.to(device)for epoch in range(num_epochs):  # loop over the dataset multiple timesrunning_loss = 0.0running_acc = 0.0for i, (images, labels) in enumerate(train_loader, 0):# get the inputsimages = Variable(images.to(device))labels = Variable(labels.to(device))# zero the parameter gradientsoptimizer.zero_grad()# predict classes using images from the training setoutputs = model(images)# compute the loss based on model output and real labelsloss = loss_fn(outputs, labels)# backpropagate the lossloss.backward()# adjust parameters based on the calculated gradientsoptimizer.step()# Let's print statistics for every 1,000 imagesrunning_loss += loss.item()     # extract the loss valueif i % 1000 == 999:    # print every 1000 (twice per epoch) print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 1000))# zero the lossrunning_loss = 0.0# Compute and print the average accuracy fo this epoch when tested over all 10000 test imagesaccuracy = testAccuracy()print('For epoch', epoch+1,'the test accuracy over the whole test set is %d %%' % (accuracy))# we want to save the model if the accuracy is the bestif accuracy > best_accuracy:saveModel()best_accuracy = accuracy
测试模型
import matplotlib.pyplot as plt
import numpy as np# Function to show the images
def imageshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# Function to test the model with a batch of images and show the labels predictions
def testBatch():# get batch of images from the test DataLoader  images, labels = next(iter(test_loader))# show all images as one image gridimageshow(torchvision.utils.make_grid(images))# Show the real labels on the screen print('Real labels: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# Let's see what if the model identifiers the  labels of those exampleoutputs = model(images)# We got the probability for every 10 labels. The highest (max) probability should be correct label_, predicted = torch.max(outputs, 1)# Let's show the predicted labels on the screen to compare with the real onesprint('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(batch_size)))
执行模型
if __name__ == "__main__":# Let's build our modeltrain(5)print('Finished Training')# Test which classes performed welltestAccuracy()# Let's load the model we just created and test the accuracy per labelmodel = Network()path = "myFirstModel.pth"model.load_state_dict(torch.load(path))# Test with batch of imagestestBatch()

在这里插入图片描述

总结

pytorch 搭建一个 CNN 模型比较简单,5 轮训练之后,效果就可以达到 60%,10 张图片中预测对了 6 张。

相关文章:

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。 使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计…...

得物App获评新奖项,正品保障夯实供应链创新水平

近日,得物App再度获评新奖项——“2024上海市供应链创新与应用优秀案例”。 本次奖项为上海市供应链领域最高奖项,旨在评选出在供应链创新成效上处于领先地位、拥有成功模式和经验的企业。今年以来,得物App已接连获得“上海市质量金奖”、“科…...

【数据结构-邻项消除】力扣735. 小行星碰撞

给定一个整数数组 asteroids,表示在同一行的小行星。 对于数组中的每一个元素,其绝对值表示小行星的大小,正负表示小行星的移动方向(正表示向右移动,负表示向左移动)。每一颗小行星以相同的速度移动。 找…...

002-Kotlin界面开发之Kotlin旋风之旅

Kotlin旋风之旅 Compose Desktop中哪些Kotlin知识是必须的? 在学习Compose Desktop中,以下Kotlin知识是必须的: 基础语法:包括变量声明、数据类型、条件语句、循环等。面向对象编程:类与对象、继承、接口、抽象类等。…...

VMware Workstation Pro for Personal Use (For Windows)

这是从broadcom.com网下载的个人版本的Vmware 17.6.1,存分享不要分。 VMware-workstation-full-17.6.1-24319023.exe(447.93 MB) Build Number: 24319023 Oct 08, 2024 07.33AM SHA2: f95429e395a583eb5ba91f09b040e2f8c53a5e7aa37c4c6bfcaf82115a8…...

论文 | PROMPTAGATOR : FEW-SHOT DENSE RETRIEVAL FROM 8 EXAMPLES

1. 背景信息 在信息检索领域,传统的方法往往依赖于大量的标注数据来训练模型,以便在各种任务中表现良好。然而,许多实际应用中的监督数据是有限的,尤其是在不同的检索任务中。最近的研究开始关注如何从一个拥有丰富监督数据的任务…...

使用 Github 进行项目管理

GitHub 是一个广泛使用的代码托管和协作平台,它提供了强大的工具来支持项目管理和团队协作。在项目开发和工作中,避免不了 Github 的使用,然鹅我一直没有稍微系统地学习过 github 的整个工作流程,对这些操作都是一知半解的&#x…...

企业SRC挖掘选择与信息收集指南

内容预览 ≧∀≦ゞ 企业SRC挖掘选择与信息收集指南导语1. 企业SRC的选择2. 信息收集2.1 集团与子公司2.2 小程序与APP2.3 Web端信息收集 3. 信息收集常用模板总结 企业SRC挖掘选择与信息收集指南 导语 近年来,企业的安全响应中心(SRC)已逐渐…...

Golang | Leetcode Golang题解之第524题通过删除字母匹配到字典里最长单词

题目: 题解: func findLongestWord(s string, dictionary []string) (ans string) {m : len(s)f : make([][26]int, m1)for i : range f[m] {f[m][i] m}for i : m - 1; i > 0; i-- {f[i] f[i1]f[i][s[i]-a] i}outer:for _, t : range dictionary …...

【DBeaver】连接带kerberos的hive[Apache|HDP]

目录 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 1.2 环境配置 二、基于Cloudera驱动创建连接 三、基于Hive原生驱动创建连接 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 在Kerberos官网下载,地址如下:https://web.mit.edu/kerberos…...

Unity3D 开发教程:从入门到精通

Unity3D 开发教程:从入门到精通 Unity3D 是一款强大的跨平台游戏引擎,广泛应用于游戏开发、虚拟现实、增强现实等领域。本文将详细介绍 Unity3D 的基本概念、开发流程以及一些高级技巧,帮助你从零基础到掌握 Unity3D 开发。 目录 Unity3D…...

文件操作和 IO(一):文件基础知识 文件系统操作 => File类

目录 1. 什么是文件 1.1 概念 1.2 硬盘, 内存, 寄存器之间的区别 1.3 机械硬盘和固态硬盘 2. 文件路径 2.1 绝对路径 2.2 相对路径 3. 文件分类 4. File 类 4.1 属性 4.2 构造方法 4.3 方法 1. 什么是文件 1.1 概念 狭义上的文件: 保存在硬盘上的文件广义的上的文…...

用Pyhon写一款简单的益智类小游戏——2048

文字版——代码及讲解 代码—— import random# 初始化游戏棋盘 def init_board():return [[0] * 4 for _ in range(4)]# 在棋盘上随机生成一个2或4 def add_new_tile(board):empty_cells [(i, j) for i in range(4) for j in range(4) if board[i][j] 0]if empty_cells:i,…...

akshare股票涨跌幅自定义范围查询:A股、港股、美股

参看:https://stock.hexun.com/2024-10-31/215251914.html 涨幅计算公式:(当前价格 - 上一个交易日收盘价) 上一个交易日收盘价 100% 。 跌幅计算公式:(上一个交易日收盘价 - 当前价格) 上一个…...

通过js控制修改css变量

在JavaScript中,你可以通过操作CSS变量(也称为自定义属性)来动态改变样式。CSS变量在CSS中使用 – 前缀定义,例如 --main-color: red;。在JavaScript中,你可以使用 document.documentElement.style.setProperty 方法来…...

<HarmonyOS第一课>HarmonyOS SDK开放能力简介的课后习题

不出户&#xff0c;知天下&#xff1b; 不窥牖&#xff0c;见天道。 其出弥远&#xff0c;其知弥少。 是以圣人不行而知&#xff0c;不见而明&#xff0c;不为而成。 本篇<HarmonyOS第一课>HarmonyOS SDK开放能力简介是简单介绍了HarmonyOS SDK&#xff0c;不需要大家过多…...

深度学习:yolo的使用--图像处理

定义了一个名为 ListDataset 的类&#xff0c;它继承自 PyTorch 的 Dataset 类,这个数据集从一个包含图像文件路径的列表中读取图像和对应的标签文件 class ListDataset(Dataset):def __init__(self, list_path, img_size416, augmentTrue, multiscaleTrue, normalized_labelsT…...

TypeScript实用笔记(一):初始化、类型定义与函数使用

文章目录 一、ts初始化1. 初始化.json文件一2. 启动方式2.1 直接运行.ts文件2.2 转换运行 二、类型1. 参数类型1.1 常规参数1.2 symbol1.3 数组\[]1.4 元组\[]1.5 用字面量定义数据类型 2. Object3. 枚举类型\[Enum]3.1 数字枚举3.2 字符串枚举 三、 类型别名1. 数组别名使用2.…...

【大数据学习 | kafka】producer之拦截器,序列化器与分区器

1. 自定义拦截器 interceptor是拦截器&#xff0c;可以拦截到发送到kafka中的数据进行二次处理&#xff0c;它是producer组成部分的第一个组件。 public static class MyInterceptor implements ProducerInterceptor<String,String>{Overridepublic ProducerRecord<…...

零基础学西班牙语,柯桥专业小语种培训泓畅学校

No te comas el coco, seguro que te ha salido bien la entrevista. Ya te llamarn. 别瞎想了&#xff01;我保证你的面试很顺利。他们会给你打电话的。 这里的椰子是"头"的比喻。在西班牙的口语中&#xff0c;我们也可以听到其他同义表达&#xff0c;比如&#x…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件&#xff1a;-&#xff08;纯文本文件&#xff0c;二进制文件&#xff0c;数据格式文件&#xff09; 如文本文件、图片、程序文件等。 目录文件&#xff1a;d&#xff08;directory&#xff09; 用来存放其他文件或子目录。 设备…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

视频字幕质量评估的大规模细粒度基准

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 摘要 视频字幕在文本到视频生成任务中起着至关重要的作用&#xff0c;因为它们的质量直接影响所生成视频的语义连贯性和视觉保真度。尽管大型视觉-语言模型&#xff08;VLMs&#xff09;在字幕生成方面…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件&#xff0c;这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下&#xff0c;实现高效测试与快速迭代&#xff1f;这一命题正考验着…...

从零开始了解数据采集(二十八)——制造业数字孪生

近年来&#xff0c;我国的工业领域正经历一场前所未有的数字化变革&#xff0c;从“双碳目标”到工业互联网平台的推广&#xff0c;国家政策和市场需求共同推动了制造业的升级。在这场变革中&#xff0c;数字孪生技术成为备受关注的关键工具&#xff0c;它不仅让企业“看见”设…...

pgsql:还原数据库后出现重复序列导致“more than one owned sequence found“报错问题的解决

问题&#xff1a; pgsql数据库通过备份数据库文件进行还原时&#xff0c;如果表中有自增序列&#xff0c;还原后可能会出现重复的序列&#xff0c;此时若向表中插入新行时会出现“more than one owned sequence found”的报错提示。 点击菜单“其它”-》“序列”&#xff0c;…...

接口 RESTful 中的超媒体:REST 架构的灵魂驱动

在 RESTful 架构中&#xff0c;** 超媒体&#xff08;Hypermedia&#xff09;** 是一个核心概念&#xff0c;它体现了 REST 的 “表述性状态转移&#xff08;Representational State Transfer&#xff09;” 的本质&#xff0c;也是区分 “真 RESTful API” 与 “伪 RESTful AP…...