GooLeNet模型搭建
一、model
import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU = nn.ReLU()#路线1:1x1卷积self.p1_1 = nn.Conv2d(in_channels = in_channels, out_channels = c1, kernel_size = 1)#路线2:1x1卷积->3x3卷积self.p2_1 = nn.Conv2d(in_channels = in_channels, out_channels = c2[0], kernel_size = 1)self.p2_2 = nn.Conv2d(in_channels = c2[0], out_channels = c2[1], kernel_size = 3, padding = 1)#路线3:1x1卷积->5x5卷积self.p3_1 = nn.Conv2d(in_channels = in_channels, out_channels = c3[0], kernel_size = 1)self.p3_2 = nn.Conv2d(in_channels = c3[0], out_channels = c3[1], kernel_size = 5, padding = 2)#路线4:3x3最大池化->1x1卷积self.p4_1 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)self.p4_2 = nn.Conv2d(in_channels = in_channels, out_channels = c4, kernel_size = 1)#前向传播def forward(self, x):#路线1p1 = self.ReLU(self.p1_1(x))#路线2p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))#路线3p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))#路线4p4 = self.ReLU(self.p4_2(self.p4_1(x)))#拼接out = torch.cat((p1, p2, p3, p4), dim = 1)return outclass GoogLeNet(nn.Module):def __init__(self,Inception):super(GoogLeNet, self).__init__()self.b1 = nn.Sequential(nn.Conv2d(in_channels = 1,out_channels = 64,kernel_size=7,stride=2,padding = 3),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1))self.b2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))#Inception模块self.b3 = nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b4 = nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48,128), 128),nn.AdaptiveAvgPool2d((1,1)),#平展层nn.Flatten(),nn.Linear(1024,10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.b1(x)x = self.b2(x)x = self.b3(x)x = self.b4(x)x = self.b5(x)return xif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = GoogLeNet(Inception).to(device)print(summary(model,(1,224,224)))
二、model_train
import copy
import time
from torchvision.datasets import FashionMNIST
import numpy as np
from torchvision import transforms
import torch.utils.data as Data
import matplotlib.pyplot as plt
from model import GoogLeNet,Inception
import torch
import pandas as pddef train_val_data_process():# 数据预处理train_data = FashionMNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.Resize(size=224),transforms.ToTensor()]))# 数据集划分train_data,val_data = Data.random_split(train_data,[round(0.8*len(train_data)),round(0.2*len(train_data))])train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True, #数据打乱num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)#返回数据集return train_dataloader,val_dataloader# 训练模型
def train_model_process(model,train_dataloader,val_dataloader,num_epochs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#定义优化器,optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#定义损失函数(交叉熵损失函数)criterion = torch.nn.CrossEntropyLoss()model = model.to(device)#复制当前模型的参数best_model_wts = copy.deepcopy(model.state_dict())#初始化参数best_acc = 0.0#训练集与验证集损失函数列表train_loss_all = []val_loss_all = []#训练集与验证集准确率列表train_acc_all = []val_acc_all = []since = time.time()for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有训练集训练过程和验证集评估过程train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 训练集与验证集样本数量train_num = 0val_num = 0for step,(b_x,b_y) in enumerate(train_dataloader):#将特征和标签数据放入GPU中b_x = b_x.to(device)b_y = b_y.to(device)#设置模型为训练模式model.train()#前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)#将梯度置零,因为PyTorch中梯度是累加的optimizer.zero_grad()#反向传播,计算梯度loss.backward()#根据网络反向传播的梯度更新网络参数,达到降低loss的目的optimizer.step()#计算训练集的损失train_loss = train_loss + loss.item() * b_x.size(0)train_corrects = train_corrects + torch.sum(pre_lab == b_y.data)train_num = train_num + b_x.size(0)for step,(b_x,b_y) in enumerate(val_dataloader):# 将特征和标签数据放入验证设备中b_x = b_x.to(device)b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)# 计算验证集的损失val_loss = val_loss + loss.item() * b_x.size(0)val_corrects = val_corrects + torch.sum(pre_lab == b_y.data)val_num = val_num + b_x.size(0)# 计算训练集与测试集每个epoch的损失和准确率train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double() / val_num)# 打印训练集和验证集准确率print('Train Loss: {:.4f} Acc: {:.4f}'.format(train_loss_all[-1], train_acc_all[-1]))print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc: # 保存准确率最高的模型参数best_acc = val_acc_all[-1]#保存当前的最高准确度best_model_wts = copy.deepcopy(model.state_dict())#训练耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use / 60,time_use % 60))#选择最优的模型参数#加载最高准确率的模型参数# model.load_state_dict(best_model_wts)# torch.save(model.load_state_dict(best_model_wts), 'best_model.pth')torch.save(best_model_wts, 'best_model.pth')train_porcess = pd.DataFrame(data = {'epoch':range(num_epochs),'train_loss_all':train_loss_all,'train_acc_all':train_acc_all,'val_loss_all':val_loss_all,'val_acc_all':val_acc_all})return train_porcessdef matplot_acc_loss(train_process):plt.figure(figsize=(12, 4))plt.subplot(1,2,1)plt.plot(train_process["epoch"],train_process["train_loss_all"],'ro-',label="train_loss")plt.plot(train_process["epoch"],train_process["val_loss_all"],'bo-',label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bo-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("acc")plt.show()if __name__=='__main__':#模型实例化GoogLeNet = GoogLeNet(Inception)train_dataloader,val_dataloader = train_val_data_process()train_porcess = train_model_process(GoogLeNet, train_dataloader,val_dataloader,num_epochs=10)matplot_acc_loss(train_porcess)
三、model_test
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import GoogLeNet,Inception
import matplotlib.pyplot as plt
import numpy as np
import os
import timedef test_data_process():# 数据预处理test_data = FashionMNIST(root='./data',train=False,download=True,transform=transforms.Compose([transforms.Resize(size=224), transforms.ToTensor()]))test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True, # 数据打乱num_workers=2)return test_dataloader#测试过程
def test_model_process(model,test_dataloader):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#将模型放入设备中model = model.to(device)#初始化参数test_corrects = 0test_num = 0#只进行前向传播,不计算梯度,节省内存,提高速度with torch.no_grad():for test_data_x,test_data_y in test_dataloader:test_data_x = test_data_x.to(device)test_data_y = test_data_y.to(device)#设置模型为评估模式model.eval()#前向传播,输入为测试数据集,输出为每个样本的预测值test_output = model(test_data_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(test_output, 1) #预测标签test_corrects += torch.sum(pre_lab == test_data_y.data) #预测正确的数量test_num = test_num + test_data_y.size(0) #测试样本数量#计算测试集的准确率test_acc = test_corrects.double().item() / test_num#打印测试集的准确率print('Test Accuracy: {:.4f}'.format(test_acc))if __name__ == '__main__':model = GoogLeNet(Inception)#加载训练好的模型model.load_state_dict(torch.load('best_model.pth'))test_dataloader = test_data_process() #加载测试数据# test_model_process(model,test_dataloader) #测试模型#模型推理device = 'cuda' if torch.cuda.is_available() else 'cpu'model = model.to(device)#将模型设置为评估模式with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)#模型设为验证模型model.eval()output = model(b_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(output, dim = 1) #预测标签result = pre_lab.item()label = b_y.item()#打印预测结果print('预测结果:',result,"--------","真实值:",label)
相关文章:

GooLeNet模型搭建
一、model import torch from torch import nn from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU nn.ReLU()#路线1:1x1卷积self.p1_1 nn.Conv2d(in_channels i…...

使用ThreadLocal来存取单线程内的数据
一.什么是ThreadLocal? ThreadLocal,即线程本地变量。如果你创建了一个 ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个本地拷贝,多个线程操作这个变量的时候,实际是在操作自己本地内存里面的变量&…...

elasticsearch教程
1. 单点部署(rpm): #提前关闭firewalld,否则无法组建集群 #1. 下载ES rpm包 ]# https://www.elastic.co/cn/downloads #2. 安装es ]# rpm -ivh elasticsearch-7.17.5-x86_64.rpm #3. 调整内核参数(太低的话es会启动报错) echo "vm.max_map_count655360 fs.file-max 655…...

Arrays、Lambda表达式、Collection集合
1. Arrays 1.1 操作数组的工具类 方法名说明public static String toString(数组)把数组拼接成一个字符串public static int binarySearch(数组,查找的元素)二分查找法查找元素public static int[] copyOf(原数组,新数组长度)拷贝数组public static int[] copyOfRange(原数组…...

2024年前端趋势:全栈或许是不容错过的选择!
近年来,前端开发的技术不断推陈出新,2024年也不例外。在这个变化迅速的领域,全栈开发逐渐成为一股不容忽视的趋势。无论你是经验丰富的开发者,还是刚刚入门的新手,掌握全栈技术都能让你在竞争中脱颖而出。而在这个过程…...

MySQL 实战 45 讲(01-05)
本文为笔者学习林晓斌老师《MySQL 实战 45 讲》课程的学习笔记,并进行了一定的知识扩充。 sql 查询语句的执行流程 大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。 Server 层包括连接器、查询缓存、分析器、优化器和执行器。 连接器负责接收客…...

仓颉编程语言入门 -- Array数组详解
仓颉编程语言入门 – Array数组详解 一. 如何创建Array数组 我们可以使用 Array 类型来构造单一元素类型,有序序列的数据。 1.仓颉使用 Array 来表示 Array 类型。T 表示 Array 的元素类型,T 可以是任意类型 , 类似于泛型的概念 var arr:Array<St…...

C#初级——简单单例模式使用
单例模式 单例模式是一种常用的软件设计模式,它确保一个类只有一个实例,并提供一个全局访问点来获取这个实例,通过单例模式防止私有成员被多次引用,防止数据被随意纂改。本文使用的是线程不安全的懒汉式单例。 创建单例模式 首…...

2024.07.29 校招 实习 内推 面经
地/球🌍 : neituijunsir 交* 流*裙 ,内推/实习/校招汇总表格 1、校招 | 美/团// 快驴、小象、优/选/事/业/部2024年校/园/招聘(内推) 校招 | 美团快驴、小象、优选事业部2024年校园招聘(内推ÿ…...

速盾:爬虫攻击和cc攻击的区别是什么?
爬虫攻击和CC(Distributed Denial of Service)攻击是网络安全领域两种不同类型的攻击方式。尽管它们都涉及对目标网站或服务器的非法访问,但它们的目的、方法和影响各不相同。在接下来的文章中,我们将详细介绍这两种攻击方式的区别…...

Tomcat与Nginx的区别详解
目录 引言Tomcat概述 Tomcat的历史Tomcat的架构Tomcat的功能Nginx概述 Nginx的历史Nginx的架构Nginx的功能Tomcat与Nginx的区别 架构上的区别...

【大模型从入门到精通5】openAI API高级内容审核-1
这里写目录标题 高级内容审核利用 OpenAI 内容审核 API 的高级内容审核技术整合与实施使用自定义规则增强审核综合示例防止提示注入的策略使用分隔符隔离命令理解分隔符使用分隔符实现命令隔离 高级内容审核 利用 OpenAI 内容审核 API 的高级内容审核技术 OpenAI 内容审核 AP…...

JVM系列 | 对象的消亡3——垃圾收集器的对比与实现细节
垃圾收集器 文章目录 各收集器简单对比收集器启动参数各收集器详细说明JDK 1.3 之前JDK 1.3 | SerialJDK 1.4 | ParNewJDK 1.4 | Parallel ScavengeJDK 5 | CMS 收集器JDK 7 | G1 各收集器简单对比 收集器名称出现时间淘汰时间目标采用技术线程数STW分代备注无名JDK 1.3之前JD…...

C# Unity 面向对象补全计划 七大原则 之 开闭原则(OCP) 难度:☆ 总结:已经写好的就别动它了,多用继承
本文仅作学习笔记与交流,不作任何商业用途,作者能力有限,如有不足还请斧正 本系列作为七大原则和设计模式的进阶知识,看不懂没关系 请看专栏:http://t.csdnimg.cn/mIitr,查漏补缺 1.开闭原则(OC…...

微信防封指南请收好
一、新号与老号的添加限制 建议新注册的微信号主动添加好友的数量不宜过多,推荐每日添加不超过5个好友;对于老号,建议每日添加不超过20个好友。保持适度的添加速度,避免被系统判定为异常操作。 二、避免使用营销性词汇 在发送消…...

选择排序算法改进思路和算法实现
选择排序 在未排序的数组中,用第一个数去和后面的数比较,找出最小的数,和第一个数交换。第一个数已为已排序的数。 相当于0~7 从0~7中找到最小的数放在0 从1~7中找到最小的数放在1 从2~7中找到最小的数放在2 ...以此类推 从6~7中找到最…...

【文件解析漏洞复现】
一.IIS解析漏洞复现 1.IIS6.X 方式一:目录解析 搭建IIS环境 在网站下建立文件夹的名字为.asp/.asa 的文件夹,其目录内的任何扩展名的文件都被IIS当作asp文件来解析并执行。 访问成功被解析 方式一:目录解析 在IIS 6处理文件解…...

【STL】 vector的底层实现
1.vector的模拟代码完整实现(后面会拆分开一个一个细讲) #pragma once #include<assert.h>// 抓重点namespace bit {/*template<class T>class vector{public:typedef T* iterator;private:T* _a;size_t _size;size_t _capacity;};*/templa…...

责任链模式:解耦职责,优化请求处理
在软件设计中,如何有效地处理复杂的请求是一个重要的课题。 责任链模式(Chain of Responsibility Pattern)提供了一种解耦请求发送者和接收者的方法,使得多个对象都有机会处理请求,从而达到灵活和可扩展的设计。 什么…...

【Scene Transformer】scene transformer论文阅读笔记
文章目录 序言(Abstract)(Introduction)(Related Work)(Methods)(Scene-centric Representation for Agents and Road Graphs)(Encoding Transformer)(Predicting Probabilities for Each Futures)(Joint and Marginal Loss Formulation) (Results)(Discussion)(Questions) sce…...

ESP32在ESP-IDF环境下禁用看门狗
最近使用了一款ESP32的开发板。但在调试时发现出现许多看门狗复位事件: E (8296) task_wdt: Task watchdog got triggered. The following tasks/users did not reset the watchdog in time: E (8296) task_wdt: - IDLE (CPU 0) E (8296) task_wdt: Tasks curre…...

基于 uniapp html5plus API,怎么把图片保存到相册
要将图片保存到相册中,可以使用HTML5 API中的plus.gallery.save方法。以下是一个示例代码,展示如何将图片保存到手机相册: // 图片的URL,可以是本地路径或网络路径 var imageUrl path/to/your/image.jpg;// 调用plus.gallery.sa…...

3.特征工程-特征抽取、特征预处理、特征降维
文章目录 环境配置(必看)头文件引用1.数据集: sklearn代码运行结果 2.字典特征抽取: DictVectorizer代码运行结果稀疏矩阵 3.文本特征抽取(英文文本): CountVectorizer()代码运行结果 4.中文文本分词(中文文本特征抽取使用)代码运行结果 5.中文文本特征抽…...

RISC-V (五)上下文切换和协作式多任务
任务(task) 所谓的任务就是寄存器的当前值。 -smp后面的数字指的是hart的个数,qemu模拟器最大可以有8个核,此文围绕一个核来讲。 QEMU qemu-system-riscv32 QFLAG -nographic -smp 1 -machine virt -bios none 协作式多任务 …...

Cornerstone加载本地Dicom文件第二弹 - Blob篇
🍀 引言 当我们刚接触Cornerstone或拿到一组Dicom文件时,如果没有ImageID和后台接口,可能只是想简单测试Cornerstone能否加载这些Dicom文件。在这种情况下,可以使用本地文件加载的方法。之前我们介绍了通过node启动服务器请求文件…...

C语言中整数类型及其类型转换
1.数据的存储和排列 是的,在C语言中,整数类型通常以补码(twos complement)形式存储在内存中。这是因为补码表示法在处理有符号整数的加减运算上更为简便和高效。 2.有符号数和无符号数之间的转换 在C语言中,有符号数和…...

powerjob连接postgresql数据库(支持docker部署)
1.先去pg建一个powerjob-product库 2.首先去拉最新的包,然后找到server模块,把mysql的配置文件信息替换成pg的 spring.datasource.hikari.auto-committrue spring.datasource.remote.hibernate.properties.hibernate.dialecttech.powerjob.server.pers…...

浅谈位运算及其应用(c++)
目录 一、位运算的基础(一)位与(&)(二)位或(|)(三)位异或(^)(四)位取反(~)&#x…...

Git版本管理中下列不适于Git的本地工作区域的是
Git版本管理中下列不适于Git的本地工作区域的是 A. 工作目录 B. 代码区 C. 暂存区 D. 资源库 选择B Git本地有四个工作区域: 工作目录(Working Directory)、 暂存区(Stage/Index)、 资源库(Repository或Git Directory)、 git仓库(Remote Di…...

webGL + WebGIS + 数据可视化
webGL: 解释:用于在浏览器中渲染 2D 和 3D 图形。它是基于 OpenGL ES 的,提供了直接操作 GPU 的能力。 库: Three.jsBabylon.jsPixiJSReglGlMatrixOsgjs WebGIS: 解释:用于在 Web 浏览器中处理和展示地…...