当前位置: 首页 > news >正文

深度学习中简易FC和CNN搭建

  • TensorFlow是由谷歌开发的
  • PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的

Torch和cuda版本的对应,手动安装较好

全连接FC(Batch*Num)

搭建建议网络:

from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x

封装数据

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)

训练模型:

import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

一般在训练模型时加上model.train(),这样会正常使用Batch NormalizationDropout;测试的时候一般选择model.eval(),这样就不会使用Batch NormalizationDropout

批量损失函数

from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

优化器
SGD是一种简单且易于实现的优化算法,但在大规模数据集和复杂模型上收敛缓慢。
Adam是一种自适应学习率调整的优化算法,能够更快地收敛,但可能会占用更多的内存。
在实践中,根据具体问题和数据集的特点,选择适合的优化算法可以提高训练效果。

卷积神经网络CNN(Batch * C * H * W)

Channel First
引入py库

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np

预处理

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

构建CNN

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

定义准确率

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

相关文章:

深度学习中简易FC和CNN搭建

TensorFlow是由谷歌开发的PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的 Torch和cuda版本的对应,手动安装较好 全连接FC(Batch*Num) 搭建建议网络: from torch import nnclass Mnist_NN(nn.Module):def __i…...

【多模态】20、OVR-CNN | 使用 caption 来实现开放词汇目标检测

文章目录 一、背景二、方法2.1 学习 视觉-语义 空间2.2 学习开放词汇目标检测 三、效果 论文:Open-Vocabulary Object Detection Using Captions 代码:https://github.com/alirezazareian/ovr-cnn 出处:CVPR2021 Oral 一、背景 目标检测数…...

网络编程 IO多路复用 [select版] (TCP网络聊天室)

//head.h 头文件 //TcpGrpSer.c 服务器端 //TcpGrpUsr.c 客户端 select函数 功能&#xff1a;阻塞函数&#xff0c;让内核去监测集合中的文件描述符是否准备就绪&#xff0c;若准备就绪则解除阻塞。 原型&#xff1a; #include <sys/select.…...

数学建模学习(7):单目标和多目标规划

