深度学习中简易FC和CNN搭建
- TensorFlow是由谷歌开发的
- PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的
Torch和cuda版本的对应,手动安装较好
全连接FC(Batch*Num)
搭建建议网络:
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
封装数据
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
训练模型:
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
一般在训练模型时加上model.train()
,这样会正常使用Batch Normalization
和 Dropout
;测试的时候一般选择model.eval()
,这样就不会使用Batch Normalization
和 Dropout
批量损失函数
from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
优化器
SGD是一种简单且易于实现的优化算法,但在大规模数据集和复杂模型上收敛缓慢。
Adam是一种自适应学习率调整的优化算法,能够更快地收敛,但可能会占用更多的内存。
在实践中,根据具体问题和数据集的特点,选择适合的优化算法可以提高训练效果。
卷积神经网络CNN(Batch * C * H * W)
Channel First
引入py库
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
预处理
# 定义超参数
input_size = 28 #图像的总尺寸28*28
num_classes = 10 #标签的种类数
num_epochs = 3 #训练的总循环周期
batch_size = 64 #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
构建CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential( # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1, # 灰度图out_channels=16, # 要得到几多少个特征图kernel_size=5, # 卷积核大小stride=1, # 步长padding=2, # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1), # 输出的特征图为 (16, 28, 28)nn.ReLU(), # relu层nn.MaxPool2d(kernel_size=2), # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2), # 输出 (32, 14, 14)nn.ReLU(), # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2), # 输出 (32, 7, 7))self.conv3 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2), # 输出 (32, 14, 14)nn.ReLU(), # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10) # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1) # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output
定义准确率
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels)
训练网络模型
# 实例化
net = CNN()
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader): #针对容器中的每一个批进行循环net.train() output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))
相关文章:
深度学习中简易FC和CNN搭建
TensorFlow是由谷歌开发的PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的 Torch和cuda版本的对应,手动安装较好 全连接FC(Batch*Num) 搭建建议网络: from torch import nnclass Mnist_NN(nn.Module):def __i…...

【多模态】20、OVR-CNN | 使用 caption 来实现开放词汇目标检测
文章目录 一、背景二、方法2.1 学习 视觉-语义 空间2.2 学习开放词汇目标检测 三、效果 论文:Open-Vocabulary Object Detection Using Captions 代码:https://github.com/alirezazareian/ovr-cnn 出处:CVPR2021 Oral 一、背景 目标检测数…...

网络编程 IO多路复用 [select版] (TCP网络聊天室)
//head.h 头文件 //TcpGrpSer.c 服务器端 //TcpGrpUsr.c 客户端 select函数 功能:阻塞函数,让内核去监测集合中的文件描述符是否准备就绪,若准备就绪则解除阻塞。 原型: #include <sys/select.…...

