当前位置: 首页 > news >正文

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。
在这里插入图片描述
使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计算最终生成特征图。

在这里插入图片描述
通过池化层进行降维度
在这里插入图片描述

卷积网络结构从输入到输出, 3* 32*32 --> 10:

类型WeightBIAS
卷积(3, 12, 5)(12, 3, 5, 5)12
卷积(12, 12, 5)(12, 12, 5, 5)12
Norm1212
卷积(12, 24, 5)(24, 12, 5, 5)24
卷积(24 24, 5)(24, 24, 5, 5)24
Norm2424
Linear(10, 2400)10

训练分类模型

准备数据
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader# Loading and normalizing the data.
# Define transformations for the training and test sets
transformations = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# CIFAR10 dataset consists of 50K training images. We define the batch size of 10 to load 5,000 batches of images.
batch_size = 10
number_of_labels = 10 # Create an instance for training. 
# When we run this code for the first time, the CIFAR10 train dataset will be downloaded locally. 
train_set =CIFAR10(root="./data",train=True,transform=transformations,download=True)# Create a loader for the training set which will read the data within batch size and put into memory.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
print("The number of images in a training set is: ", len(train_loader)*batch_size)# Create an instance for testing, note that train is set to False.
# When we run this code for the first time, the CIFAR10 test dataset will be downloaded locally. 
test_set = CIFAR10(root="./data", train=False, transform=transformations, download=True)# Create a loader for the test set which will read the data within batch size and put into memory. 
# Note that each shuffle is set to false for the test loader.
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print("The number of images in a test set is: ", len(test_loader)*batch_size)print("The number of batches per epoch is: ", len(train_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
创建网络
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F# Define a convolution neural network
class Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*10*10, 10)def forward(self, input):output = F.relu(self.bn1(self.conv1(input)))      output = F.relu(self.bn2(self.conv2(output)))     output = self.pool(output)                        output = F.relu(self.bn4(self.conv4(output)))     output = F.relu(self.bn5(self.conv5(output)))     output = output.view(-1, 24*10*10)output = self.fc1(output)return output# Instantiate a neural network model 
model = Network()

定义损失函数

使用交叉熵函数作为损失函数,交叉熵分为两种

  • 二分类交叉熵函数
    在这里插入图片描述
  • 多分类交叉熵函数
    在这里插入图片描述
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
模型训练
from torch.autograd import Variable# Function to save the model
def saveModel():path = "./myFirstModel.pth"torch.save(model.state_dict(), path)# Function to test the model with the test dataset and print the accuracy for the test images
def testAccuracy():model.eval()accuracy = 0.0total = 0.0device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")with torch.no_grad():for data in test_loader:images, labels = data# run the model on the test set to predict labelsoutputs = model(images.to(device))# the label with the highest energy will be our prediction_, predicted = torch.max(outputs.data, 1)total += labels.size(0)accuracy += (predicted == labels.to(device)).sum().item()# compute the accuracy over all test imagesaccuracy = (100 * accuracy / total)return(accuracy)# Training function. We simply have to loop over our data iterator and feed the inputs to the network and optimize.
def train(num_epochs):best_accuracy = 0.0# Define your execution devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("The model will be running on", device, "device")# Convert model parameters and buffers to CPU or Cudamodel.to(device)for epoch in range(num_epochs):  # loop over the dataset multiple timesrunning_loss = 0.0running_acc = 0.0for i, (images, labels) in enumerate(train_loader, 0):# get the inputsimages = Variable(images.to(device))labels = Variable(labels.to(device))# zero the parameter gradientsoptimizer.zero_grad()# predict classes using images from the training setoutputs = model(images)# compute the loss based on model output and real labelsloss = loss_fn(outputs, labels)# backpropagate the lossloss.backward()# adjust parameters based on the calculated gradientsoptimizer.step()# Let's print statistics for every 1,000 imagesrunning_loss += loss.item()     # extract the loss valueif i % 1000 == 999:    # print every 1000 (twice per epoch) print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 1000))# zero the lossrunning_loss = 0.0# Compute and print the average accuracy fo this epoch when tested over all 10000 test imagesaccuracy = testAccuracy()print('For epoch', epoch+1,'the test accuracy over the whole test set is %d %%' % (accuracy))# we want to save the model if the accuracy is the bestif accuracy > best_accuracy:saveModel()best_accuracy = accuracy
测试模型
import matplotlib.pyplot as plt
import numpy as np# Function to show the images
def imageshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# Function to test the model with a batch of images and show the labels predictions
def testBatch():# get batch of images from the test DataLoader  images, labels = next(iter(test_loader))# show all images as one image gridimageshow(torchvision.utils.make_grid(images))# Show the real labels on the screen print('Real labels: ', ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# Let's see what if the model identifiers the  labels of those exampleoutputs = model(images)# We got the probability for every 10 labels. The highest (max) probability should be correct label_, predicted = torch.max(outputs, 1)# Let's show the predicted labels on the screen to compare with the real onesprint('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(batch_size)))
执行模型
if __name__ == "__main__":# Let's build our modeltrain(5)print('Finished Training')# Test which classes performed welltestAccuracy()# Let's load the model we just created and test the accuracy per labelmodel = Network()path = "myFirstModel.pth"model.load_state_dict(torch.load(path))# Test with batch of imagestestBatch()

在这里插入图片描述

总结

pytorch 搭建一个 CNN 模型比较简单,5 轮训练之后,效果就可以达到 60%,10 张图片中预测对了 6 张。

相关文章:

Pytorch 实现图片分类

CNN 网络适用于图片识别,卷积神经网络主要用于图片的处理识别。卷积神经网络,包括一下几部分,输入层、卷积层、池化层、全链接层和输出层。 使用 CIFAR-10 进行训练, CIFAR-10 中图片尺寸为 32 * 32。卷积层通过卷积核移动进行计…...

得物App获评新奖项,正品保障夯实供应链创新水平

近日,得物App再度获评新奖项——“2024上海市供应链创新与应用优秀案例”。 本次奖项为上海市供应链领域最高奖项,旨在评选出在供应链创新成效上处于领先地位、拥有成功模式和经验的企业。今年以来,得物App已接连获得“上海市质量金奖”、“科…...

【数据结构-邻项消除】力扣735. 小行星碰撞

给定一个整数数组 asteroids,表示在同一行的小行星。 对于数组中的每一个元素,其绝对值表示小行星的大小,正负表示小行星的移动方向(正表示向右移动,负表示向左移动)。每一颗小行星以相同的速度移动。 找…...

002-Kotlin界面开发之Kotlin旋风之旅

Kotlin旋风之旅 Compose Desktop中哪些Kotlin知识是必须的? 在学习Compose Desktop中,以下Kotlin知识是必须的: 基础语法:包括变量声明、数据类型、条件语句、循环等。面向对象编程:类与对象、继承、接口、抽象类等。…...

VMware Workstation Pro for Personal Use (For Windows)

这是从broadcom.com网下载的个人版本的Vmware 17.6.1,存分享不要分。 VMware-workstation-full-17.6.1-24319023.exe(447.93 MB) Build Number: 24319023 Oct 08, 2024 07.33AM SHA2: f95429e395a583eb5ba91f09b040e2f8c53a5e7aa37c4c6bfcaf82115a8…...

论文 | PROMPTAGATOR : FEW-SHOT DENSE RETRIEVAL FROM 8 EXAMPLES

1. 背景信息 在信息检索领域,传统的方法往往依赖于大量的标注数据来训练模型,以便在各种任务中表现良好。然而,许多实际应用中的监督数据是有限的,尤其是在不同的检索任务中。最近的研究开始关注如何从一个拥有丰富监督数据的任务…...

使用 Github 进行项目管理

GitHub 是一个广泛使用的代码托管和协作平台,它提供了强大的工具来支持项目管理和团队协作。在项目开发和工作中,避免不了 Github 的使用,然鹅我一直没有稍微系统地学习过 github 的整个工作流程,对这些操作都是一知半解的&#x…...

企业SRC挖掘选择与信息收集指南

内容预览 ≧∀≦ゞ 企业SRC挖掘选择与信息收集指南导语1. 企业SRC的选择2. 信息收集2.1 集团与子公司2.2 小程序与APP2.3 Web端信息收集 3. 信息收集常用模板总结 企业SRC挖掘选择与信息收集指南 导语 近年来,企业的安全响应中心(SRC)已逐渐…...

Golang | Leetcode Golang题解之第524题通过删除字母匹配到字典里最长单词

题目: 题解: func findLongestWord(s string, dictionary []string) (ans string) {m : len(s)f : make([][26]int, m1)for i : range f[m] {f[m][i] m}for i : m - 1; i > 0; i-- {f[i] f[i1]f[i][s[i]-a] i}outer:for _, t : range dictionary …...

【DBeaver】连接带kerberos的hive[Apache|HDP]

目录 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 1.2 环境配置 二、基于Cloudera驱动创建连接 三、基于Hive原生驱动创建连接 一、安装配置Kerberos客户端环境 1.1 安装Kerberos客户端 在Kerberos官网下载,地址如下:https://web.mit.edu/kerberos…...

Unity3D 开发教程:从入门到精通

Unity3D 开发教程:从入门到精通 Unity3D 是一款强大的跨平台游戏引擎,广泛应用于游戏开发、虚拟现实、增强现实等领域。本文将详细介绍 Unity3D 的基本概念、开发流程以及一些高级技巧,帮助你从零基础到掌握 Unity3D 开发。 目录 Unity3D…...

文件操作和 IO(一):文件基础知识 文件系统操作 => File类

目录 1. 什么是文件 1.1 概念 1.2 硬盘, 内存, 寄存器之间的区别 1.3 机械硬盘和固态硬盘 2. 文件路径 2.1 绝对路径 2.2 相对路径 3. 文件分类 4. File 类 4.1 属性 4.2 构造方法 4.3 方法 1. 什么是文件 1.1 概念 狭义上的文件: 保存在硬盘上的文件广义的上的文…...

用Pyhon写一款简单的益智类小游戏——2048

文字版——代码及讲解 代码—— import random# 初始化游戏棋盘 def init_board():return [[0] * 4 for _ in range(4)]# 在棋盘上随机生成一个2或4 def add_new_tile(board):empty_cells [(i, j) for i in range(4) for j in range(4) if board[i][j] 0]if empty_cells:i,…...

akshare股票涨跌幅自定义范围查询:A股、港股、美股

参看:https://stock.hexun.com/2024-10-31/215251914.html 涨幅计算公式:(当前价格 - 上一个交易日收盘价) 上一个交易日收盘价 100% 。 跌幅计算公式:(上一个交易日收盘价 - 当前价格) 上一个…...

通过js控制修改css变量

在JavaScript中,你可以通过操作CSS变量(也称为自定义属性)来动态改变样式。CSS变量在CSS中使用 – 前缀定义,例如 --main-color: red;。在JavaScript中,你可以使用 document.documentElement.style.setProperty 方法来…...

<HarmonyOS第一课>HarmonyOS SDK开放能力简介的课后习题

不出户&#xff0c;知天下&#xff1b; 不窥牖&#xff0c;见天道。 其出弥远&#xff0c;其知弥少。 是以圣人不行而知&#xff0c;不见而明&#xff0c;不为而成。 本篇<HarmonyOS第一课>HarmonyOS SDK开放能力简介是简单介绍了HarmonyOS SDK&#xff0c;不需要大家过多…...

深度学习:yolo的使用--图像处理

定义了一个名为 ListDataset 的类&#xff0c;它继承自 PyTorch 的 Dataset 类,这个数据集从一个包含图像文件路径的列表中读取图像和对应的标签文件 class ListDataset(Dataset):def __init__(self, list_path, img_size416, augmentTrue, multiscaleTrue, normalized_labelsT…...

TypeScript实用笔记(一):初始化、类型定义与函数使用

文章目录 一、ts初始化1. 初始化.json文件一2. 启动方式2.1 直接运行.ts文件2.2 转换运行 二、类型1. 参数类型1.1 常规参数1.2 symbol1.3 数组\[]1.4 元组\[]1.5 用字面量定义数据类型 2. Object3. 枚举类型\[Enum]3.1 数字枚举3.2 字符串枚举 三、 类型别名1. 数组别名使用2.…...

【大数据学习 | kafka】producer之拦截器,序列化器与分区器

1. 自定义拦截器 interceptor是拦截器&#xff0c;可以拦截到发送到kafka中的数据进行二次处理&#xff0c;它是producer组成部分的第一个组件。 public static class MyInterceptor implements ProducerInterceptor<String,String>{Overridepublic ProducerRecord<…...

零基础学西班牙语,柯桥专业小语种培训泓畅学校

No te comas el coco, seguro que te ha salido bien la entrevista. Ya te llamarn. 别瞎想了&#xff01;我保证你的面试很顺利。他们会给你打电话的。 这里的椰子是"头"的比喻。在西班牙的口语中&#xff0c;我们也可以听到其他同义表达&#xff0c;比如&#x…...

vJoy虚拟摇杆驱动技术架构深度解析

vJoy虚拟摇杆驱动技术架构深度解析 【免费下载链接】vJoy Virtual Joystick 项目地址: https://gitcode.com/gh_mirrors/vj/vJoy 在Windows游戏开发和输入设备模拟领域&#xff0c;虚拟控制器技术扮演着关键角色。vJoy作为一款开源的虚拟摇杆驱动&#xff0c;通过内核模…...

QT 5.13.0离线安装指南:绕过账号验证的实用技巧

1. QT 5.13.0离线安装的必要性与场景 在企业开发环境中&#xff0c;经常会遇到内网隔离或网络访问受限的情况。这时候传统的在线安装方式就会遇到麻烦——QT从5.12版本开始强制要求用户登录账号才能继续安装。我去年给某银行做系统迁移时就碰到这个问题&#xff0c;他们的开发机…...

Alibaba DASD-4B Thinking 对话工具 MathType 公式编辑技巧与 LaTeX 转换助手

Alibaba DASD-4B Thinking 对话工具&#xff1a;你的智能公式编辑与 LaTeX 转换助手 写论文、做报告&#xff0c;最头疼的是什么&#xff1f;对我而言&#xff0c;除了查文献&#xff0c;就是处理公式了。尤其是当导师要求用 LaTeX 排版&#xff0c;而我却习惯在 Word 里用 Ma…...

SDMatte多模态输入探索:结合文本描述实现指代性抠图

SDMatte多模态输入探索&#xff1a;结合文本描述实现指代性抠图 1. 效果亮点预览 想象一下这样的场景&#xff1a;面对一张复杂的家庭聚会照片&#xff0c;你只需要输入"穿红色衣服的人"&#xff0c;AI就能自动识别并精确抠出目标人物。这正是SDMatte最新探索的多模…...

终极指南:Fay开源项目技术路线图重大调整,全面响应社区反馈

终极指南&#xff1a;Fay开源项目技术路线图重大调整&#xff0c;全面响应社区反馈 【免费下载链接】Fay fay是一个帮助数字人&#xff08;2.5d、3d、移动、pc、网页&#xff09;或大语言模型&#xff08;openai兼容、deepseek&#xff09;连通业务系统的agent框架。 项目地址…...

计算机网络之TCP和UDP的底层机制

文章目录 1. TCP和UDP区别&#xff1f;2.TCP为什么可靠传输3. 怎么用UDP实现HTTP&#xff1f;4. TCP粘包怎么解决5. 滑动窗口6. 拥塞控制 1. TCP和UDP区别&#xff1f; TCP&#xff1a; 报头 TCP发送数据 客户端&#xff1a; #include <iostream> #include <strin…...

Agilent E5100A 高速网络分析仪

10 kHz 至 180 MHz/300M 提供快速测量&#xff08;扫描速度高达 0.04 ms/点&#xff09;、快速波形分析命令和高速处理器&#xff0c;可提高生产线的生产效率 使用波形分析命令和相位跟踪功能更快速地完成滤波器和谐振器评测 使用嵌入式 IBASIC 更轻松地开发自动化程序 使用蒸发…...

3文件搞定AI编程:极简工作流让AI从“拖油瓶“变“得力助手

针对当前AI编程效率低下的痛点&#xff0c;本文提出了一套只需3个文件的极简工作流方案。通过分析AI编程的三个进化阶段&#xff08;氛围编程→规格先行→自主代理&#xff09;&#xff0c;作者发现关键在于为AI提供明确任务指引&#xff08;task.md&#xff09;、标准工作流程…...

如何免费搭建个人数字图书馆:番茄小说下载器终极指南

如何免费搭建个人数字图书馆&#xff1a;番茄小说下载器终极指南 【免费下载链接】fanqienovel-downloader 下载番茄小说 项目地址: https://gitcode.com/gh_mirrors/fa/fanqienovel-downloader 还在为网络小说平台限制、网络不稳定或小说突然下架而烦恼吗&#xff1f;今…...

【GitHub项目推荐--Plane:开源版 JIRA,让项目管理回归“有序”】⭐⭐⭐

GitHub 地址&#xff1a;https://github.com/makeplane/plane 简介 Plane​ 是一个现代化的开源项目管理平台&#xff0c;被广泛认为是 JIRA、Linear 和 Asana 的开源替代品。它专为追求效率的研发和产品团队设计&#xff0c;将问题跟踪、敏捷迭代、文档协作和产品路线图统一在…...