当前位置: 首页 > news >正文

深度学习中简易FC和CNN搭建

  • TensorFlow是由谷歌开发的
  • PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的

Torch和cuda版本的对应,手动安装较好

全连接FC(Batch*Num)

搭建建议网络:

from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x

封装数据

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)

训练模型:

import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

一般在训练模型时加上model.train(),这样会正常使用Batch NormalizationDropout;测试的时候一般选择model.eval(),这样就不会使用Batch NormalizationDropout

批量损失函数

from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

优化器
SGD是一种简单且易于实现的优化算法,但在大规模数据集和复杂模型上收敛缓慢。
Adam是一种自适应学习率调整的优化算法,能够更快地收敛,但可能会占用更多的内存。
在实践中,根据具体问题和数据集的特点,选择适合的优化算法可以提高训练效果。

卷积神经网络CNN(Batch * C * H * W)

Channel First
引入py库

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np

预处理

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

构建CNN

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

定义准确率

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

相关文章:

深度学习中简易FC和CNN搭建

TensorFlow是由谷歌开发的PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的 Torch和cuda版本的对应,手动安装较好 全连接FC(Batch*Num) 搭建建议网络: from torch import nnclass Mnist_NN(nn.Module):def __i…...

【多模态】20、OVR-CNN | 使用 caption 来实现开放词汇目标检测

文章目录 一、背景二、方法2.1 学习 视觉-语义 空间2.2 学习开放词汇目标检测 三、效果 论文:Open-Vocabulary Object Detection Using Captions 代码:https://github.com/alirezazareian/ovr-cnn 出处:CVPR2021 Oral 一、背景 目标检测数…...

网络编程 IO多路复用 [select版] (TCP网络聊天室)

//head.h 头文件 //TcpGrpSer.c 服务器端 //TcpGrpUsr.c 客户端 select函数 功能&#xff1a;阻塞函数&#xff0c;让内核去监测集合中的文件描述符是否准备就绪&#xff0c;若准备就绪则解除阻塞。 原型&#xff1a; #include <sys/select.…...

数学建模学习(7):单目标和多目标规划

