当前位置: 首页 > news >正文

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。
在这里插入图片描述
使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计算最终生成特征图。

在这里插入图片描述
通过池化层进行降维度
在这里插入图片描述

卷积网络结构从输入到输出, 3* 32*32 --> 10:

类型WeightBIAS
卷积(3, 12, 5)(12, 3, 5, 5)12
卷积(12, 12, 5)(12, 12, 5, 5)12
Norm1212
卷积(12, 24, 5)(24, 12, 5, 5)24
卷积(24 24, 5)(24, 24, 5, 5)24
Norm2424
Linear(10, 2400)10

训练分类模型

准备数据
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader# Loading and normalizing the data.
# Define transformations for the training and test sets
transformations = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# CIFAR10 dataset consists of 50K training images. We define the batch size of 10 to load 5,000 batches of images.
batch_size = 10
number_of_labels = 10 # Create an instance for training. 
# When we run this code for the first time, the CIFAR10 train dataset will be downloaded locally. 
train_set =CIFAR10(root="./data",train=True,transform=transformations,download=True)# Create a loader for the training set which will read the data within batch size and put into memory.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
print("The number of images in a training set is: ", len(train_loader)*batch_size)# Create an instance for testing, note that train is set to False.
# When we run this code for the first time, the CIFAR10 test dataset will be downloaded locally. 
test_set = CIFAR10(root="./data", train=False, transform=transformations, download=True)# Create a loader for the test set which will read the data within batch size and put into memory. 
# Note that each shuffle is set to false for the test loader.
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print("The number of images in a test set is: ", len(test_loader)*batch_size)print("The number of batches per epoch is: ", len(train_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
创建网络
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F# Define a convolution neural network
class Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*10*10, 10)def forward(self, input):output = F.relu(self.bn1(self.conv1(input)))      output = F.relu(self.bn2(self.conv2(output)))     output = self.pool(output)                        output = F.relu(self.bn4(self.conv4(output)))     output = F.relu(self.bn5(self.conv5(output)))     output = output.view(-1, 24*10*10)output = self.fc1(output)return output# Instantiate a neural network model 
model = Network()

定义损失函数

使用交叉熵函数作为损失函数,交叉熵分为两种

  • 二分类交叉熵函数
    在这里插入图片描述
  • 多分类交叉熵函数
    在这里插入图片描述
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
模型训练
from torch.autograd import Variable# Function to save the model
def saveModel():path = "./myFirstModel.pth"torch.save(model.state_dict(), path)# Function to test the model with the test dataset and print the accuracy for the test images
def testAccuracy():model.eval()accuracy = 0.0total = 0.0device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")with torch.no_grad():for data in test_loader:images, labels = data# run the model on the test set to predict labelsoutputs = model(images.to(device))# the label with the highest energy will be our prediction_, predicted = torch.max(outputs.data, 1)total += labels.size(0)accuracy += (predicted == labels.to(device)).sum().item()# compute the accuracy over all test imagesaccuracy = (100 * accuracy / total)return(accuracy)# Training function. We simply have to loop over our data iterator and feed the inputs to the network and optimize.
def train(num_epochs):best_accuracy = 0.0# Define your execution devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("The model will be running on", device, "device")# Convert model parameters and buffers to CPU or Cudamodel.to(device)for epoch in range(num_epochs):  # loop over the dataset multiple timesrunning_loss = 0.0running_acc = 0.0for i, (images, labels) in enumerate(train_loader, 0):# get the inputsimages = Variable(images.to(device))labels = Variable(labels.to(device))# zero the parameter gradientsoptimizer.zero_grad()# predict classes using images from the training setoutputs = model(images)# compute the loss based on model output and real labelsloss = loss_fn(outputs, labels)# backpropagate the lossloss.backward()# adjust parameters based on the calculated gradientsoptimizer.step()# Let's print statistics for every 1,000 imagesrunning_loss += loss.item()     # extract the loss valueif i % 1000 == 999:    # print every 1000 (twice per epoch) print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 1000))# zero the lossrunning_loss = 0.0# Compute and print the average accuracy fo this epoch when tested over all 10000 test imagesaccuracy = testAccuracy()print('For epoch', epoch+1,'the test accuracy over the whole test set is %d %%' % (accuracy))# we want to save the model if the accuracy is the bestif accuracy > best_accuracy:saveModel()best_accuracy = accuracy
测试模型
import matplotlib.pyplot as plt
import numpy as np# Function to show the images
def imageshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# Function to test the model with a batch of images and show the labels predictions
def testBatch():# get batch of images from the test DataLoader  images, labels = next(iter(test_loader))# show all images as one image gridimageshow(torchvision.utils.make_grid(images))# Show the real labels on the screen print('Real labels: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# Let's see what if the model identifiers the  labels of those exampleoutputs = model(images)# We got the probability for every 10 labels. The highest (max) probability should be correct label_, predicted = torch.max(outputs, 1)# Let's show the predicted labels on the screen to compare with the real onesprint('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(batch_size)))
执行模型
if __name__ == "__main__":# Let's build our modeltrain(5)print('Finished Training')# Test which classes performed welltestAccuracy()# Let's load the model we just created and test the accuracy per labelmodel = Network()path = "myFirstModel.pth"model.load_state_dict(torch.load(path))# Test with batch of imagestestBatch()

在这里插入图片描述

总结

pytorch 搭建一个 CNN 模型比较简单,5 轮训练之后,效果就可以达到 60%,10 张图片中预测对了 6 张。

相关文章:

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。 使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计…...

得物App获评新奖项,正品保障夯实供应链创新水平

近日,得物App再度获评新奖项——“2024上海市供应链创新与应用优秀案例”。 本次奖项为上海市供应链领域最高奖项,旨在评选出在供应链创新成效上处于领先地位、拥有成功模式和经验的企业。今年以来,得物App已接连获得“上海市质量金奖”、“科…...

【数据结构-邻项消除】力扣735. 小行星碰撞

给定一个整数数组 asteroids,表示在同一行的小行星。 对于数组中的每一个元素,其绝对值表示小行星的大小,正负表示小行星的移动方向(正表示向右移动,负表示向左移动)。每一颗小行星以相同的速度移动。 找…...

002-Kotlin界面开发之Kotlin旋风之旅

Kotlin旋风之旅 Compose Desktop中哪些Kotlin知识是必须的? 在学习Compose Desktop中,以下Kotlin知识是必须的: 基础语法:包括变量声明、数据类型、条件语句、循环等。面向对象编程:类与对象、继承、接口、抽象类等。…...

VMware Workstation Pro for Personal Use (For Windows)

这是从broadcom.com网下载的个人版本的Vmware 17.6.1,存分享不要分。 VMware-workstation-full-17.6.1-24319023.exe(447.93 MB) Build Number: 24319023 Oct 08, 2024 07.33AM SHA2: f95429e395a583eb5ba91f09b040e2f8c53a5e7aa37c4c6bfcaf82115a8…...

论文 | PROMPTAGATOR : FEW-SHOT DENSE RETRIEVAL FROM 8 EXAMPLES

1. 背景信息 在信息检索领域,传统的方法往往依赖于大量的标注数据来训练模型,以便在各种任务中表现良好。然而,许多实际应用中的监督数据是有限的,尤其是在不同的检索任务中。最近的研究开始关注如何从一个拥有丰富监督数据的任务…...

使用 Github 进行项目管理

GitHub 是一个广泛使用的代码托管和协作平台,它提供了强大的工具来支持项目管理和团队协作。在项目开发和工作中,避免不了 Github 的使用,然鹅我一直没有稍微系统地学习过 github 的整个工作流程,对这些操作都是一知半解的&#x…...

企业SRC挖掘选择与信息收集指南

内容预览 ≧∀≦ゞ 企业SRC挖掘选择与信息收集指南导语1. 企业SRC的选择2. 信息收集2.1 集团与子公司2.2 小程序与APP2.3 Web端信息收集 3. 信息收集常用模板总结 企业SRC挖掘选择与信息收集指南 导语 近年来,企业的安全响应中心(SRC)已逐渐…...

Golang | Leetcode Golang题解之第524题通过删除字母匹配到字典里最长单词

题目: 题解: func findLongestWord(s string, dictionary []string) (ans string) {m : len(s)f : make([][26]int, m1)for i : range f[m] {f[m][i] m}for i : m - 1; i > 0; i-- {f[i] f[i1]f[i][s[i]-a] i}outer:for _, t : range dictionary …...

【DBeaver】连接带kerberos的hive[Apache|HDP]

目录 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 1.2 环境配置 二、基于Cloudera驱动创建连接 三、基于Hive原生驱动创建连接 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 在Kerberos官网下载,地址如下:https://web.mit.edu/kerberos…...

Unity3D 开发教程:从入门到精通

Unity3D 开发教程:从入门到精通 Unity3D 是一款强大的跨平台游戏引擎,广泛应用于游戏开发、虚拟现实、增强现实等领域。本文将详细介绍 Unity3D 的基本概念、开发流程以及一些高级技巧,帮助你从零基础到掌握 Unity3D 开发。 目录 Unity3D…...

文件操作和 IO(一):文件基础知识 文件系统操作 => File类

目录 1. 什么是文件 1.1 概念 1.2 硬盘, 内存, 寄存器之间的区别 1.3 机械硬盘和固态硬盘 2. 文件路径 2.1 绝对路径 2.2 相对路径 3. 文件分类 4. File 类 4.1 属性 4.2 构造方法 4.3 方法 1. 什么是文件 1.1 概念 狭义上的文件: 保存在硬盘上的文件广义的上的文…...

用Pyhon写一款简单的益智类小游戏——2048

文字版——代码及讲解 代码—— import random# 初始化游戏棋盘 def init_board():return [[0] * 4 for _ in range(4)]# 在棋盘上随机生成一个2或4 def add_new_tile(board):empty_cells [(i, j) for i in range(4) for j in range(4) if board[i][j] 0]if empty_cells:i,…...

akshare股票涨跌幅自定义范围查询:A股、港股、美股

参看:https://stock.hexun.com/2024-10-31/215251914.html 涨幅计算公式:(当前价格 - 上一个交易日收盘价) 上一个交易日收盘价 100% 。 跌幅计算公式:(上一个交易日收盘价 - 当前价格) 上一个…...

通过js控制修改css变量

在JavaScript中,你可以通过操作CSS变量(也称为自定义属性)来动态改变样式。CSS变量在CSS中使用 – 前缀定义,例如 --main-color: red;。在JavaScript中,你可以使用 document.documentElement.style.setProperty 方法来…...

<HarmonyOS第一课>HarmonyOS SDK开放能力简介的课后习题

不出户&#xff0c;知天下&#xff1b; 不窥牖&#xff0c;见天道。 其出弥远&#xff0c;其知弥少。 是以圣人不行而知&#xff0c;不见而明&#xff0c;不为而成。 本篇<HarmonyOS第一课>HarmonyOS SDK开放能力简介是简单介绍了HarmonyOS SDK&#xff0c;不需要大家过多…...

深度学习:yolo的使用--图像处理

定义了一个名为 ListDataset 的类&#xff0c;它继承自 PyTorch 的 Dataset 类,这个数据集从一个包含图像文件路径的列表中读取图像和对应的标签文件 class ListDataset(Dataset):def __init__(self, list_path, img_size416, augmentTrue, multiscaleTrue, normalized_labelsT…...

TypeScript实用笔记(一):初始化、类型定义与函数使用

文章目录 一、ts初始化1. 初始化.json文件一2. 启动方式2.1 直接运行.ts文件2.2 转换运行 二、类型1. 参数类型1.1 常规参数1.2 symbol1.3 数组\[]1.4 元组\[]1.5 用字面量定义数据类型 2. Object3. 枚举类型\[Enum]3.1 数字枚举3.2 字符串枚举 三、 类型别名1. 数组别名使用2.…...

【大数据学习 | kafka】producer之拦截器,序列化器与分区器

1. 自定义拦截器 interceptor是拦截器&#xff0c;可以拦截到发送到kafka中的数据进行二次处理&#xff0c;它是producer组成部分的第一个组件。 public static class MyInterceptor implements ProducerInterceptor<String,String>{Overridepublic ProducerRecord<…...

零基础学西班牙语,柯桥专业小语种培训泓畅学校

No te comas el coco, seguro que te ha salido bien la entrevista. Ya te llamarn. 别瞎想了&#xff01;我保证你的面试很顺利。他们会给你打电话的。 这里的椰子是"头"的比喻。在西班牙的口语中&#xff0c;我们也可以听到其他同义表达&#xff0c;比如&#x…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

day52 ResNet18 CBAM

在深度学习的旅程中&#xff0c;我们不断探索如何提升模型的性能。今天&#xff0c;我将分享我在 ResNet18 模型中插入 CBAM&#xff08;Convolutional Block Attention Module&#xff09;模块&#xff0c;并采用分阶段微调策略的实践过程。通过这个过程&#xff0c;我不仅提升…...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略&#xff0c;并且实现了基本的选区操作&#xff0c;还调研了自绘选区的实现。那么相对的&#xff0c;我们还需要设计编辑器的选区表达&#xff0c;也可以称为模型选区。编辑器中应用变更时的操作范围&#xff0c;就是以模型选区为基准来…...

镜像里切换为普通用户

如果你登录远程虚拟机默认就是 root 用户&#xff0c;但你不希望用 root 权限运行 ns-3&#xff08;这是对的&#xff0c;ns3 工具会拒绝 root&#xff09;&#xff0c;你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案&#xff1a;创建非 roo…...

ServerTrust 并非唯一

NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)

目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关&#xff0…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。

1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj&#xff0c;再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...