GooLeNet模型搭建
一、model
import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU = nn.ReLU()#路线1:1x1卷积self.p1_1 = nn.Conv2d(in_channels = in_channels, out_channels = c1, kernel_size = 1)#路线2:1x1卷积->3x3卷积self.p2_1 = nn.Conv2d(in_channels = in_channels, out_channels = c2[0], kernel_size = 1)self.p2_2 = nn.Conv2d(in_channels = c2[0], out_channels = c2[1], kernel_size = 3, padding = 1)#路线3:1x1卷积->5x5卷积self.p3_1 = nn.Conv2d(in_channels = in_channels, out_channels = c3[0], kernel_size = 1)self.p3_2 = nn.Conv2d(in_channels = c3[0], out_channels = c3[1], kernel_size = 5, padding = 2)#路线4:3x3最大池化->1x1卷积self.p4_1 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)self.p4_2 = nn.Conv2d(in_channels = in_channels, out_channels = c4, kernel_size = 1)#前向传播def forward(self, x):#路线1p1 = self.ReLU(self.p1_1(x))#路线2p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))#路线3p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))#路线4p4 = self.ReLU(self.p4_2(self.p4_1(x)))#拼接out = torch.cat((p1, p2, p3, p4), dim = 1)return outclass GoogLeNet(nn.Module):def __init__(self,Inception):super(GoogLeNet, self).__init__()self.b1 = nn.Sequential(nn.Conv2d(in_channels = 1,out_channels = 64,kernel_size=7,stride=2,padding = 3),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1))self.b2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))#Inception模块self.b3 = nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b4 = nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48,128), 128),nn.AdaptiveAvgPool2d((1,1)),#平展层nn.Flatten(),nn.Linear(1024,10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.b1(x)x = self.b2(x)x = self.b3(x)x = self.b4(x)x = self.b5(x)return xif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = GoogLeNet(Inception).to(device)print(summary(model,(1,224,224)))
二、model_train
import copy
import time
from torchvision.datasets import FashionMNIST
import numpy as np
from torchvision import transforms
import torch.utils.data as Data
import matplotlib.pyplot as plt
from model import GoogLeNet,Inception
import torch
import pandas as pddef train_val_data_process():# 数据预处理train_data = FashionMNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.Resize(size=224),transforms.ToTensor()]))# 数据集划分train_data,val_data = Data.random_split(train_data,[round(0.8*len(train_data)),round(0.2*len(train_data))])train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True, #数据打乱num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)#返回数据集return train_dataloader,val_dataloader# 训练模型
def train_model_process(model,train_dataloader,val_dataloader,num_epochs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#定义优化器,optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#定义损失函数(交叉熵损失函数)criterion = torch.nn.CrossEntropyLoss()model = model.to(device)#复制当前模型的参数best_model_wts = copy.deepcopy(model.state_dict())#初始化参数best_acc = 0.0#训练集与验证集损失函数列表train_loss_all = []val_loss_all = []#训练集与验证集准确率列表train_acc_all = []val_acc_all = []since = time.time()for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有训练集训练过程和验证集评估过程train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 训练集与验证集样本数量train_num = 0val_num = 0for step,(b_x,b_y) in enumerate(train_dataloader):#将特征和标签数据放入GPU中b_x = b_x.to(device)b_y = b_y.to(device)#设置模型为训练模式model.train()#前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)#将梯度置零,因为PyTorch中梯度是累加的optimizer.zero_grad()#反向传播,计算梯度loss.backward()#根据网络反向传播的梯度更新网络参数,达到降低loss的目的optimizer.step()#计算训练集的损失train_loss = train_loss + loss.item() * b_x.size(0)train_corrects = train_corrects + torch.sum(pre_lab == b_y.data)train_num = train_num + b_x.size(0)for step,(b_x,b_y) in enumerate(val_dataloader):# 将特征和标签数据放入验证设备中b_x = b_x.to(device)b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)# 计算验证集的损失val_loss = val_loss + loss.item() * b_x.size(0)val_corrects = val_corrects + torch.sum(pre_lab == b_y.data)val_num = val_num + b_x.size(0)# 计算训练集与测试集每个epoch的损失和准确率train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double() / val_num)# 打印训练集和验证集准确率print('Train Loss: {:.4f} Acc: {:.4f}'.format(train_loss_all[-1], train_acc_all[-1]))print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc: # 保存准确率最高的模型参数best_acc = val_acc_all[-1]#保存当前的最高准确度best_model_wts = copy.deepcopy(model.state_dict())#训练耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use / 60,time_use % 60))#选择最优的模型参数#加载最高准确率的模型参数# model.load_state_dict(best_model_wts)# torch.save(model.load_state_dict(best_model_wts), 'best_model.pth')torch.save(best_model_wts, 'best_model.pth')train_porcess = pd.DataFrame(data = {'epoch':range(num_epochs),'train_loss_all':train_loss_all,'train_acc_all':train_acc_all,'val_loss_all':val_loss_all,'val_acc_all':val_acc_all})return train_porcessdef matplot_acc_loss(train_process):plt.figure(figsize=(12, 4))plt.subplot(1,2,1)plt.plot(train_process["epoch"],train_process["train_loss_all"],'ro-',label="train_loss")plt.plot(train_process["epoch"],train_process["val_loss_all"],'bo-',label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bo-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("acc")plt.show()if __name__=='__main__':#模型实例化GoogLeNet = GoogLeNet(Inception)train_dataloader,val_dataloader = train_val_data_process()train_porcess = train_model_process(GoogLeNet, train_dataloader,val_dataloader,num_epochs=10)matplot_acc_loss(train_porcess)
三、model_test
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import GoogLeNet,Inception
import matplotlib.pyplot as plt
import numpy as np
import os
import timedef test_data_process():# 数据预处理test_data = FashionMNIST(root='./data',train=False,download=True,transform=transforms.Compose([transforms.Resize(size=224), transforms.ToTensor()]))test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True, # 数据打乱num_workers=2)return test_dataloader#测试过程
def test_model_process(model,test_dataloader):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#将模型放入设备中model = model.to(device)#初始化参数test_corrects = 0test_num = 0#只进行前向传播,不计算梯度,节省内存,提高速度with torch.no_grad():for test_data_x,test_data_y in test_dataloader:test_data_x = test_data_x.to(device)test_data_y = test_data_y.to(device)#设置模型为评估模式model.eval()#前向传播,输入为测试数据集,输出为每个样本的预测值test_output = model(test_data_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(test_output, 1) #预测标签test_corrects += torch.sum(pre_lab == test_data_y.data) #预测正确的数量test_num = test_num + test_data_y.size(0) #测试样本数量#计算测试集的准确率test_acc = test_corrects.double().item() / test_num#打印测试集的准确率print('Test Accuracy: {:.4f}'.format(test_acc))if __name__ == '__main__':model = GoogLeNet(Inception)#加载训练好的模型model.load_state_dict(torch.load('best_model.pth'))test_dataloader = test_data_process() #加载测试数据# test_model_process(model,test_dataloader) #测试模型#模型推理device = 'cuda' if torch.cuda.is_available() else 'cpu'model = model.to(device)#将模型设置为评估模式with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)#模型设为验证模型model.eval()output = model(b_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(output, dim = 1) #预测标签result = pre_lab.item()label = b_y.item()#打印预测结果print('预测结果:',result,"--------","真实值:",label)
相关文章:
GooLeNet模型搭建
一、model import torch from torch import nn from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU nn.ReLU()#路线1:1x1卷积self.p1_1 nn.Conv2d(in_channels i…...

使用ThreadLocal来存取单线程内的数据
一.什么是ThreadLocal? ThreadLocal,即线程本地变量。如果你创建了一个 ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个本地拷贝,多个线程操作这个变量的时候,实际是在操作自己本地内存里面的变量&…...

elasticsearch教程
1. 单点部署(rpm): #提前关闭firewalld,否则无法组建集群 #1. 下载ES rpm包 ]# https://www.elastic.co/cn/downloads #2. 安装es ]# rpm -ivh elasticsearch-7.17.5-x86_64.rpm #3. 调整内核参数(太低的话es会启动报错) echo "vm.max_map_count655360 fs.file-max 655…...

Arrays、Lambda表达式、Collection集合
1. Arrays 1.1 操作数组的工具类 方法名说明public static String toString(数组)把数组拼接成一个字符串public static int binarySearch(数组,查找的元素)二分查找法查找元素public static int[] copyOf(原数组,新数组长度)拷贝数组public static int[] copyOfRange(原数组…...

2024年前端趋势:全栈或许是不容错过的选择!
近年来,前端开发的技术不断推陈出新,2024年也不例外。在这个变化迅速的领域,全栈开发逐渐成为一股不容忽视的趋势。无论你是经验丰富的开发者,还是刚刚入门的新手,掌握全栈技术都能让你在竞争中脱颖而出。而在这个过程…...

MySQL 实战 45 讲(01-05)
本文为笔者学习林晓斌老师《MySQL 实战 45 讲》课程的学习笔记,并进行了一定的知识扩充。 sql 查询语句的执行流程 大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。 Server 层包括连接器、查询缓存、分析器、优化器和执行器。 连接器负责接收客…...
仓颉编程语言入门 -- Array数组详解
仓颉编程语言入门 – Array数组详解 一. 如何创建Array数组 我们可以使用 Array 类型来构造单一元素类型,有序序列的数据。 1.仓颉使用 Array 来表示 Array 类型。T 表示 Array 的元素类型,T 可以是任意类型 , 类似于泛型的概念 var arr:Array<St…...
C#初级——简单单例模式使用
单例模式 单例模式是一种常用的软件设计模式,它确保一个类只有一个实例,并提供一个全局访问点来获取这个实例,通过单例模式防止私有成员被多次引用,防止数据被随意纂改。本文使用的是线程不安全的懒汉式单例。 创建单例模式 首…...
2024.07.29 校招 实习 内推 面经
地/球🌍 : neituijunsir 交* 流*裙 ,内推/实习/校招汇总表格 1、校招 | 美/团// 快驴、小象、优/选/事/业/部2024年校/园/招聘(内推) 校招 | 美团快驴、小象、优选事业部2024年校园招聘(内推ÿ…...
速盾:爬虫攻击和cc攻击的区别是什么?
爬虫攻击和CC(Distributed Denial of Service)攻击是网络安全领域两种不同类型的攻击方式。尽管它们都涉及对目标网站或服务器的非法访问,但它们的目的、方法和影响各不相同。在接下来的文章中,我们将详细介绍这两种攻击方式的区别…...
Tomcat与Nginx的区别详解
目录 引言Tomcat概述 Tomcat的历史Tomcat的架构Tomcat的功能Nginx概述 Nginx的历史Nginx的架构Nginx的功能Tomcat与Nginx的区别 架构上的区别...

【大模型从入门到精通5】openAI API高级内容审核-1
这里写目录标题 高级内容审核利用 OpenAI 内容审核 API 的高级内容审核技术整合与实施使用自定义规则增强审核综合示例防止提示注入的策略使用分隔符隔离命令理解分隔符使用分隔符实现命令隔离 高级内容审核 利用 OpenAI 内容审核 API 的高级内容审核技术 OpenAI 内容审核 AP…...

JVM系列 | 对象的消亡3——垃圾收集器的对比与实现细节
垃圾收集器 文章目录 各收集器简单对比收集器启动参数各收集器详细说明JDK 1.3 之前JDK 1.3 | SerialJDK 1.4 | ParNewJDK 1.4 | Parallel ScavengeJDK 5 | CMS 收集器JDK 7 | G1 各收集器简单对比 收集器名称出现时间淘汰时间目标采用技术线程数STW分代备注无名JDK 1.3之前JD…...

C# Unity 面向对象补全计划 七大原则 之 开闭原则(OCP) 难度:☆ 总结:已经写好的就别动它了,多用继承
本文仅作学习笔记与交流,不作任何商业用途,作者能力有限,如有不足还请斧正 本系列作为七大原则和设计模式的进阶知识,看不懂没关系 请看专栏:http://t.csdnimg.cn/mIitr,查漏补缺 1.开闭原则(OC…...

微信防封指南请收好
一、新号与老号的添加限制 建议新注册的微信号主动添加好友的数量不宜过多,推荐每日添加不超过5个好友;对于老号,建议每日添加不超过20个好友。保持适度的添加速度,避免被系统判定为异常操作。 二、避免使用营销性词汇 在发送消…...

选择排序算法改进思路和算法实现
选择排序 在未排序的数组中,用第一个数去和后面的数比较,找出最小的数,和第一个数交换。第一个数已为已排序的数。 相当于0~7 从0~7中找到最小的数放在0 从1~7中找到最小的数放在1 从2~7中找到最小的数放在2 ...以此类推 从6~7中找到最…...

【文件解析漏洞复现】
一.IIS解析漏洞复现 1.IIS6.X 方式一:目录解析 搭建IIS环境 在网站下建立文件夹的名字为.asp/.asa 的文件夹,其目录内的任何扩展名的文件都被IIS当作asp文件来解析并执行。 访问成功被解析 方式一:目录解析 在IIS 6处理文件解…...

【STL】 vector的底层实现
1.vector的模拟代码完整实现(后面会拆分开一个一个细讲) #pragma once #include<assert.h>// 抓重点namespace bit {/*template<class T>class vector{public:typedef T* iterator;private:T* _a;size_t _size;size_t _capacity;};*/templa…...
责任链模式:解耦职责,优化请求处理
在软件设计中,如何有效地处理复杂的请求是一个重要的课题。 责任链模式(Chain of Responsibility Pattern)提供了一种解耦请求发送者和接收者的方法,使得多个对象都有机会处理请求,从而达到灵活和可扩展的设计。 什么…...

【Scene Transformer】scene transformer论文阅读笔记
文章目录 序言(Abstract)(Introduction)(Related Work)(Methods)(Scene-centric Representation for Agents and Road Graphs)(Encoding Transformer)(Predicting Probabilities for Each Futures)(Joint and Marginal Loss Formulation) (Results)(Discussion)(Questions) sce…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
Vue记事本应用实现教程
文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试
作者:Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位:中南大学地球科学与信息物理学院论文标题:BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接:https://arxiv.…...

PL0语法,分析器实现!
简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...

DeepSeek源码深度解析 × 华为仓颉语言编程精粹——从MoE架构到全场景开发生态
前言 在人工智能技术飞速发展的今天,深度学习与大模型技术已成为推动行业变革的核心驱动力,而高效、灵活的开发工具与编程语言则为技术创新提供了重要支撑。本书以两大前沿技术领域为核心,系统性地呈现了两部深度技术著作的精华:…...

VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...

图解JavaScript原型:原型链及其分析 | JavaScript图解
忽略该图的细节(如内存地址值没有用二进制) 以下是对该图进一步的理解和总结 1. JS 对象概念的辨析 对象是什么:保存在堆中一块区域,同时在栈中有一块区域保存其在堆中的地址(也就是我们通常说的该变量指向谁&…...