优化问题描述 优化 优化算法是指在满足一定条件下,在众多方案中或者参数中最优方案,或者参数值,以使得某个或者多个功能指标达到最优,或使得系统的某些性能指标达到最大值或者最小值 线性规划 线性规划是指目标函数和约束都是线性的情况 [x,fval]linprog(f,A,b,Aeq,Beq,LB,U…...

Element UI如何自定义样式

简介 Element UI是一套非常完善的前端组件库&#xff0c;但是如何个性化定制其中的组件样式呢&#xff1f;今天我们就来聊一聊这个 举例 就拿最常见的按钮el-button来举例&#xff0c;一般来说默认是蓝底白字。效果图如下 可是我们想个性化定制&#xff0c;让他成为粉底红字应…...

protobuf入门实践2

如何在proto中定义一个rpc服务? syntax "proto3"; //声明protobuf的版本package fixbug; //声明了代码所在的包 &#xff08;对于C来说就是namespace)//下面的选项&#xff0c;表示生成service服务类和rpc方法描述&#xff0c; 默认是不生成的 option cc_generi…...

adb shell使用总结

文章目录 日志记录系统概览adb 使用方式 adb命令日志过滤按照告警等级进行过滤按照tag进行过滤根据告警等级和tag进行联合过滤屏蔽系统和其他App干扰&#xff0c;仅仅关注App自身日志 查看“当前页面”Activity文件传输截屏和录屏安装、卸载App启动activity其他 日志记录系统概…...

UG NX二次开发(C++)-Tag的含义、Tag类型与其他的转换

文章目录 1、前言2、Tag号的含义3、tag_t转换为int3、TaggedObject与Tag转换3.1 TaggedObject定义3.2 TaggedObject获取Tag3.3 根据Tag获取TaggedObject4.Tag与double类型的转换1、前言 在UG NX中,每个对象对应一个tag号,C++中,其类型是tag_t,一般是5位或者6位的int数字,…...

Informer 论文学习笔记

论文&#xff1a;《Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting》 代码&#xff1a;https://github.com/zhouhaoyi/Informer2020 地址&#xff1a;https://arxiv.org/abs/2012.07436v3 特点&#xff1a; 实现时间与空间复杂度为 O ( …...

c语言位段知识详解

本篇文章带来位段相关知识详细讲解&#xff01; 如果您觉得文章不错&#xff0c;期待你的一键三连哦&#xff0c;你的鼓励是我创作的动力之源&#xff0c;让我们一起加油&#xff0c;一起奔跑&#xff0c;让我们顶峰相见&#xff01;&#xff01;&#xff01; 目录 一.什么是…...

FFmpeg aresample_swr_opts的解析

ffmpeg option的解析 aresample_swr_opts是AVFilterGraph中的option。 static const AVOption filtergraph_options[] {{ "thread_type", "Allowed thread types", OFFSET(thread_type), AV_OPT_TYPE_FLAGS,{ .i64 AVFILTER_THREAD_SLICE }, 0, INT_MA…...

CAN学习笔记3:STM32 CAN控制器介绍

STM32 CAN控制器 1 概述 STM32 CAN控制器&#xff08;bxCAN&#xff09;&#xff0c;支持CAN 2.0A 和 CAN 2.0B Active版本协议。CAN 2.0A 只能处理标准数据帧且扩展帧的内容会识别错误&#xff0c;而CAN 2.0B Active 可以处理标准数据帧和扩展数据帧。 2 bxCAN 特性 波特率…...

软工导论知识框架(二)结构化的需求分析

本章节涉及很多重要图表的制作&#xff0c;如ER图、数据流图、状态转换图、数据字典的书写等&#xff0c;对初学者来说比较生僻&#xff0c;本贴只介绍基础的轮廓&#xff0c;后面会有单独的帖子详解各图表如何绘制。 一.结构化的软件开发方法&#xff1a;结构化的分析、设计、…...

[SQL挖掘机] - 算术函数 - abs

介绍: 当谈到 SQL 中的 abs 函数时&#xff0c;它是一个用于计算数值的绝对值的函数。“abs” 代表 “absolute”&#xff08;绝对&#xff09;&#xff0c;因此 abs 函数的作用是返回一个给定数值的非负值&#xff08;即该数值的绝对值&#xff09;。 abs 函数接受一个参数&a…...

vue拼接html点击事件不生效

vue使用ts&#xff0c;拼接html&#xff0c;点击事件不生效或者报 is not defined 点击事件要用onclick 不是click let data{name:测,id:123} let conHtml <div> "名称&#xff1a;" data.name "<br>" <p class"cursor blue&quo…...

【Spring】Spring之依赖注入源码解析

1 Spring注入方式 1.1 手动注入 xml中定义Bean&#xff0c;程序员手动给某个属性赋值。 set方式注入 <bean name"userService" class"com.firechou.service.UserService"><property name"orderService" ref"orderService"…...

【微软知识】微软相关技术知识分享

微软技术领域 一、微软操作系统&#xff1a; 微软的操作系统主要是 Windows 系列&#xff0c;包括 Windows 10、Windows Server 等。了解 Windows 操作系统的基本使用、配置和故障排除是非常重要的。微软操作系统&#xff08;Microsoft System&#xff09;是美国微软开发的Wi…...

12.python设计模式【观察者模式】

内容&#xff1a;定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变的时候&#xff0c;所有依赖于它的对象得到通知并被自动更新。观者者模式又称为“发布-订阅”模式。比如天气预报&#xff0c;气象局分发气象数据。 角色&#xff1a; 抽象主题&#xf…...

重生之我要学C++第五天

这篇文章主要内容是构造函数的初始化列表以及运算符重载在顺序表中的简单应用&#xff0c;运算符重载实现自定义类型的流插入流提取。希望对大家有所帮助&#xff0c;点赞收藏评论&#xff0c;支持一下吧&#xff01; 目录 构造函数进阶理解 1.内置类型成员在参数列表中的定义 …...

复习之linux高级存储管理

一、lvm----逻辑卷管理 1.lvm定义 LVM是 Logical Volume Manager&#xff08;逻辑卷管理&#xff09;的简写&#xff0c;它是Linux环境下对磁盘分区进行管理的一种机制。 逻辑卷管理器(LogicalVolumeManager)本质上是一个虚拟设备驱动&#xff0c;是在内核中块设备和物理设备…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周&#xff0c;有很多同学在写期末Java web作业时&#xff0c;运行tomcat出现乱码问题&#xff0c;经过多次解决与研究&#xff0c;我做了如下整理&#xff1a; 原因&#xff1a; IDEA本身编码与tomcat的编码与Windows编码不同导致&#xff0c;Windows 系统控制台…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解

突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 ​安全措施依赖问题​ GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题&#xff1a; 下面创建一个简单的Flask RESTful API示例。首先&#xff0c;我们需要创建环境&#xff0c;安装必要的依赖&#xff0c;然后…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法

树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作&#xff0c;无需更改相机配置。但是&#xff0c;一…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...