当前位置: 首页 > news >正文

从0开始深度学习(12)——多层感知机的逐步实现

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过

1 读取数据

import torch
from torchvision import transforms
import torchvision
from torch.utils import data# 读取数据
def load_data_fashion_mnist(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=True, transform=trans, download=False)mnist_test = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=False, transform=trans, download=False)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=12),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=12))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

2 初始化模型参数

以单隐藏层的多层感知机为例,选择使用256个隐藏单元

from torch import nn# 初始化模型参数
num_inputs=784      # 28*28
num_outputs=10
num_hiddens=256     # 我们选择使用256个隐藏单元,注意,一般选择使用2的若干次幂,因为内存的特殊性,可以在计算上更高效w1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [w1, b1, w2, b2]

3 激活函数、损失函数、建立模型

# 激活函数
def relu(x):a=torch.zeros_like(x) # 保证全零张量和x的形状一致,利于广播计算return torch.max(x,a)# 损失函数
loss = nn.CrossEntropyLoss(reduction='none')#建立模型
def net(x):x=x.reshape((-1,num_inputs))#展开H=relu(x@w1+b1)# @表示矩阵乘法return (H@w2+b2)

4 训练模型

优化器使用SGD

#训练,优化器使用sgd
num_epochs=5
lr=00.1
updater=torch.optim.SGD(params,lr=lr)def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()  # 将模型设置为训练模式metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def compute_accuracy(y_hat, y):  # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]train(net, train_iter, test_iter, loss, num_epochs, updater)

在这里插入图片描述

5 预测

import matplotlib.pyplot as plt
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break  # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

相关文章:

从0开始深度学习(12)——多层感知机的逐步实现

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过 1 读取数据 import torch from torchvision import transforms import torchv…...

如何利用OpenCV和yolo实现人脸检测

在之前的blog里面,我们有介绍OpenCV和yolo的区别,本文就人脸检测为例,分别介绍下OpenCV和yolo的实现方式。 OpenCV实现人脸检测 一、安装 OpenCV 首先确保你已经安装了 OpenCV 库。可以通过以下方式安装: 使用包管理工具安装&…...

015集——c# 实现CAD excel交互(CAD—C#二次开发入门)

