pytorch-神经网络-手写数字分类任务
Mnist分类任务:
-
网络基本构建与训练方法,常用函数解析
-
torch.nn.functional模块
-
nn.Module模块
-
读取Mnist数据集
- 会自动进行下载
%matplotlib inlinefrom pathlib import Path import requestsDATA_PATH = Path("data") PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/" FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)import pickle import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")from matplotlib import pyplot import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray") print(x_train.shape)
-
\

-
注意数据需转换成tensor才能参与后续建模训练
import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid) ) n, c = x_train.shape x_train, x_train.shape, y_train.min(), y_train.max() print(x_train, y_train) print(x_train.shape) print(y_train.min(), y_train.max())torch.nn.functional 很多层和函数在这里都会见到
torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些
import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + biasbs = 64 xb = x_train[0:bs] # a mini-batch from x yb = y_train[0:bs] weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True) bs = 64 bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))创建一个model来更简化代码
- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
- 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return xnet = Mnist_NN() print(net)
for name, parameter in net.named_parameters():print(name, parameter,parameter.size())hidden1.weight Parameter containing: tensor([[ 0.0018, 0.0218, 0.0036, ..., -0.0286, -0.0166, 0.0089],[-0.0349, 0.0268, 0.0328, ..., 0.0263, 0.0200, -0.0137],[ 0.0061, 0.0060, -0.0351, ..., 0.0130, -0.0085, 0.0073],...,[-0.0231, 0.0195, -0.0205, ..., -0.0207, -0.0103, -0.0223],[-0.0299, 0.0305, 0.0098, ..., 0.0184, -0.0247, -0.0207],[-0.0306, -0.0252, -0.0341, ..., 0.0136, -0.0285, 0.0057]],requires_grad=True) torch.Size([128, 784]) hidden1.bias Parameter containing: tensor([ 0.0072, -0.0269, -0.0320, -0.0162, 0.0102, 0.0189, -0.0118, -0.0063,-0.0277, 0.0349, 0.0267, -0.0035, 0.0127, -0.0152, -0.0070, 0.0228,-0.0029, 0.0049, 0.0072, 0.0002, -0.0356, 0.0097, -0.0003, -0.0223,-0.0028, -0.0120, -0.0060, -0.0063, 0.0237, 0.0142, 0.0044, -0.0005,0.0349, -0.0132, 0.0138, -0.0295, -0.0299, 0.0074, 0.0231, 0.0292,-0.0178, 0.0046, 0.0043, -0.0195, 0.0175, -0.0069, 0.0228, 0.0169,0.0339, 0.0245, -0.0326, -0.0260, -0.0029, 0.0028, 0.0322, -0.0209,-0.0287, 0.0195, 0.0188, 0.0261, 0.0148, -0.0195, -0.0094, -0.0294,-0.0209, -0.0142, 0.0131, 0.0273, 0.0017, 0.0219, 0.0187, 0.0161,0.0203, 0.0332, 0.0225, 0.0154, 0.0169, -0.0346, -0.0114, 0.0277,0.0292, -0.0164, 0.0001, -0.0299, -0.0076, -0.0128, -0.0076, -0.0080,-0.0209, -0.0194, -0.0143, 0.0292, -0.0316, -0.0188, -0.0052, 0.0013,-0.0247, 0.0352, -0.0253, -0.0306, 0.0035, -0.0253, 0.0167, -0.0260,-0.0179, -0.0342, 0.0033, -0.0287, -0.0272, 0.0238, 0.0323, 0.0108,0.0097, 0.0219, 0.0111, 0.0208, -0.0279, 0.0324, -0.0325, -0.0166,-0.0010, -0.0007, 0.0298, 0.0329, 0.0012, -0.0073, -0.0010, 0.0057],requires_grad=True) torch.Size([128]) hidden2.weight Parameter containing: tensor([[-0.0383, -0.0649, 0.0665, ..., -0.0312, 0.0394, -0.0801],[-0.0189, -0.0342, 0.0431, ..., -0.0321, 0.0072, 0.0367],[ 0.0289, 0.0780, 0.0496, ..., 0.0018, -0.0604, -0.0156],...,[-0.0360, 0.0394, -0.0615, ..., 0.0233, -0.0536, -0.0266],[ 0.0416, 0.0082, -0.0345, ..., 0.0808, -0.0308, -0.0403],[-0.0477, 0.0136, -0.0408, ..., 0.0180, -0.0316, -0.0782]],requires_grad=True) torch.Size([256, 128]) hidden2.bias Parameter containing: tensor([-0.0694, -0.0363, -0.0178, 0.0206, -0.0875, -0.0876, -0.0369, -0.0386,0.0642, -0.0738, -0.0017, -0.0243, -0.0054, 0.0757, -0.0254, 0.0050,0.0519, -0.0695, 0.0318, -0.0042, -0.0189, -0.0263, -0.0627, -0.0691,0.0713, -0.0696, -0.0672, 0.0297, 0.0102, 0.0040, 0.0830, 0.0214,0.0714, 0.0327, -0.0582, -0.0354, 0.0621, 0.0475, 0.0490, 0.0331,-0.0111, -0.0469, -0.0695, -0.0062, -0.0432, -0.0132, -0.0856, -0.0219,-0.0185, -0.0517, 0.0017, -0.0788, -0.0403, 0.0039, 0.0544, -0.0496,0.0588, -0.0068, 0.0496, 0.0588, -0.0100, 0.0731, 0.0071, -0.0155,-0.0872, -0.0504, 0.0499, 0.0628, -0.0057, 0.0530, -0.0518, -0.0049,0.0767, 0.0743, 0.0748, -0.0438, 0.0235, -0.0809, 0.0140, -0.0374,0.0615, -0.0177, 0.0061, -0.0013, -0.0138, -0.0750, -0.0550, 0.0732,0.0050, 0.0778, 0.0415, 0.0487, 0.0522, 0.0867, -0.0255, -0.0264,0.0829, 0.0599, 0.0194, 0.0831, -0.0562, 0.0487, -0.0411, 0.0237,0.0347, -0.0194, -0.0560, -0.0562, -0.0076, 0.0459, -0.0477, 0.0345,-0.0575, -0.0005, 0.0174, 0.0855, -0.0257, -0.0279, -0.0348, -0.0114,-0.0823, -0.0075, -0.0524, 0.0331, 0.0387, -0.0575, 0.0068, -0.0590,-0.0101, -0.0880, -0.0375, 0.0033, -0.0172, -0.0641, -0.0797, 0.0407,0.0741, -0.0041, -0.0608, 0.0672, -0.0464, -0.0716, -0.0191, -0.0645,0.0397, 0.0013, 0.0063, 0.0370, 0.0475, -0.0535, 0.0721, -0.0431,0.0053, -0.0568, -0.0228, -0.0260, -0.0784, -0.0148, 0.0229, -0.0095,-0.0040, 0.0025, 0.0781, 0.0140, -0.0561, 0.0384, -0.0011, -0.0366,0.0345, 0.0015, 0.0294, -0.0734, -0.0852, -0.0015, -0.0747, -0.0100,0.0801, -0.0739, 0.0611, 0.0536, 0.0298, -0.0097, 0.0017, -0.0398,0.0076, -0.0759, -0.0293, 0.0344, -0.0463, -0.0270, 0.0447, 0.0814,-0.0193, -0.0559, 0.0160, 0.0216, -0.0346, 0.0316, 0.0881, -0.0652,-0.0169, 0.0117, -0.0107, -0.0754, -0.0231, -0.0291, 0.0210, 0.0427,0.0418, 0.0040, 0.0762, 0.0645, -0.0368, -0.0229, -0.0569, -0.0881,-0.0660, 0.0297, 0.0433, -0.0777, 0.0212, -0.0601, 0.0795, -0.0511,-0.0634, 0.0720, 0.0016, 0.0693, -0.0547, -0.0652, -0.0480, 0.0759,0.0194, -0.0328, -0.0211, -0.0025, -0.0055, -0.0157, 0.0817, 0.0030,0.0310, -0.0735, 0.0160, -0.0368, 0.0528, -0.0675, -0.0083, -0.0427,-0.0872, 0.0699, 0.0795, -0.0738, -0.0639, 0.0350, 0.0114, 0.0303],requires_grad=True) torch.Size([256]) out.weight Parameter containing: tensor([[ 0.0232, -0.0571, 0.0439, ..., -0.0417, -0.0237, 0.0183],[ 0.0210, 0.0607, 0.0277, ..., -0.0015, 0.0571, 0.0502],[ 0.0297, -0.0393, 0.0616, ..., 0.0131, -0.0163, -0.0239],...,[ 0.0416, 0.0309, -0.0441, ..., -0.0493, 0.0284, -0.0230],[ 0.0404, -0.0564, 0.0442, ..., -0.0271, -0.0526, -0.0554],[-0.0404, -0.0049, -0.0256, ..., -0.0262, -0.0130, 0.0057]],requires_grad=True) torch.Size([10, 256]) out.bias Parameter containing: tensor([-0.0536, 0.0007, 0.0227, -0.0072, -0.0168, -0.0125, -0.0207, -0.0558,0.0579, -0.0439], requires_grad=True) torch.Size([10])
-
使用TensorDataset和DataLoader来简化
from torch.utils.data import TensorDataset from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train) train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid) valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),) - 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))from torch import optim def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)三行搞定!
train_dl, valid_dl = get_data(train_ds, valid_ds, bs) model, opt = get_model() fit(25, model, loss_func, opt, train_dl, valid_dl)当前step:0 验证集损失:2.2796445930480957 当前step:1 验证集损失:2.2440698066711424 当前step:2 验证集损失:2.1889826164245605 当前step:3 验证集损失:2.0985311767578123 当前step:4 验证集损失:1.9517273582458496 当前step:5 验证集损失:1.7341805934906005 当前step:6 验证集损失:1.4719875366210937 当前step:7 验证集损失:1.2273896869659424 当前step:8 验证集损失:1.0362271406173706 当前step:9 验证集损失:0.8963696184158325 当前step:10 验证集损失:0.7927186088562012 当前step:11 验证集损失:0.7141492074012756 当前step:12 验证集损失:0.6529350900650024 当前step:13 验证集损失:0.60417300491333 当前step:14 验证集损失:0.5643046331882476 当前step:15 验证集损失:0.5317994566917419 当前step:16 验证集损失:0.5047958114624024 当前step:17 验证集损失:0.4813900615692139 当前step:18 验证集损失:0.4618900228500366 当前step:19 验证集损失:0.4443243554592133 当前step:20 验证集损失:0.4297310716629028 当前step:21 验证集损失:0.416976597738266 当前step:22 验证集损失:0.406348459148407 当前step:23 验证集损失:0.3963301926612854 当前step:24 验证集损失:0.38733808159828187https://gitee.com/code-wenjiahao/neural-network-practical-classification-and-regression-tasks/tree/master
相关文章:
pytorch-神经网络-手写数字分类任务
Mnist分类任务: 网络基本构建与训练方法,常用函数解析 torch.nn.functional模块 nn.Module模块 读取Mnist数据集 会自动进行下载 %matplotlib inlinefrom pathlib import Path import requestsDATA_PATH Path("data") PATH DATA_PATH / &…...
【群智能算法改进】一种改进的鹈鹕优化算法 IPOA算法[1]【Matlab代码#57】
文章目录 【获取资源请见文章第5节:资源获取】1. 原始POA算法2. 改进后的IPOA算法2.1 Sine映射种群初始化2.2 融合改进的正余弦策略2.3 Levy飞行策略 3. 部分代码展示4. 仿真结果展示5. 资源获取 【获取资源请见文章第5节:资源获取】 1. 原始POA算法 此…...
C++初阶:C++入门
目录 一.iostream文件 二.命名空间 2.1.命名空间的定义 2.2.命名空间的使用 三.C的输入输出 四.缺省参数 4.1.缺省参数概念 4.2.缺省参数分类 4.3.缺省参数注意事项 4.4.缺省参数用途 五.函数重载 5.1.重载函数概念 5.2.C支持函数重载的原理--名字修饰(name Mangl…...
golang操作数据库--gorm框架、redis
目录 1.数据库相关操作(1)非orm框架①引入②初始化③增删改查 (2) io版orm框架 (推荐用这个)①引入②初始化③增删改查④gorm gen的使用 (3) jinzhu版orm框架①引入②初始化③增删改查 2.redis(1)引入(2)初始化①普通初始化②v8初始化③get/set示例 1.数据库相关操作 (1)非orm…...
10 种常用的字符串方法
10 种常用的字符串方法 1.concat() 字符串拼接 const str1 12345678;const str2 abcdefgh;const str3 -【】;‘;console.log(str1.concat(str2,str3))//12345678abcdefgh-【】;‘ 2.includes() 判断字符串中是否包含指定值,返回布尔值…...
CSDN每日一练 |『生命进化书』『订班服』『c++难题-大数加法』2023-09-06
CSDN每日一练 |『生命进化书』『订班服』『c++难题-大数加法』2023-09-06 一、题目名称:生命进化书二、题目名称:订班服三、题目名称:c++难题-大数加法一、题目名称:生命进化书 时间限制:1000ms内存限制:256M 题目描述: 小A有一本生命进化书,以一个树形结构记载了所有生…...
echarts饼图label自定义样式
生成的options {"tooltip": {"trigger": "item","axisPointer": {"type": "shadow"},"backgroundColor": "rgba(9, 24, 48, 0.5)","borderColor": "rgba(255,255,255,0.4)&q…...
Unity汉化一个插件 制作插件汉化工具
我是编程一个菜鸟,英语又不好,有的插件非常牛!我想学一学,页面全是英文,完全不知所措,我该怎么办啊...尝试在Unity中汉化一个插件 效果: 思路: 如何在Unity中把一个自己喜欢的插件…...
从过滤器初识责任链设计模式
下面用的过滤器都是注解方式 可以使用非注解方式,就是去web.xml配置映射关系 上面程序的执行输出是 再加一个过滤器 下面来看一段程序 输出结果 和过滤器是否非常相识 但是上面这段程序存在的问题:在编译阶段已经完全确定了调用关系,如果你想改变他们的调用顺序或者继续添加一…...
Redis7安装配置
✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Java从入门到精通 ✨特色专栏…...
切分支解决切不走因为未合并的路径如何解决
改代码的时候改做分支了,本来是在另一个分支上面改代码,结果改到另一个放置上面,然后想着使用git stash进行保存,然后切到另外一个分支再pop,结果不行。 报这个错误,导致切不过去,因为我这边pop…...
自动化运维:Ansible之playbook基于ROLES部署LNMP平台
目录 一、理论 1.playbook剧本 2.ROLES角色 3.关系 4.Roles模块搭建LNMP架构 二、实验 1.Roles模块搭建LNMP架构 三、问题 1.剧本启动php报错语法问题 2.剧本启动mysql报错语法问题 3.剧本启动nginx开启失败 4.剧本安装php失败 5.使用yum时报错 6.rpm -Uvh https…...
SpringBoot整合MQ
1.创建工程并引入依赖 <!-- 添加rocketmq的启动器--><dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId><version>2.1.1</version></dependency>2.编写…...
算法训练day37|贪心算法 part06(LeetCode738.单调递增的数字)
文章目录 738.单调递增的数字思路分析代码实现 738.单调递增的数字 题目链接🔥🔥 给定一个非负整数 N,找出小于或等于 N 的最大的整数,同时这个整数需要满足其各个位数上的数字是单调递增。 (当且仅当每个相邻位数上的…...
【C++基础】4. 变量
文章目录 【 1. 变量的定义 】【 2. 变量的声明 】示例 【 3. 左值和右值 】 变量:相当于是程序可操作的数据存储区的名称。在 C 中,有多种变量类型可用于存储不同种类的数据。C 中每个变量都有指定的类型,类型决定了变量存储的大小和布局&am…...
jmeter setUp Thread Group
SetUp Thread Group 是一种特殊类型的线程组,它用于在主测试计划执行之前执行一些初始化任务。 SetUp Thread Group 通常用于以下几种情况: 用户登录:在模拟用户执行实际测试之前,模拟用户登录到系统以获取访问权限。 创建会话&a…...
图神经网络教程之GCN(pyG)
图神经网络-pyG版本的GCN Data(数据) data.x、data.edge_index、data.edge_attr、data.y、data.pos 举个例子 import torch from torch_geometric.data import Data edge_index torch.tensor([[0, 1, 1, 2],[1, 0, 2, 1]], dtypetorch.long) #代表…...
python中的逻辑运算
逻辑运算 逻辑运算符是python用来进行逻辑判断的运算符,虽然运算符只有and、or、not三种,但是理解这三个运算符的原理才是最重要的 python中对false的认定 逻辑运算符是python用来进行逻辑判断的运算符,虽然运算符只有and、or、not三种&…...
TortoiseGit设置作者信息和用户名、密码存储
前言 Git 客户端每次与服务器交互,都需要输入密码,但是我们可以配置保存密码,只需要输入一次,就不再需要输入密码。 操作说明 在任意文件夹下,空白处,鼠标右键点击 在弹出菜单中按照下图点击 依次点击下…...
Fragment.OnPause的事情
我们知道Fragment的生命周期依附于相应Activity的生命周期,如果activity A调用了onPause,则A里面的fragment也会相应收到onPause回调,这里以support27.1.1版本的源码来说明Fragment生命周期onPause的事情。 当activity执行onPause时ÿ…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...
华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)
更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
学习STC51单片机32(芯片为STC89C52RCRC)OLED显示屏2
每日一言 今天的每一份坚持,都是在为未来积攒底气。 案例:OLED显示一个A 这边观察到一个点,怎么雪花了就是都是乱七八糟的占满了屏幕。。 解释 : 如果代码里信号切换太快(比如 SDA 刚变,SCL 立刻变&#…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...