数学建模学习(7):单目标和多目标规划
优化问题描述 优化 优化算法是指在满足一定条件下,在众多方案中或者参数中最优方案,或者参数值,以使得某个或者多个功能指标达到最优,或使得系统的某些性能指标达到最大值或者最小值 线性规划 线性规划是指目标函数和约束都是线性的情况 [x,fval]linprog(f,A,b,Aeq,Beq,LB,U…...

Element UI如何自定义样式
简介 Element UI是一套非常完善的前端组件库,但是如何个性化定制其中的组件样式呢?今天我们就来聊一聊这个 举例 就拿最常见的按钮el-button来举例,一般来说默认是蓝底白字。效果图如下 可是我们想个性化定制,让他成为粉底红字应…...
protobuf入门实践2
如何在proto中定义一个rpc服务? syntax "proto3"; //声明protobuf的版本package fixbug; //声明了代码所在的包 (对于C来说就是namespace)//下面的选项,表示生成service服务类和rpc方法描述, 默认是不生成的 option cc_generi…...
adb shell使用总结
文章目录 日志记录系统概览adb 使用方式 adb命令日志过滤按照告警等级进行过滤按照tag进行过滤根据告警等级和tag进行联合过滤屏蔽系统和其他App干扰,仅仅关注App自身日志 查看“当前页面”Activity文件传输截屏和录屏安装、卸载App启动activity其他 日志记录系统概…...
UG NX二次开发(C++)-Tag的含义、Tag类型与其他的转换
文章目录 1、前言2、Tag号的含义3、tag_t转换为int3、TaggedObject与Tag转换3.1 TaggedObject定义3.2 TaggedObject获取Tag3.3 根据Tag获取TaggedObject4.Tag与double类型的转换1、前言 在UG NX中,每个对象对应一个tag号,C++中,其类型是tag_t,一般是5位或者6位的int数字,…...

Informer 论文学习笔记
论文:《Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting》 代码:https://github.com/zhouhaoyi/Informer2020 地址:https://arxiv.org/abs/2012.07436v3 特点: 实现时间与空间复杂度为 O ( …...

c语言位段知识详解
本篇文章带来位段相关知识详细讲解! 如果您觉得文章不错,期待你的一键三连哦,你的鼓励是我创作的动力之源,让我们一起加油,一起奔跑,让我们顶峰相见!!! 目录 一.什么是…...

FFmpeg aresample_swr_opts的解析
ffmpeg option的解析 aresample_swr_opts是AVFilterGraph中的option。 static const AVOption filtergraph_options[] {{ "thread_type", "Allowed thread types", OFFSET(thread_type), AV_OPT_TYPE_FLAGS,{ .i64 AVFILTER_THREAD_SLICE }, 0, INT_MA…...

CAN学习笔记3:STM32 CAN控制器介绍
STM32 CAN控制器 1 概述 STM32 CAN控制器(bxCAN),支持CAN 2.0A 和 CAN 2.0B Active版本协议。CAN 2.0A 只能处理标准数据帧且扩展帧的内容会识别错误,而CAN 2.0B Active 可以处理标准数据帧和扩展数据帧。 2 bxCAN 特性 波特率…...

软工导论知识框架(二)结构化的需求分析
本章节涉及很多重要图表的制作,如ER图、数据流图、状态转换图、数据字典的书写等,对初学者来说比较生僻,本贴只介绍基础的轮廓,后面会有单独的帖子详解各图表如何绘制。 一.结构化的软件开发方法:结构化的分析、设计、…...
[SQL挖掘机] - 算术函数 - abs
介绍: 当谈到 SQL 中的 abs 函数时,它是一个用于计算数值的绝对值的函数。“abs” 代表 “absolute”(绝对),因此 abs 函数的作用是返回一个给定数值的非负值(即该数值的绝对值)。 abs 函数接受一个参数&a…...
vue拼接html点击事件不生效
vue使用ts,拼接html,点击事件不生效或者报 is not defined 点击事件要用onclick 不是click let data{name:测,id:123} let conHtml <div> "名称:" data.name "<br>" <p class"cursor blue&quo…...
【Spring】Spring之依赖注入源码解析
1 Spring注入方式 1.1 手动注入 xml中定义Bean,程序员手动给某个属性赋值。 set方式注入 <bean name"userService" class"com.firechou.service.UserService"><property name"orderService" ref"orderService"…...

【微软知识】微软相关技术知识分享
微软技术领域 一、微软操作系统: 微软的操作系统主要是 Windows 系列,包括 Windows 10、Windows Server 等。了解 Windows 操作系统的基本使用、配置和故障排除是非常重要的。微软操作系统(Microsoft System)是美国微软开发的Wi…...
12.python设计模式【观察者模式】
内容:定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变的时候,所有依赖于它的对象得到通知并被自动更新。观者者模式又称为“发布-订阅”模式。比如天气预报,气象局分发气象数据。 角色: 抽象主题…...

重生之我要学C++第五天
这篇文章主要内容是构造函数的初始化列表以及运算符重载在顺序表中的简单应用,运算符重载实现自定义类型的流插入流提取。希望对大家有所帮助,点赞收藏评论,支持一下吧! 目录 构造函数进阶理解 1.内置类型成员在参数列表中的定义 …...

复习之linux高级存储管理
一、lvm----逻辑卷管理 1.lvm定义 LVM是 Logical Volume Manager(逻辑卷管理)的简写,它是Linux环境下对磁盘分区进行管理的一种机制。 逻辑卷管理器(LogicalVolumeManager)本质上是一个虚拟设备驱动,是在内核中块设备和物理设备…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

均衡后的SNRSINR
本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt 根发送天线, n r n_r nr 根接收天线的 MIMO 系…...
Typeerror: cannot read properties of undefined (reading ‘XXX‘)
最近需要在离线机器上运行软件,所以得把软件用docker打包起来,大部分功能都没问题,出了一个奇怪的事情。同样的代码,在本机上用vscode可以运行起来,但是打包之后在docker里出现了问题。使用的是dialog组件,…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】,分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...
Caliper 负载(Workload)详细解析
Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...
MySQL 索引底层结构揭秘:B-Tree 与 B+Tree 的区别与应用
文章目录 一、背景知识:什么是 B-Tree 和 BTree? B-Tree(平衡多路查找树) BTree(B-Tree 的变种) 二、结构对比:一张图看懂 三、为什么 MySQL InnoDB 选择 BTree? 1. 范围查询更快 2…...
C语言中提供的第三方库之哈希表实现
一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

Sklearn 机器学习 缺失值处理 获取填充失值的统计值
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 使用 Scikit-learn 处理缺失值并提取填充统计信息的完整指南 在机器学习项目中,数据清…...