第一步:添加引用 程序集—>扩展 namespace WindowsFormsApp2 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){}private void 获取当前excel_Click(object sender, EventArgs e…...

【计网笔记】以太网

经典以太网 总线拓扑 物理层 Manchester编码 数据链路层 MAC子层 MAC帧 DIX格式与IEEE802.3格式 IEEE802.3格式兼容DIX格式 前导码(帧开始定界符SOF) 8字节 前7字节均为0xAA第8字节为0xAB前7字节的Manchester编码将产生稳定方波,用于…...

Java 入门基础篇14 - java面向对象思想以及特性

学习目标: 一、目标 面向对象思想类和对象对象的创建和使用属性和方法封装 开始学习: 二、编程思想 2.1 什么是编程思想 做人有做人的原则,编程也有编程的原则。这些编程的原则,就叫做编程思想。 2.2 面向过程和面向对象 二…...

第15篇:网络架构优化与综合案例分析

目录 引言 15.1 网络性能优化的方法与工具 15.1.1 带宽管理与流量控制 15.1.2 负载均衡 15.1.3 缓存优化 15.2 网络故障的排查与解决 15.2.1 常用的网络故障排查工具 15.2.2 网络故障排查案例 15.3 网络安全架构的综合设计案例 15.3.1 企业网络安全架构的要求 15.3.…...

UI自动化测试实战

补充:Selenium主要用于Web页面的自动化测试,它可以模拟用户的各种操作,如点击、输入、滚动等,来测试网页的功能。而Appium是一个开源的移动端自动化测试工具。 一、自动化测试实战章节 自动化测试流程测试用例编写项目自动化测试…...

东方智者颜廷利:以哲学思想促进世界和谐与无私奉献

【本社讯】在全球化的今天,东方智慧与哲学思想正逐渐成为促进世界和谐与理解的重要力量。近日,祖籍齐鲁大地山东济南的东方智者颜廷利以其深邃的哲学思想和对人类社会的深刻洞察,引起了国际社会的广泛关注。 颜廷利,一位致力于哲学研究与实践的智者,他的思想跨越古今,融合了东…...

基于 springboot vue停车场管理系统 设计与实现

博主介绍:专注于Java(springboot ssm 等开发框架) vue .net php phython node.js uniapp 微信小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不…...

如何验证ssl私钥和证书是否匹配?

从证书(CRT)文件提取公钥 openssl x509 -in server.crt -pubkey -noout | openssl sha256从证书签名请求(CSR)文件提取公钥 openssl req -in server.csr -pubkey -noout | openssl sha256从私钥(KEY)文件…...

MongoDB的基本操作

🌷数据库准备 🎈Mongoshell 1.在指定目录下创建mongodb文件夹、其子文件log和data以及mongodb.log cd /home/ubuntu mkdir -p mongodb/data mkdir -p mongodb/log touch mongodb/log/mongodb.log 执行mongodb命令启动mongdb服务 mongod --dbpath /h…...

spring mvc后端实现过程

文章目录 一、Spring mvc1、controller1.1、LoginController011.2、LoginController 2、service2.1、LoginService2.1、LoginInimplements 3、dao3.1、LoginMapper3.1、LoginMapper.xml 4、实体类 一、Spring mvc 1、controller 控制器层、处理用户的请求和响应, …...

102005

import os os.environ["CUDA_VISIBLE_DEVICES"] "0" # 设定使用的 GPUimport tensorflow as tf from dataset import generate_data import numpy as np from model import enhancednet# 检查 TensorFlow 是否可以识别 GPU gpus tf.config.list_physica…...

Cisco ACI环境给Leaf配置OOB带外管理IP方法

可以通过GUI 或CLI进行配置 通过CLI更简单,和配置传统交换机差不多, ACI中共有3大组件 APIC 控制器 SPINE 核心 LEAF 接入 下面我们将3种角色的带外IP配置方法都列出来 1 APIC配置带外IP This example shows how to configure out-of-band managemen…...

免费送源码:Java+B/S+MySQL springboot电影推荐系统 计算机毕业设计原创定制

摘 要 随着互联网与移动互联网迅速普及,网络上的电影娱乐信息数量相当庞大,人们对获取感兴趣的电影娱乐信息的需求越来越大,个性化的电影推荐系统成为一个热门。然而电影信息的表示相当复杂,己有的相似度计算方法与推荐算法都各有优势&#…...

数据清洗(脚本)

使用脚本清洗数据时,可以根据具体的数据问题选择编程语言,如Shell、Python、SQL等。这里我以 Python(Pandas库) 和 SQL 为例,演示如何通过脚本进行数据清洗。 1. 使用 Python(Pandas库) 进行数…...

jmeter中发送post请求遇到的问题

用jmeter发送post请求,把请求参数放在Body Data处,参数都写得正确,但没想到结果每次都报错,直接响应结果乱七八糟,改成用Parameters,反而不乱报错了。 上图 请求里如下 另外一些请求也是这样 这个响应结果也是错误的…...

Java中使用protobuf

一、简介 Protocal Buffers(简称protobuf)是谷歌的一项技术,用于结构化的数据序列化、反序列化。 Protocol Buffers 是一种语言无关、平台无关、可扩展的序列化结构数据的方法,它可用于(数据)通信协议、数据存储等。 Protocol B…...

2020款Macbook Pro A2251无法充电无法开机定位及修复

问题背景 up主有一台2020年的Macbook Pro,带Touch Bar,16G512G,四核I5,型号A2251 应该是一周没充电了,之前还用的好好的,后来有一天出差想带上 打开没电,手头上有个小米的66W快充头&#xff0c…...

Spring Cloud --- 引入Gateway网关

引入Gateway网关 介绍 Spring Cloud Gateway 组件的核心是一系列的过滤器,通过这些过滤器可以将客户端发送的请求转发(路由)到对应的微服务。 Spring Cloud Gateway 是加在整个微服务最前沿的防火墙和代理器,隐藏微服务结点 IP 端口信息,从…...

idea大量爆红问题解决

问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作

一、上下文切换 即使单核CPU也可以进行多线程执行代码&#xff0c;CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短&#xff0c;所以CPU会不断地切换线程执行&#xff0c;从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机

这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机&#xff0c;因为在使用过程中发现 Airsim 对外部监控相机的描述模糊&#xff0c;而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置&#xff0c;最后在源码示例中找到了&#xff0c;所以感…...

Sklearn 机器学习 缺失值处理 获取填充失值的统计值

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 使用 Scikit-learn 处理缺失值并提取填充统计信息的完整指南 在机器学习项目中,数据清…...