当前位置: 首页 > article >正文

Python打卡DAY40

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 创建数据加载器
batch_size = 64  # 每批处理64个样本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 设置为训练模式# 新增:记录每个 iteration 的损失all_iter_losses = []  # 存储所有 batch 的损失iter_indices = []     # 存储 iteration 序号(从1开始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向传播loss = criterion(output, target)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数# 记录当前 iteration 的损失(注意:这里直接使用单 batch 损失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序号从1开始# 统计准确率和损失(原逻辑保留,用于 epoch 级统计)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100个批次打印一次训练信息(可选:同时打印单 batch 损失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')# 原 epoch 级逻辑(测试、打印 epoch 结果)不变epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 绘制所有 iteration 的损失曲线plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 级曲线(可选)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最终测试准确率# 6. 测试模型
def test(model, test_loader, criterion, device):model.eval()  # 设置为评估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回损失和准确率# 7.绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 执行训练和测试(设置 epochs=2 验证效果)
epochs = 2  
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

@浙大疏锦行

相关文章:

Python打卡DAY40

知识点回顾: 彩色和灰度图片测试和训练的规范写法:封装在函数中展平操作:除第一个维度batchsize外全部展平dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout 作业:仔细学习下测试和训练代码…...

OPC Client第6讲(wxwidgets):Logger.h日志记录文件(单例模式);登录后的主界面

接上一讲三、2、2>4》,创建logger.h和helper_t.h里的gettime函数 即解决下图的报红 同时,接上一讲二、3、点击“确认”按钮后,进入MainFrame.h对应的下述界面,此讲下图进行实现 一、创建Logger.h:日志记录文件&…...

CesiumInstancedMesh 实例

CesiumInstancedMesh 实例 import * as Cesium from cesium;// Three.js 风格的 InstancedMesh 类, https://threejs.org/docs/#api/en/objects/InstancedMesh export class CesiumInstancedMesh {/*** Creates an instance of InstancedMesh.** param {Cesium.Geometry} geom…...

单细胞注释前沿:CASSIA——无参考、可解释、自动化细胞注释的大语言模型

细胞类型注释是单细胞RNA-seq分析的重要步骤,目前有许多注释方法。大多数注释方法都需要计算和特定领域专业知识的结合,而且经常产生不一致的结果,难以解释。大语言模型有可能在减少人工输入和提高准确性的同时扩大可访问性,但现有…...

历年武汉大学计算机保研上机真题

2025武汉大学计算机保研上机真题 2024武汉大学计算机保研上机真题 2023武汉大学计算机保研上机真题 在线测评链接:https://pgcode.cn/school 分段函数计算 题目描述 写程序计算如下分段函数: 当 x > 0 x > 0 x>0 时, f ( x ) …...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(30):みます

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(30):みます 1、前言(1)情况说明(2)工程师的信仰2、知识点(1)ように 復習:1、ように Change12、ように Ideal state(理想(りそう)の状態(じょうたい))3、V辞書・Vない ようにしています いつも気をつけて…...

AR-HUD 光波导方案优化难题待解?OAS 光学软件来破局

波导-HUD系统案例分析 简介 光波导技术凭借其平板超薄结构和强大的二维扩展能力,在解决AR-HUD问题方面展现出显著优势。一方面,其独特的结构特性能够大幅减小对光机体积的需求,成为 HUD 未来发展的重要技术方向;另一方面&#xf…...

火狐安装自动录制表单教程——仙盟自动化运营大衍灵机——仙盟创梦IDE

打开火狐插件页面 安装完成 使用 功能 录制浏览器操作 录入地址 开始操作 录制完成 在当今快速发展的软件开发生态中,自动化测试已从一种新兴技术手段,转变为保障软件质量与开发效率不可或缺的关键环节。其重要性体现在多个维度,同时&#x…...

线程池的详细知识(含有工厂模式)