优化问题描述 优化 优化算法是指在满足一定条件下,在众多方案中或者参数中最优方案,或者参数值,以使得某个或者多个功能指标达到最优,或使得系统的某些性能指标达到最大值或者最小值 线性规划 线性规划是指目标函数和约束都是线性的情况 [x,fval]linprog(f,A,b,Aeq,Beq,LB,U…...

Element UI如何自定义样式

简介 Element UI是一套非常完善的前端组件库&#xff0c;但是如何个性化定制其中的组件样式呢&#xff1f;今天我们就来聊一聊这个 举例 就拿最常见的按钮el-button来举例&#xff0c;一般来说默认是蓝底白字。效果图如下 可是我们想个性化定制&#xff0c;让他成为粉底红字应…...

protobuf入门实践2

如何在proto中定义一个rpc服务? syntax "proto3"; //声明protobuf的版本package fixbug; //声明了代码所在的包 &#xff08;对于C来说就是namespace)//下面的选项&#xff0c;表示生成service服务类和rpc方法描述&#xff0c; 默认是不生成的 option cc_generi…...

adb shell使用总结

文章目录 日志记录系统概览adb 使用方式 adb命令日志过滤按照告警等级进行过滤按照tag进行过滤根据告警等级和tag进行联合过滤屏蔽系统和其他App干扰&#xff0c;仅仅关注App自身日志 查看“当前页面”Activity文件传输截屏和录屏安装、卸载App启动activity其他 日志记录系统概…...

UG NX二次开发(C++)-Tag的含义、Tag类型与其他的转换

文章目录 1、前言2、Tag号的含义3、tag_t转换为int3、TaggedObject与Tag转换3.1 TaggedObject定义3.2 TaggedObject获取Tag3.3 根据Tag获取TaggedObject4.Tag与double类型的转换1、前言 在UG NX中,每个对象对应一个tag号,C++中,其类型是tag_t,一般是5位或者6位的int数字,…...

Informer 论文学习笔记

论文&#xff1a;《Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting》 代码&#xff1a;https://github.com/zhouhaoyi/Informer2020 地址&#xff1a;https://arxiv.org/abs/2012.07436v3 特点&#xff1a; 实现时间与空间复杂度为 O ( …...

c语言位段知识详解

本篇文章带来位段相关知识详细讲解&#xff01; 如果您觉得文章不错&#xff0c;期待你的一键三连哦&#xff0c;你的鼓励是我创作的动力之源&#xff0c;让我们一起加油&#xff0c;一起奔跑&#xff0c;让我们顶峰相见&#xff01;&#xff01;&#xff01; 目录 一.什么是…...

FFmpeg aresample_swr_opts的解析

ffmpeg option的解析 aresample_swr_opts是AVFilterGraph中的option。 static const AVOption filtergraph_options[] {{ "thread_type", "Allowed thread types", OFFSET(thread_type), AV_OPT_TYPE_FLAGS,{ .i64 AVFILTER_THREAD_SLICE }, 0, INT_MA…...

CAN学习笔记3:STM32 CAN控制器介绍

STM32 CAN控制器 1 概述 STM32 CAN控制器&#xff08;bxCAN&#xff09;&#xff0c;支持CAN 2.0A 和 CAN 2.0B Active版本协议。CAN 2.0A 只能处理标准数据帧且扩展帧的内容会识别错误&#xff0c;而CAN 2.0B Active 可以处理标准数据帧和扩展数据帧。 2 bxCAN 特性 波特率…...

软工导论知识框架(二)结构化的需求分析

本章节涉及很多重要图表的制作&#xff0c;如ER图、数据流图、状态转换图、数据字典的书写等&#xff0c;对初学者来说比较生僻&#xff0c;本贴只介绍基础的轮廓&#xff0c;后面会有单独的帖子详解各图表如何绘制。 一.结构化的软件开发方法&#xff1a;结构化的分析、设计、…...

[SQL挖掘机] - 算术函数 - abs

介绍: 当谈到 SQL 中的 abs 函数时&#xff0c;它是一个用于计算数值的绝对值的函数。“abs” 代表 “absolute”&#xff08;绝对&#xff09;&#xff0c;因此 abs 函数的作用是返回一个给定数值的非负值&#xff08;即该数值的绝对值&#xff09;。 abs 函数接受一个参数&a…...

vue拼接html点击事件不生效

vue使用ts&#xff0c;拼接html&#xff0c;点击事件不生效或者报 is not defined 点击事件要用onclick 不是click let data{name:测,id:123} let conHtml <div> "名称&#xff1a;" data.name "<br>" <p class"cursor blue&quo…...

【Spring】Spring之依赖注入源码解析

1 Spring注入方式 1.1 手动注入 xml中定义Bean&#xff0c;程序员手动给某个属性赋值。 set方式注入 <bean name"userService" class"com.firechou.service.UserService"><property name"orderService" ref"orderService"…...

【微软知识】微软相关技术知识分享

微软技术领域 一、微软操作系统&#xff1a; 微软的操作系统主要是 Windows 系列&#xff0c;包括 Windows 10、Windows Server 等。了解 Windows 操作系统的基本使用、配置和故障排除是非常重要的。微软操作系统&#xff08;Microsoft System&#xff09;是美国微软开发的Wi…...

12.python设计模式【观察者模式】

内容&#xff1a;定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变的时候&#xff0c;所有依赖于它的对象得到通知并被自动更新。观者者模式又称为“发布-订阅”模式。比如天气预报&#xff0c;气象局分发气象数据。 角色&#xff1a; 抽象主题&#xf…...

重生之我要学C++第五天

这篇文章主要内容是构造函数的初始化列表以及运算符重载在顺序表中的简单应用&#xff0c;运算符重载实现自定义类型的流插入流提取。希望对大家有所帮助&#xff0c;点赞收藏评论&#xff0c;支持一下吧&#xff01; 目录 构造函数进阶理解 1.内置类型成员在参数列表中的定义 …...

复习之linux高级存储管理

一、lvm----逻辑卷管理 1.lvm定义 LVM是 Logical Volume Manager&#xff08;逻辑卷管理&#xff09;的简写&#xff0c;它是Linux环境下对磁盘分区进行管理的一种机制。 逻辑卷管理器(LogicalVolumeManager)本质上是一个虚拟设备驱动&#xff0c;是在内核中块设备和物理设备…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版&#xff0c;柱状图PPT模版&#xff0c;线状图PPT模版&#xff0c;折线图PPT模版&#xff0c;饼状图PPT模版&#xff0c;雷达图PPT模版&#xff0c;树状图PPT模版 图表类系列各种样式PPT模版分享&#xff1a;图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

安全突围:重塑内生安全体系:齐向东在2025年BCS大会的演讲

文章目录 前言第一部分&#xff1a;体系力量是突围之钥第一重困境是体系思想落地不畅。第二重困境是大小体系融合瓶颈。第三重困境是“小体系”运营梗阻。 第二部分&#xff1a;体系矛盾是突围之障一是数据孤岛的障碍。二是投入不足的障碍。三是新旧兼容难的障碍。 第三部分&am…...

elementUI点击浏览table所选行数据查看文档

项目场景&#xff1a; table按照要求特定的数据变成按钮可以点击 解决方案&#xff1a; <el-table-columnprop"mlname"label"名称"align"center"width"180"><template slot-scope"scope"><el-buttonv-if&qu…...

Axure 下拉框联动

实现选省、选完省之后选对应省份下的市区...

Python训练营-Day26-函数专题1:函数定义与参数

题目1&#xff1a;计算圆的面积 任务&#xff1a; 编写一个名为 calculate_circle_area 的函数&#xff0c;该函数接收圆的半径 radius 作为参数&#xff0c;并返回圆的面积。圆的面积 π * radius (可以使用 math.pi 作为 π 的值)要求&#xff1a;函数接收一个位置参数 radi…...

基于单片机的宠物屋智能系统设计与实现(论文+源码)

本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢&#xff0c;连接红外测温传感器&#xff0c;可实时精准捕捉宠物体温变化&#xff0c;以便及时发现健康异常&#xff1b;水位检测传感器时刻监测饮用水余量&#xff0c;防止宠物…...