GooLeNet模型搭建
一、model
import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU = nn.ReLU()#路线1:1x1卷积self.p1_1 = nn.Conv2d(in_channels = in_channels, out_channels = c1, kernel_size = 1)#路线2:1x1卷积->3x3卷积self.p2_1 = nn.Conv2d(in_channels = in_channels, out_channels = c2[0], kernel_size = 1)self.p2_2 = nn.Conv2d(in_channels = c2[0], out_channels = c2[1], kernel_size = 3, padding = 1)#路线3:1x1卷积->5x5卷积self.p3_1 = nn.Conv2d(in_channels = in_channels, out_channels = c3[0], kernel_size = 1)self.p3_2 = nn.Conv2d(in_channels = c3[0], out_channels = c3[1], kernel_size = 5, padding = 2)#路线4:3x3最大池化->1x1卷积self.p4_1 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)self.p4_2 = nn.Conv2d(in_channels = in_channels, out_channels = c4, kernel_size = 1)#前向传播def forward(self, x):#路线1p1 = self.ReLU(self.p1_1(x))#路线2p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))#路线3p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))#路线4p4 = self.ReLU(self.p4_2(self.p4_1(x)))#拼接out = torch.cat((p1, p2, p3, p4), dim = 1)return outclass GoogLeNet(nn.Module):def __init__(self,Inception):super(GoogLeNet, self).__init__()self.b1 = nn.Sequential(nn.Conv2d(in_channels = 1,out_channels = 64,kernel_size=7,stride=2,padding = 3),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1))self.b2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))#Inception模块self.b3 = nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b4 = nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48,128), 128),nn.AdaptiveAvgPool2d((1,1)),#平展层nn.Flatten(),nn.Linear(1024,10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.b1(x)x = self.b2(x)x = self.b3(x)x = self.b4(x)x = self.b5(x)return xif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = GoogLeNet(Inception).to(device)print(summary(model,(1,224,224)))
二、model_train
import copy
import time
from torchvision.datasets import FashionMNIST
import numpy as np
from torchvision import transforms
import torch.utils.data as Data
import matplotlib.pyplot as plt
from model import GoogLeNet,Inception
import torch
import pandas as pddef train_val_data_process():# 数据预处理train_data = FashionMNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.Resize(size=224),transforms.ToTensor()]))# 数据集划分train_data,val_data = Data.random_split(train_data,[round(0.8*len(train_data)),round(0.2*len(train_data))])train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True, #数据打乱num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)#返回数据集return train_dataloader,val_dataloader# 训练模型
def train_model_process(model,train_dataloader,val_dataloader,num_epochs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#定义优化器,optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#定义损失函数(交叉熵损失函数)criterion = torch.nn.CrossEntropyLoss()model = model.to(device)#复制当前模型的参数best_model_wts = copy.deepcopy(model.state_dict())#初始化参数best_acc = 0.0#训练集与验证集损失函数列表train_loss_all = []val_loss_all = []#训练集与验证集准确率列表train_acc_all = []val_acc_all = []since = time.time()for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有训练集训练过程和验证集评估过程train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 训练集与验证集样本数量train_num = 0val_num = 0for step,(b_x,b_y) in enumerate(train_dataloader):#将特征和标签数据放入GPU中b_x = b_x.to(device)b_y = b_y.to(device)#设置模型为训练模式model.train()#前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)#将梯度置零,因为PyTorch中梯度是累加的optimizer.zero_grad()#反向传播,计算梯度loss.backward()#根据网络反向传播的梯度更新网络参数,达到降低loss的目的optimizer.step()#计算训练集的损失train_loss = train_loss + loss.item() * b_x.size(0)train_corrects = train_corrects + torch.sum(pre_lab == b_y.data)train_num = train_num + b_x.size(0)for step,(b_x,b_y) in enumerate(val_dataloader):# 将特征和标签数据放入验证设备中b_x = b_x.to(device)b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播,输入为一个batch的特征数据,输出为预测的标签output = model(b_x)#查看每一行中最大值所在的位置,即预测的类别pre_lab = torch.argmax(output,dim=1)# 计算损失函数loss = criterion(output,b_y)# 计算验证集的损失val_loss = val_loss + loss.item() * b_x.size(0)val_corrects = val_corrects + torch.sum(pre_lab == b_y.data)val_num = val_num + b_x.size(0)# 计算训练集与测试集每个epoch的损失和准确率train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double() / val_num)# 打印训练集和验证集准确率print('Train Loss: {:.4f} Acc: {:.4f}'.format(train_loss_all[-1], train_acc_all[-1]))print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc: # 保存准确率最高的模型参数best_acc = val_acc_all[-1]#保存当前的最高准确度best_model_wts = copy.deepcopy(model.state_dict())#训练耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use / 60,time_use % 60))#选择最优的模型参数#加载最高准确率的模型参数# model.load_state_dict(best_model_wts)# torch.save(model.load_state_dict(best_model_wts), 'best_model.pth')torch.save(best_model_wts, 'best_model.pth')train_porcess = pd.DataFrame(data = {'epoch':range(num_epochs),'train_loss_all':train_loss_all,'train_acc_all':train_acc_all,'val_loss_all':val_loss_all,'val_acc_all':val_acc_all})return train_porcessdef matplot_acc_loss(train_process):plt.figure(figsize=(12, 4))plt.subplot(1,2,1)plt.plot(train_process["epoch"],train_process["train_loss_all"],'ro-',label="train_loss")plt.plot(train_process["epoch"],train_process["val_loss_all"],'bo-',label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bo-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("acc")plt.show()if __name__=='__main__':#模型实例化GoogLeNet = GoogLeNet(Inception)train_dataloader,val_dataloader = train_val_data_process()train_porcess = train_model_process(GoogLeNet, train_dataloader,val_dataloader,num_epochs=10)matplot_acc_loss(train_porcess)
三、model_test
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import GoogLeNet,Inception
import matplotlib.pyplot as plt
import numpy as np
import os
import timedef test_data_process():# 数据预处理test_data = FashionMNIST(root='./data',train=False,download=True,transform=transforms.Compose([transforms.Resize(size=224), transforms.ToTensor()]))test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True, # 数据打乱num_workers=2)return test_dataloader#测试过程
def test_model_process(model,test_dataloader):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#将模型放入设备中model = model.to(device)#初始化参数test_corrects = 0test_num = 0#只进行前向传播,不计算梯度,节省内存,提高速度with torch.no_grad():for test_data_x,test_data_y in test_dataloader:test_data_x = test_data_x.to(device)test_data_y = test_data_y.to(device)#设置模型为评估模式model.eval()#前向传播,输入为测试数据集,输出为每个样本的预测值test_output = model(test_data_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(test_output, 1) #预测标签test_corrects += torch.sum(pre_lab == test_data_y.data) #预测正确的数量test_num = test_num + test_data_y.size(0) #测试样本数量#计算测试集的准确率test_acc = test_corrects.double().item() / test_num#打印测试集的准确率print('Test Accuracy: {:.4f}'.format(test_acc))if __name__ == '__main__':model = GoogLeNet(Inception)#加载训练好的模型model.load_state_dict(torch.load('best_model.pth'))test_dataloader = test_data_process() #加载测试数据# test_model_process(model,test_dataloader) #测试模型#模型推理device = 'cuda' if torch.cuda.is_available() else 'cpu'model = model.to(device)#将模型设置为评估模式with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)#模型设为验证模型model.eval()output = model(b_x)#获取预测值中最大值对应的索引,即预测的类别pre_lab = torch.argmax(output, dim = 1) #预测标签result = pre_lab.item()label = b_y.item()#打印预测结果print('预测结果:',result,"--------","真实值:",label)
相关文章:
GooLeNet模型搭建
一、model import torch from torch import nn from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2 , c3 , c4):super(Inception, self).__init__()self.ReLU nn.ReLU()#路线1:1x1卷积self.p1_1 nn.Conv2d(in_channels i…...
使用ThreadLocal来存取单线程内的数据
一.什么是ThreadLocal? ThreadLocal,即线程本地变量。如果你创建了一个 ThreadLocal变量,那么访问这个变量的每个线程都会有这个变量的一个本地拷贝,多个线程操作这个变量的时候,实际是在操作自己本地内存里面的变量&…...
elasticsearch教程
1. 单点部署(rpm): #提前关闭firewalld,否则无法组建集群 #1. 下载ES rpm包 ]# https://www.elastic.co/cn/downloads #2. 安装es ]# rpm -ivh elasticsearch-7.17.5-x86_64.rpm #3. 调整内核参数(太低的话es会启动报错) echo "vm.max_map_count655360 fs.file-max 655…...
Arrays、Lambda表达式、Collection集合
1. Arrays 1.1 操作数组的工具类 方法名说明public static String toString(数组)把数组拼接成一个字符串public static int binarySearch(数组,查找的元素)二分查找法查找元素public static int[] copyOf(原数组,新数组长度)拷贝数组public static int[] copyOfRange(原数组…...
2024年前端趋势:全栈或许是不容错过的选择!
近年来,前端开发的技术不断推陈出新,2024年也不例外。在这个变化迅速的领域,全栈开发逐渐成为一股不容忽视的趋势。无论你是经验丰富的开发者,还是刚刚入门的新手,掌握全栈技术都能让你在竞争中脱颖而出。而在这个过程…...
MySQL 实战 45 讲(01-05)
本文为笔者学习林晓斌老师《MySQL 实战 45 讲》课程的学习笔记,并进行了一定的知识扩充。 sql 查询语句的执行流程 大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。 Server 层包括连接器、查询缓存、分析器、优化器和执行器。 连接器负责接收客…...
仓颉编程语言入门 -- Array数组详解
仓颉编程语言入门 – Array数组详解 一. 如何创建Array数组 我们可以使用 Array 类型来构造单一元素类型,有序序列的数据。 1.仓颉使用 Array 来表示 Array 类型。T 表示 Array 的元素类型,T 可以是任意类型 , 类似于泛型的概念 var arr:Array<St…...
C#初级——简单单例模式使用
单例模式 单例模式是一种常用的软件设计模式,它确保一个类只有一个实例,并提供一个全局访问点来获取这个实例,通过单例模式防止私有成员被多次引用,防止数据被随意纂改。本文使用的是线程不安全的懒汉式单例。 创建单例模式 首…...
2024.07.29 校招 实习 内推 面经
地/球🌍 : neituijunsir 交* 流*裙 ,内推/实习/校招汇总表格 1、校招 | 美/团// 快驴、小象、优/选/事/业/部2024年校/园/招聘(内推) 校招 | 美团快驴、小象、优选事业部2024年校园招聘(内推ÿ…...
速盾:爬虫攻击和cc攻击的区别是什么?
爬虫攻击和CC(Distributed Denial of Service)攻击是网络安全领域两种不同类型的攻击方式。尽管它们都涉及对目标网站或服务器的非法访问,但它们的目的、方法和影响各不相同。在接下来的文章中,我们将详细介绍这两种攻击方式的区别…...
Tomcat与Nginx的区别详解
目录 引言Tomcat概述 Tomcat的历史Tomcat的架构Tomcat的功能Nginx概述 Nginx的历史Nginx的架构Nginx的功能Tomcat与Nginx的区别 架构上的区别...
【大模型从入门到精通5】openAI API高级内容审核-1
这里写目录标题 高级内容审核利用 OpenAI 内容审核 API 的高级内容审核技术整合与实施使用自定义规则增强审核综合示例防止提示注入的策略使用分隔符隔离命令理解分隔符使用分隔符实现命令隔离 高级内容审核 利用 OpenAI 内容审核 API 的高级内容审核技术 OpenAI 内容审核 AP…...
JVM系列 | 对象的消亡3——垃圾收集器的对比与实现细节
垃圾收集器 文章目录 各收集器简单对比收集器启动参数各收集器详细说明JDK 1.3 之前JDK 1.3 | SerialJDK 1.4 | ParNewJDK 1.4 | Parallel ScavengeJDK 5 | CMS 收集器JDK 7 | G1 各收集器简单对比 收集器名称出现时间淘汰时间目标采用技术线程数STW分代备注无名JDK 1.3之前JD…...
C# Unity 面向对象补全计划 七大原则 之 开闭原则(OCP) 难度:☆ 总结:已经写好的就别动它了,多用继承
本文仅作学习笔记与交流,不作任何商业用途,作者能力有限,如有不足还请斧正 本系列作为七大原则和设计模式的进阶知识,看不懂没关系 请看专栏:http://t.csdnimg.cn/mIitr,查漏补缺 1.开闭原则(OC…...
微信防封指南请收好
一、新号与老号的添加限制 建议新注册的微信号主动添加好友的数量不宜过多,推荐每日添加不超过5个好友;对于老号,建议每日添加不超过20个好友。保持适度的添加速度,避免被系统判定为异常操作。 二、避免使用营销性词汇 在发送消…...
选择排序算法改进思路和算法实现
选择排序 在未排序的数组中,用第一个数去和后面的数比较,找出最小的数,和第一个数交换。第一个数已为已排序的数。 相当于0~7 从0~7中找到最小的数放在0 从1~7中找到最小的数放在1 从2~7中找到最小的数放在2 ...以此类推 从6~7中找到最…...
【文件解析漏洞复现】
一.IIS解析漏洞复现 1.IIS6.X 方式一:目录解析 搭建IIS环境 在网站下建立文件夹的名字为.asp/.asa 的文件夹,其目录内的任何扩展名的文件都被IIS当作asp文件来解析并执行。 访问成功被解析 方式一:目录解析 在IIS 6处理文件解…...
【STL】 vector的底层实现
1.vector的模拟代码完整实现(后面会拆分开一个一个细讲) #pragma once #include<assert.h>// 抓重点namespace bit {/*template<class T>class vector{public:typedef T* iterator;private:T* _a;size_t _size;size_t _capacity;};*/templa…...
责任链模式:解耦职责,优化请求处理
在软件设计中,如何有效地处理复杂的请求是一个重要的课题。 责任链模式(Chain of Responsibility Pattern)提供了一种解耦请求发送者和接收者的方法,使得多个对象都有机会处理请求,从而达到灵活和可扩展的设计。 什么…...
【Scene Transformer】scene transformer论文阅读笔记
文章目录 序言(Abstract)(Introduction)(Related Work)(Methods)(Scene-centric Representation for Agents and Road Graphs)(Encoding Transformer)(Predicting Probabilities for Each Futures)(Joint and Marginal Loss Formulation) (Results)(Discussion)(Questions) sce…...
利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
华为云AI开发平台ModelArts
华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...
云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?
大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...
7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
Java数值运算常见陷阱与规避方法
整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...
给网站添加live2d看板娘
给网站添加live2d看板娘 参考文献: stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下,文章也主…...
适应性Java用于现代 API:REST、GraphQL 和事件驱动
在快速发展的软件开发领域,REST、GraphQL 和事件驱动架构等新的 API 标准对于构建可扩展、高效的系统至关重要。Java 在现代 API 方面以其在企业应用中的稳定性而闻名,不断适应这些现代范式的需求。随着不断发展的生态系统,Java 在现代 API 方…...
DiscuzX3.5发帖json api
参考文章:PHP实现独立Discuz站外发帖(直连操作数据库)_discuz 发帖api-CSDN博客 简单改造了一下,适配我自己的需求 有一个站点存在多个采集站,我想通过主站拿标题,采集站拿内容 使用到的sql如下 CREATE TABLE pre_forum_post_…...