前言 下午学习了线程池的知识。重点探究了ThreadPoolExecutor里面的各种参数的含义。我详细了解了这部分的知识。其中有一个参数涉及工厂模式,我将这一部分知识分享给大家~ 线程池的详细介绍(含工厂模式) 结语 分享到此结束啦。byebye~...

木愚科技闪亮第63届高博会 全栈式智能教育解决方案助力教学升级

5月23日,第63届高等教育博览会在长春东北亚国际博览中心开幕,木愚科技积极筹备,奔赴展会现场。彼时,木愚科技企业领导及相关职能部门负责人亲临展位指导工作,通过特装展位、资料发放及现场交流等方式,全方位…...

Proteus寻找元器件(常见)

一 元件库 二 找元件 1 主控 32 51 输入 stm32 AT89c51 2 找屏幕 oled 3 找按键button 4 电阻、电容 res cap 5 电机驱动 l298n 6 电机 motor 7 滑动变阻器 pot 8 找电源和 GND 9 找晶振 选择 D 开头的 CRYSTAL 10 网络标签...

RK3566 Android12 HG24C02MM/TR EEPROM适配

一、背景 近期项目中,有一个需求,要使用RK3566 Android12平台适配一款HG24C02MM/TR EEPROM芯片,通过i2c实现主板与EEPROM芯片的数据通讯。废话不多说,来看资料。 二、芯片资料 HG24C02 / HG24C04 / HG24C08 / HG24C16是提供2048…...

IoTDB 集成 DBeaver,简易操作实现时序数据清晰管理

数据结构一目了然,跨库分析轻松实现,方便 IoTDB “内部构造”管理! 随着物联网场景对时序数据处理需求激增,时序数据库与数据库管理工具的集成尤为关键。作为数据资产的 “智能管家”,借助数据库管理工具的可视化操作界…...

sqli-labs第二十八关——Trick with ‘union select‘

