深度学习DNN实战
导包:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as Fprint(sys.version_info)
for module in mpl, np, pd, sklearn, torch:print(module.__name__, module.__version__)device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)
加载数据:
from torchvision import datasets
from torchvision.transforms import ToTensor# fashion_mnist图像分类数据集
train_ds = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_ds = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)# torchvision 数据集里没有提供训练集和验证集的划分
# 当然也可以用 torch.utils.data.Dataset 实现人为划分
# 从数据集到dataloader
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=False)
获得标准化的东西:
from torchvision.transforms import Normalize# 遍历train_ds得到每张图片,计算每个通道的均值和方差
def cal_mean_std(ds):mean = 0.std = 0.for img, _ in ds: # 遍历每张图片,img.shape=[1,28,28]mean += img.mean(dim=(1, 2))std += img.std(dim=(1, 2))mean /= len(ds)std /= len(ds)return mean, stdprint(cal_mean_std(train_ds))
# 0.2860, 0.3205
transforms = nn.Sequential(Normalize([0.2860], [0.3205]) # 这里的均值和标准差是通过train_ds计算得到的
)
定义模型:
Sequential容器中写输入层,在下面的add_moudle中可以添加多个隐藏层,更深层次的网络能够学到更多的特征;
init_weights改变分布,也是一种调参数的方式
class NeuralNetwork(nn.Module):def __init__(self, layers_num=2):super().__init__()self.transforms = transforms # 预处理层,标准化self.flatten = nn.Flatten()# 多加几层self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 100),nn.ReLU(),)# 加19层for i in range(1, layers_num):self.linear_relu_stack.add_module(f"Linear_{i}", nn.Linear(100, 100))self.linear_relu_stack.add_module(f"relu", nn.ReLU())# 输出层self.linear_relu_stack.add_module("Output Layer", nn.Linear(100, 10))# 初始化权重self.init_weights()def init_weights(self):"""使用 xavier 均匀分布来初始化全连接层的权重 W"""# print('''初始化权重''')for m in self.modules():# print(m)# print('-'*50)if isinstance(m, nn.Linear):#判断m是否为全连接层# https://pytorch.org/docs/stable/nn.init.htmlnn.init.xavier_uniform_(m.weight) # xavier 均匀分布初始化权重nn.init.zeros_(m.bias) # 全零初始化偏置项# print('''初始化权重完成''')def forward(self, x):# x.shape [batch size, 1, 28, 28]x = self.transforms(x) #标准化x = self.flatten(x) # 展平后 x.shape [batch size, 28 * 28]logits = self.linear_relu_stack(x)# logits.shape [batch size, 10]return logits
total=0
for idx, (key, value) in enumerate(NeuralNetwork(20).named_parameters()):# print(f"Linear_{idx // 2:>02}\tparamerters num: {np.prod(value.shape)}") #np.prod是计算张量的元素个数# print(f"Linear_{idx // 2:>02}\tshape: {value.shape}")total+=np.prod(value.shape)
total #模型参数数量
训练:
from sklearn.metrics import accuracy_score@torch.no_grad()
def evaluating(model, dataloader, loss_fct):loss_list = []pred_list = []label_list = []for datas, labels in dataloader:datas = datas.to(device)labels = labels.to(device)# 前向计算logits = model(datas)loss = loss_fct(logits, labels) # 验证集损失loss_list.append(loss.item())preds = logits.argmax(axis=-1) # 验证集预测pred_list.extend(preds.cpu().numpy().tolist())label_list.extend(labels.cpu().numpy().tolist())acc = accuracy_score(label_list, pred_list)return np.mean(loss_list), acc#%%
from torch.utils.tensorboard import SummaryWriterclass TensorBoardCallback:def __init__(self, log_dir, flush_secs=10):"""Args:log_dir (str): dir to write log.flush_secs (int, optional): write to dsk each flush_secs seconds. Defaults to 10."""self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)def draw_model(self, model, input_shape):self.writer.add_graph(model, input_to_model=torch.randn(input_shape))def add_loss_scalars(self, step, loss, val_loss):self.writer.add_scalars(main_tag="training/loss", tag_scalar_dict={"loss": loss, "val_loss": val_loss},global_step=step,)def add_acc_scalars(self, step, acc, val_acc):self.writer.add_scalars(main_tag="training/accuracy",tag_scalar_dict={"accuracy": acc, "val_accuracy": val_acc},global_step=step,)def add_lr_scalars(self, step, learning_rate):self.writer.add_scalars(main_tag="training/learning_rate",tag_scalar_dict={"learning_rate": learning_rate},global_step=step,)def __call__(self, step, **kwargs):# add lossloss = kwargs.pop("loss", None)val_loss = kwargs.pop("val_loss", None)if loss is not None and val_loss is not None:self.add_loss_scalars(step, loss, val_loss)# add accacc = kwargs.pop("acc", None)val_acc = kwargs.pop("val_acc", None)if acc is not None and val_acc is not None:self.add_acc_scalars(step, acc, val_acc)# add lrlearning_rate = kwargs.pop("lr", None)if learning_rate is not None:self.add_lr_scalars(step, learning_rate)#%%
class SaveCheckpointsCallback:def __init__(self, save_dir, save_step=5000, save_best_only=True):"""Save checkpoints each save_epoch epoch. We save checkpoint by epoch in this implementation.Usually, training scripts with pytorch evaluating model and save checkpoint by step.Args:save_dir (str): dir to save checkpointsave_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.save_best_only (bool, optional): If True, only save the best model or save each model at every epoch."""self.save_dir = save_dirself.save_step = save_stepself.save_best_only = save_best_onlyself.best_metrics = -1# mkdirif not os.path.exists(self.save_dir):os.mkdir(self.save_dir)def __call__(self, step, state_dict, metric=None):if step % self.save_step > 0:returnif self.save_best_only:assert metric is not Noneif metric >= self.best_metrics:# save checkpointstorch.save(state_dict, os.path.join(self.save_dir, "best.ckpt"))# update best metricsself.best_metrics = metricelse:torch.save(state_dict, os.path.join(self.save_dir, f"{step}.ckpt"))#%%
class EarlyStopCallback:def __init__(self, patience=5, min_delta=0.01):"""Args:patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Defaults to 0.01."""self.patience = patienceself.min_delta = min_deltaself.best_metric = -1self.counter = 0def __call__(self, metric):if metric >= self.best_metric + self.min_delta:# update best metricself.best_metric = metric# reset counter self.counter = 0else: self.counter += 1@propertydef early_stop(self):return self.counter >= self.patience#%%
# 训练
def training(model, train_loader, val_loader, epoch, loss_fct, optimizer, tensorboard_callback=None,save_ckpt_callback=None,early_stop_callback=None,eval_step=500,):record_dict = {"train": [],"val": []}global_step = 0model.train()with tqdm(total=epoch * len(train_loader)) as pbar:for epoch_id in range(epoch):# trainingfor datas, labels in train_loader:datas = datas.to(device)labels = labels.to(device)# 梯度清空optimizer.zero_grad()# 模型前向计算logits = model(datas)# 计算损失loss = loss_fct(logits, labels)# 梯度回传loss.backward()# 调整优化器,包括学习率的变动等optimizer.step()preds = logits.argmax(axis=-1)acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy()) loss = loss.cpu().item()# recordrecord_dict["train"].append({"loss": loss, "acc": acc, "step": global_step})# evaluatingif global_step % eval_step == 0:model.eval()val_loss, val_acc = evaluating(model, val_loader, loss_fct)record_dict["val"].append({"loss": val_loss, "acc": val_acc, "step": global_step})model.train()# 1. 使用 tensorboard 可视化if tensorboard_callback is not None:tensorboard_callback(global_step, loss=loss, val_loss=val_loss,acc=acc, val_acc=val_acc,lr=optimizer.param_groups[0]["lr"],)# 2. 保存模型权重 save model checkpointif save_ckpt_callback is not None:save_ckpt_callback(global_step, model.state_dict(), metric=val_acc)# 3. 早停 Early Stopif early_stop_callback is not None:early_stop_callback(val_acc)if early_stop_callback.early_stop:print(f"Early stop at epoch {epoch_id} / global_step {global_step}")return record_dict# udate stepglobal_step += 1pbar.update(1)pbar.set_postfix({"epoch": epoch_id})return record_dictepoch = 100model = NeuralNetwork(layers_num=10)#%%
# 1. 定义损失函数 采用交叉熵损失
loss_fct = nn.CrossEntropyLoss()
# 2. 定义优化器 采用SGD
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 1. tensorboard 可视化
tensorboard_callback = TensorBoardCallback("runs")
tensorboard_callback.draw_model(model, [1, 28, 28])
# 2. save best
save_ckpt_callback = SaveCheckpointsCallback("checkpoints", save_best_only=True)
# 3. early stop
early_stop_callback = EarlyStopCallback(patience=10, min_delta=0.001)model = model.to(device)
#%%
record = training(model,train_loader,val_loader,epoch,loss_fct,optimizer,tensorboard_callback=None,save_ckpt_callback=save_ckpt_callback,early_stop_callback=early_stop_callback,eval_step=len(train_loader))
画图:
def plot_learning_curves(record_dict, sample_step=500):# build DataFrametrain_df = pd.DataFrame(record_dict["train"]).set_index("step").iloc[::sample_step]val_df = pd.DataFrame(record_dict["val"]).set_index("step")# plotfig_num = len(train_df.columns)fig, axs = plt.subplots(1, fig_num, figsize=(6 * fig_num, 5))for idx, item in enumerate(train_df.columns): axs[idx].plot(train_df.index, train_df[item], label=f"train_{item}")axs[idx].plot(val_df.index, val_df[item], label=f"val_{item}")axs[idx].grid()axs[idx].legend()axs[idx].set_xlabel("step")plt.show()plot_learning_curves(record, sample_step=10000) #横坐标是 steps
相关文章:

深度学习DNN实战
导包: import matplotlib as mpl import matplotlib.pyplot as plt %matplotlib inline import numpy as np import sklearn import pandas as pd import os import sys import time from tqdm.auto import tqdm import torch import torch.nn as nn import torch…...

课程3. 分批训练与数据规范、标准化
课程3. 分批训练与数据规范、标准化 理论神经网络的梯度优化反向传播算法 批量训练网络输入的规范化BatchNorm 验证样本实践加载数据集网络构建训练神经网络 课程计划: 1.理论: 批量训练; 输入数据的规范化; 批量标准化ÿ…...
《机器学习数学基础》补充资料:过渡矩阵和坐标变换推导
尽管《机器学习数学基础》这本书,耗费了比较长的时间和精力,怎奈学识有限,错误难免。因此,除了在专门的网页( 勘误和修订 )中发布勘误和修订内容之外,对于重大错误,我还会以专题的形…...
linux指令学习--sudo apt-get install vim
1. 命令分解 部分含义sudo以管理员权限运行命令(需要输入用户密码)。apt-getUbuntu 的包管理工具,用于安装、更新、卸载软件包。installapt-get 的子命令,表示安装软件包。vim要安装的软件包名称(Vim 文本编辑器&…...

类和对象—多态—案例2—制作饮品
案例描述: 制作饮品的大致流程为:煮水-冲泡-倒入杯中-加入辅料 利用多态技术实现本案例,提供抽象制作产品基类,提供子类制作咖啡和茶叶 思路解析: 1. 定义抽象基类 - 创建 AbstractDrinking 抽象类,该类…...

嵌入式产品级-超小尺寸游戏机(从0到1 硬件-软件-外壳)
Ultra-small size gaming console。 超小尺寸游戏机-Pico This embedded product is mainly based on miniaturization, followed by his game functions are also very complete, for all kinds of games can be played, and there will be relevant illustrations in the fo…...

计算机毕业设计Python+Django+Vue3微博数据舆情分析平台 微博用户画像系统 微博舆情可视化(源码+ 文档+PPT+讲解)
温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…...

前端开发10大框架深度解析
摘要 在现代前端开发中,框架的选择对项目的成功至关重要。本文旨在为开发者提供一份全面的前端框架指南,涵盖 React、Vue.js、Angular、Svelte、Ember.js、Preact、Backbone.js、Next.js、Nuxt.js 和 Gatsby。我们将从 简介、优缺点、适用场景 以及 实际…...

Mybatis 的关联映射(一对一,一对多,多对多)
前言 在前面我们已经了解了,mybatis 的基本用法,动态SQL,学会使用mybatis 来操作数据库。但这些主要操作还是针对 单表实现的。在实际的开发中,对数据库的操作,常常涉及多张表。 因此本篇博客的目标:通过my…...

深度解码!清华大学第六弹《AIGC发展研究3.0版》
在Grok3与GPT-4.5相继发布之际,《AIGC发展研究3.0版》的重磅报告——这份长达200页的行业圣经,不仅预测了2025年AI技术爆发点,更将「天人合一」的东方智慧融入AI伦理建构,堪称数字时代的《道德经》。 文档:清华大学第…...
/dev/console文件详解
/dev/console概览 /dev/console 是 Linux 系统中的一个特殊设备文件,通常用于与系统的控制台进行交互。它的作用和特点如下: 1. 作用 init 进程(PID 1)和某些系统服务在启动时会使用 /dev/console 进行日志输出,以确…...

ProfibusDP主站转ModbusTCP网关如何进行数据互换
ProfibusDP主站转ModbusTCP网关如何进行数据互换 在现代工业自动化领域,通信协议的多样性和复杂性不断增加。Profibus DP作为一种经典的现场总线标准,广泛应用于工业控制网络中;而Modbus TCP作为基于以太网的通信协议,因其简单易…...
springboot3 WebClient
1 介绍 在 Spring 5 之前,如果我们想要调用其他系统提供的 HTTP 服务,通常可以使用 Spring 提供的 RestTemplate 来访问,不过由于 RestTemplate 是 Spring 3 中引入的同步阻塞式 HTTP 客户端,因此存在一定性能瓶颈。根据 Spring 官…...

牛客周赛 Round 83
A.和猫猫一起起舞! 思路:遇到‘U’和‘D’,输出‘R’或者‘L’;遇到‘R’和‘L’,输出‘U’或者‘D’.(这题比较简单) AC代码: void solve() {int n, m, k;char ch;cin >> ch;if (ch U || ch D)…...

硬通货用Deekseek做一个Vue.js组件开发的教程
安装 Node.js 与 Vue CLI npm install -g vue/cli vue create my-vue-project cd my-vue-project npm run serve 通过 Vue CLI 可快速生成项目骨架,默认配置适合新手快速上手 目录结构 src/ ├── components/ # 存放组件文件 │ └── …...

Windows权限维持之利用安全描述符隐藏服务后门进行权限维持(八)
我们先打开cs的服务端 然后我们打开客户端 我们点击连接 然后弹出这个界面 然后我们新建一个监听器 然后我们生成一个beacon 然后把这个复制到目标主机 然后我们双击 运行 然后cs这边就上线了 然后我们把进程结束掉 然后我们再把他删除掉 然后我们创建服务 将后门程序注册…...

Ubuntu20.04双系统安装及软件安装(七):Anaconda3
Ubuntu20.04双系统安装及软件安装(七):Anaconda3 打开Anaconda官网,在右侧处填写邮箱(要真实有效!),然后Submit。会出现如图示的Success界面。 进入填写的邮箱,有一封Ana…...
【极光 Orbit•STC8A-8H】02. STC8 单片机工程模板创建
【极光 Orbit•STC8A-8H】02. STC8 单片机工程模板创建 七绝单片机 小小芯片大乾坤, 集成世界在其中。 初学虽感千重难, 实践方知奥妙通。 今天的讲法和过去不同,直接来一个多文件模块化的工程模板创建,万事开头难,…...
Spring Boot WebFlux 中 WebSocket 生命周期解析
Spring Boot WebFlux 中的 WebSocket 提供了一种高效、异步的方式来处理客户端与服务器之间的双向通信。WebSocket 连接的生命周期包括连接建立、消息传输、连接关闭以及资源清理等过程。此外,为了确保 WebSocket 连接的稳定性和可靠性,我们可以加入重试…...

PostgreSQL中的事务隔离
1. 事务隔离的概念 在数据库管理系统中,事务隔离是一项重要的功能,它能确保在并发访问数据库时事务之间能够独立运行,不会相互干扰。数据库系统通常支持不同级别的事务隔离,用来满足不同应用程序之间的需求。 2. 事务隔离的种类…...

stm32G473的flash模式是单bank还是双bank?
今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...
JDK 17 新特性
#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的ÿ…...

佰力博科技与您探讨热释电测量的几种方法
热释电的测量主要涉及热释电系数的测定,这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中,积分电荷法最为常用,其原理是通过测量在电容器上积累的热释电电荷,从而确定热释电系数…...

AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...
适应性Java用于现代 API:REST、GraphQL 和事件驱动
在快速发展的软件开发领域,REST、GraphQL 和事件驱动架构等新的 API 标准对于构建可扩展、高效的系统至关重要。Java 在现代 API 方面以其在企业应用中的稳定性而闻名,不断适应这些现代范式的需求。随着不断发展的生态系统,Java 在现代 API 方…...

elementUI点击浏览table所选行数据查看文档
项目场景: table按照要求特定的数据变成按钮可以点击 解决方案: <el-table-columnprop"mlname"label"名称"align"center"width"180"><template slot-scope"scope"><el-buttonv-if&qu…...
React核心概念:State是什么?如何用useState管理组件自己的数据?
系列回顾: 在上一篇《React入门第一步》中,我们已经成功创建并运行了第一个React项目。我们学会了用Vite初始化项目,并修改了App.jsx组件,让页面显示出我们想要的文字。但是,那个页面是“死”的,它只是静态…...

Tauri2学习笔记
教程地址:https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引:https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多,我按照Tauri1的教程来学习&…...
6.计算机网络核心知识点精要手册
计算机网络核心知识点精要手册 1.协议基础篇 网络协议三要素 语法:数据与控制信息的结构或格式,如同语言中的语法规则语义:控制信息的具体含义和响应方式,规定通信双方"说什么"同步:事件执行的顺序与时序…...