一:分析 这一关的提示和上一关一样,所以我们查看源码,屏蔽了注释符,空格,union,select等关键词 分析这一条源码的几个新增添符号 \s: 匹配任何的空白字符(普通空格,\t&…...

mapbox高阶,PMTiles介绍,MBTiles、PMTiles对比,加载PMTiles文件

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️Fill面图层样式1.4 ☘️PMTiles介绍1.5…...

Go语言通道如何实现通信

在Go语言中,通道(channel)是一种内置的数据结构,用于在不同的goroutine之间进行通信和同步。通道提供了一种安全且有效的方式来传递数据,避免了数据竞争和死锁等问题。 要在Go语言中使用通道进行通信,你需…...

投稿 IEEE Transactions on Knowledge and Data Engineering 注意事项

投稿 IEEE Transactions on Knowledge and Data Engineering 注意事项 要IEEE overleaf 模板私信,我直接给我自己论文,便于编辑 已经投稿完成了,有一些小坑 准备工作 注册IEEE账户:若没有IEEE账户,需前往IEEE官网注册。注册成功后,可用于登录投稿系统。现在新的系统,…...

题目 3316: 蓝桥杯2025年第十六届省赛真题-数组翻转

题目 3316: 蓝桥杯2025年第十六届省赛真题-数组翻转 时间限制: 3s 内存限制: 512MB 提交: 101 解决: 24 题目描述 小明生成了一个长度为 n 的正整数数组 a1, a2, . . . , an,他可以选择连续的一 段数 al , al1, ..., ar,如果其中所有数都相等即 al al1 …...

mongodb源码分析session接受客户端find命令过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制。 现在继续研究ASIOSession和connection是怎么接受客户端命令的? mongo/transport/service_state_machine.cpp核心方法有&#xf…...

Netty 实战篇:为自研 RPC 框架加入异步调用与 Future 支持

我们在上篇实现了一个轻量级 RPC 框架,现在要进一步优化 —— 加入异步响应支持,让 RPC 通信变得真正高效、非阻塞、支持并发。 一、为什么需要异步调用? 上篇的 RPC 框架是“同步阻塞”的: 每次发送请求后,必须等待服…...

python37天打卡

知识点回顾: 过拟合的判断:测试集和训练集同步打印指标 模型的保存和加载 仅保存权重 保存权重和模型 保存全部信息checkpoint,还包含训练状态 早停策略 作业:对信贷数据集训练后保存权重,加载权重后继续训练50轮&am…...

变焦位移计:机器视觉如何克服人工疲劳与主观影响?精准对结构安全实时监测

变焦视觉位移监测与人工监测的对比 人工监测是依靠目测检查或借助于全站仪,水准仪,RTK等便携式仪器测量得到的信息,但是随着整个行业的发展,传统的人工监测方法已经不能满足监测需求,从人工监测到自动化监测已是必然趋…...

嵌入式硬件篇---Ne555定时器

文章目录 前言1. 基本概述类型功能封装形式2. 引脚功能(DIP-8 封装)内部结构阈值电压两种工作模式4. 主要特性优点:缺点:5. 典型应用场景定时控制脉冲生成检测与触发信号处理6. 关键参数速查表前言 本文简单介绍了Ne555定时器(多谐振荡器/定时器)。DIP与SOP封装。 1. 基…...

【Axure结合Echarts绘制图表】

1.绘制一个矩形,用于之后存放图表,将其命名为test: 2.新建交互 -> 载入时 -> 打开链接: 3.链接到URL或文件路径: 4.点击fx: 5.输入: javascript: var script document.createEleme…...

使用web3工具结合fiscobcos网络部署调用智能合约

借助 web3 工具,在 FISCO BCOS 网络上高效部署与调用智能合约,解锁区块链开发新体验。 搭建的区块链网络需要是最新的fiscobcos3.0,最新的才支持web3调用 现在分享踩坑经验,希望大家点赞 目录 1.搭建fiscobcos节点(3.…...

Oracle/openGauss中,DATE/TIMESTAMP与数字日期/字符日期比较

ORACLE 运行环境 openGauss 运行环境 0、前置知识 ORACLE:DUMP()函数用于返回指定表达式的数据类型、字节长度及内部存储表示的详细信息 SELECT DUMP(123) FROM DUAL; -- Typ2 Len3: 194,2,24 SELECT DUMP(123) FROM DUAL;-- Typ96 Len3: 49,50,51 -- ASCII值&am…...

Datatable和实体集合互转

1.使用已废弃的 JavaScriptSerializer,且反序列化为弱类型 ArrayList。可用但不推荐。 using System; using System.Collections; using System.Collections.Generic; using System.Data; using System.Linq; using System.Reflection; using System.Web; using Sy…...

Win11切换JDK版本批处理脚本

维护的老项目jdk1.8,新项目开发采用jdk21,所以寻找类似nvm的软件,都不太满意,最后还是决定采用写一个脚本算了,先不折腾了。 1、创建switch_jdk.bat文件 2、把如下内容复制进行 echo off chcp 65001 >nul setloc…...

爬虫学习-Scrape Center spa6 超简单 JS 逆向

关卡 spa6 电影数据网站,无反爬,数据通过 Ajax 加载,数据接口参数加密且有时间限制,适合动态页面渲染爬取或 JavaScript 逆向分析。 首先抓包发现get请求的参数token有加密。 offset表示翻页,limit表示每一页有多少…...

对数的运算困惑

难点总结 学生在对数运算中的难点分析: 一、不理解对数,不会用对数公式或错用对数公式 ①对数 l o g 2 3 log_23 log2​3和指数幂 2 3 2^3 23一样,也就是个实数而已,所以其也会有加减乘除乘方开方等运算; 比如 2 2 + l o g 2 3 = 2 2 ⋅ 2 l o g 2 3 = 4 ⋅ 3 = 12 2^{2…...