《Pytorch深度学习实践》ch8-多分类
------B站《刘二大人》
1.Softmax Layer
- 在多分类问题中,输出的是每类的概率:

- 计算公式:保证了每类概率大于 0 ,又由保证了概率之和为 1;

-
举例如下:

2.Cross Entropy
- 计算损失:

y = np.array([1, 0, 0]):是目标标签的 one-hot 编码。假设有 3 个类别,这里表示正确的类别是第一个类别;
import numpy as np
y = np.array([1, 0, 0])
z = np.array([0.2, 0.1, -0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss) # 0.9729189131256584
- 交叉熵损失函数:

- y 是一个长度为 1 的长整型张量,是标签类别的 索引,
[0]表示正确的类别是类别 0;
import torch
y = torch.LongTensor([0])
z = torch.Tensor([[0.2, 0.1, -0.1]])
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z, y)
print(loss) # tensor(0.9729)
- Mini - Batch
import torch
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2, 0, 1])Y_pred1 = torch.Tensor([[0.1, 0.2, 0.9],[1.1, 0.1, 0.2],[0.2, 2.1, 0.1]])
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3],[0.2, 0.3, 0.5],[0.2, 0.2, 0.5]])loss1 = criterion(Y_pred1, Y) # Batch Loss1 = tensor(0.4966)
loss2 = criterion(Y_pred2, Y) # Batch Loss2 = tensor(1.2389)
print('Batch Loss1 = ', loss1.data, '\nBatch Loss2 = ', loss2.data)
3.MNIST
- 导包
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
- 准备数据集
ToTensor():将图片转换为PyTorch的张量。Normalize(mean, std):使用指定的均值和标准差对图片进行标准化。


batch_size = 64transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ), (0.3081, ))
])train_dataset = datasets.MNIST('data/MNIST/', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_dataset = datasets.MNIST('data/MNIST/', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
- 构造模型
- 输入层:784个神经元(因为每张图片是28x28,展平后变成784维)。
- 隐藏层:4个全连接层,神经元数量分别为512、256、128和64。
- 输出层:10个神经元,分别对应数字0到9。
- 最后一层不做激活,因为后面调用 torch.nn.CrossEntropyLoss。
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = torch.nn.Linear(784, 512)self.linear2 = torch.nn.Linear(512, 256)self.linear3 = torch.nn.Linear(256, 128)self.linear4 = torch.nn.Linear(128, 64)self.linear5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = F.relu(self.linear3(x))x = F.relu(self.linear4(x))x = self.linear5(x) # 不用激活函数,因为 torch.nn.CrossEntropyLoss = softmax + nlllossreturn xmodel = Net()
- 损失与优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
- 训练与测试
- torch.max:返回最大值和对应的下标。
- dim=1,说明是在行的维度。 0是列,1是行。
# training
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, target = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))running_loss = 0.0# test
def test():correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, labels = dataoutputs = model(inputs)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' %(100*correct/total))if __name__ == '__main__':for epoch in range(10):train(epoch)if epoch % 10 == 0:test()相关文章:
《Pytorch深度学习实践》ch8-多分类
------B站《刘二大人》 1.Softmax Layer 在多分类问题中,输出的是每类的概率: 计算公式:保证了每类概率大于 0 ,又由保证了概率之和为 1; 举例如下: 2.Cross Entropy 计算损失: y np.array…...
国产录播一体机:科技赋能智慧教育信息化
在数字化时代,教育正经历着前所未有的变革。国产工控机作为信息化教育的核心载体,正在重新定义学习方式,赋能教师与学生,打造高效、互动、智能的教学环境,让我们一起感受科技与教育的深度融合!高能计算机推…...
关于逻辑回归的见解
逻辑回归通过将线性回归的输出映射到 [ 0 , 1 ] \left[0,1\right] [0,1]区间,来表示某个类别的概率。也就是其本质是先通过线性回归的预测值 y \boldsymbol{y} y输入到映射函数,既将线性回归的输出通过映射函数映射到 [ 0 , 1 ] \left[0,1\right] [0,1].常用的映射函数是sigm…...
Amazon Augmented AI:人类智慧与AI协作,破解机器学习审核难题
在人工智能日益渗透业务核心的今天,你是否遭遇过这样的困境:自动化AI处理海量数据时,面对模糊、复杂或高风险的场景频频“卡壳”?人工审核团队则被低效、重复的任务压得喘不过气?Amazon Augmented AI (A2I) 的诞生&…...
CMake入门:3、变量操作 set 和 list
在 CMake 中,set 和 list 是两个核心命令,用于变量管理和列表操作。理解它们的用法对于编写高效的 CMakeLists.txt 文件至关重要。下面详细介绍这两个命令的功能和常见用法: 一、set 命令:变量定义与赋值 set 命令用于创建、修改…...
聊聊FlaUI:让Windows UI自动化测试优雅起飞!
你还在为手动点点点测试Windows应用而感到膝盖疼?更愁于自动化测试工具价格贵得让钱包瑟瑟发抖?今天,我要给你安利一款“野路子有余,正经事儿也能干”的.NET UI自动化神器——FlaUI!别眨眼,看完你能少加三个…...
VIN码车辆识别码解析接口如何用C#进行调用?
一、什么是VIN码车辆识别码解析接口 输入17位vin码,获取到车辆的品牌、型号、出厂日期、发动机类型、驱动类型、车型、年份等信息。无论是汽车电商平台、二手车商、维修厂,还是保险公司、金融机构,都能通过接入该API实现信息自动化、决策智能…...
[论文阅读] 人工智能 | 用大语言模型解决软件元数据“身份谜题”:科研软件的“认脸”新方案
用大语言模型解决软件元数据“身份谜题”:科研软件的“认脸”新方案 论文信息 作者: Eva Martn del Pico, Josep Llus Gelp, Salvador Capella-Gutirrez 标题: Identity resolution of software metadata using Large Language Models 年份: 2025 来源: arX…...
gorm多租户插件的使用
一、关于gorm多租户插件的使用 1、安装依赖 go get -u github.com/kuangshp/gorm-tenant2、创建一个mysql数据表 DROP TABLE IF EXISTS user; CREATE TABLE user (id int(11) NOT NULL AUTO_INCREMENT primary key COMMENT 主键id,name varchar(50) not null comment 名称,ten…...
Playwright 测试框架 - Java
🚀【Playwright + Java 实战教程】从零到一掌握自动化测试利器! 🔧 本文专为 Java 开发者量身打造,通过详尽示例带你快速掌握 Playwright 自动化测试。涵盖基础操作、表单交互、测试框架集成、高阶功能及常见实战技巧,适用于企业 UI 测试与 CI/CD 场景。 🛠️ 一、环境…...
力扣100题之128. 最长连续序列
方法1 使用了hash 方法思路 使用哈希集合:首先将数组中的所有数字存入一个哈希集合中,这样可以在 O(1) 时间内检查某个数字是否存在。 寻找连续序列:遍历数组中的每一个数字,对于每一个数字, 检查它是否是某个连续序列…...
算法打卡12天
19.链表相交 (力扣面试题 02.07. 链表相交) 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点,返回 null 。 图示两个链表在节点 c1 开始相交**:** 题目数据…...
OpenCV C++ 学习笔记(四):图像/视频的输入输出(highgui模块 高层GUI和媒体I/O)
文章目录 图片读取创建窗口图片显示图片保存视频输入输出 图片读取 cv::Mat imread( const String& filename, int flags IMREAD_COLOR );enum ImreadModes {IMREAD_UNCHANGED -1, //!< If set, return the loaded image as is (with alpha channel, othe…...
我的创作纪念日——聊聊我想成为一个创作者的动机
2025年6月4日,是我在CSDN写下第一篇技术博客的第1024天。 1024,这个数字对于程序员来说意义非凡,它不仅是内存单位的基础,更是我们这群“码农”的节日符号。而对我来说,它更像是一段旅程的里程碑:从一个曾想…...
蓝桥杯国赛训练 day1 Java大学B组
目录 k倍区间 舞狮 交换瓶子 k倍区间 取模后算组合数就行 import java.util.HashMap; import java.util.Map; import java.util.Scanner;public class Main {static Scanner sc new Scanner(System.in);public static void main(String[] args) {solve();}public static vo…...
PyTorch——非线性激活(5)
非线性激活函数的作用是让神经网络能够理解更复杂的模式和规律。如果没有非线性激活函数,神经网络就只能进行简单的加法和乘法运算,没法处理复杂的问题。 非线性变化的目的就是给我们的网络当中引入一些非线性特征 Relu 激活函数 Relu处理图像 # 导入必…...
OPenCV CUDA模块目标检测----- HOG 特征提取和目标检测类cv::cuda::HOG
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::cuda::HOG 是 OpenCV 的 CUDA 模块中对 HOG 特征提取和目标检测 提供的 GPU 实现。它与 CPU 版本的 cv::HOGDescriptor 类似,但利…...
MATLAB读取文件内容:Excel、CSV和TXT文件解析
MATLAB读取文件内容:Excel、CSV和TXT文件解析 MATLAB 是一款强大的数学与工程计算工具,广泛应用于数据分析、模型构建和图像处理等领域。在处理实际问题时,我们常常需要从文件中读取数据进行分析。本文将介绍如何使用 MATLAB 读取常见的文件…...
Spring MVC 之 异常处理
使用Spring MVC可以很灵活地完成数据的绑定和响应,极大的简化了Java Web的开发。但Spring MVC提供的便利不仅仅如此,使用Spring MVC还可以很便捷地完成项目中的异常处理、自定义拦截器以及文件上传和下载等高级功能。本章将对Spring MVC提供的这些高级功…...
缓存控制HTTP标头设置为“无缓存、无存储、必须重新验证”
文章目录 说明示例核心响应头设置实现原理代码实现1. 原生 Node.js (使用 http 模块)2. Express 框架3. 针对特定路由设置 (Express) 验证方法(使用 cURL)关键注意事项 说明 日期:2025年6月4日。 对于安全内容,请确保缓存控制HT…...
ubuntu24.04 使用apt指令只下载不安装软件
比如我想下载net-tools工具包及其依赖包可以如下指令 apt --download-only install net-tools 自动下载的软件包在/var/cache/apt/archives/目录下...
macOS 上使用 Homebrew 安装redis-cli
在 macOS 上使用 Homebrew 安装 redis-cli(Redis 命令行工具)非常简单,以下是详细步骤: 1. 安装 Redis(包含 redis-cli) 运行以下命令安装 Redis: brew install redis这会安装完整的 Redis 服…...
计算机网络安全问答数据集(1788条) ,AI智能体知识库收集! AI大模型训练数据!
继续收集数据集,话不多说,见下文! 今天分享一个计算机网络安全问答数据集(1788条),适用于AI大模型训练、智能体知识库构建、安全教育系统开发等多种场景! 一、数据特点 结构清晰:共计1788条&…...
WinCC学习系列-高阶应用(WinCC REST通信)
WinCC作为一个经典SCADA系统,它是OT与IT数据无缝集成桥梁,自WinCC7.5版本开始,可以直接提供Rest服务用于其它系统数据访问和操作。 WinCC REST 服务允许外部应用程序访问 WinCC 数据。 外部应用程序可以通过 REST 接口读取和写入 WinCC 组态…...
八、Python模块、包
目录 1. 模块 1.1 什么是模块? 1.2 创建模块 1.3 导入模块 1.4 模块的命名空间 1.5 模块的搜索路径 1.6 模块的重新加载 2. 包 2.1 什么是包? 2.2 创建包 2.3 导入包中的模块 2.4 包的层次结构 3. 模块和包的管理 3.1 安装模块 3.2 卸载模…...
使用交叉编译工具提示stubs-32.h:7:11: fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法
0 前言 使用ST官方SDK提供的交叉编译工具、cmake生成Makefile,使用make命令生成可执行文件提示fatal error: gnu/stubs-soft.h: 没有那个文件或目录的解决办法,如下所示: 根据这一错误提示,按照网上的解决方案逐一尝试均以失败告…...
macOS 连接 Docker 运行 postgres,使用navicat添加并关联数据库
下载 docker注册一个账号,登录 Docker创建 docke r文件 mkdir -p ~/.docker && touch ~/.docker/daemon.json写入配置(全量替换) {"builder": {"gc": {"defaultKeepStorage": "20GB",&quo…...
指针的使用——基本数据类型、数组、结构体
1 引言 对于学习指针要弄清楚如下问题基本可以应付大部分的场景: ① 指针是什么? ② 指针的类型是什么? ③ 指针指向的类型是什么? ④ 指针指向了哪里? 2 如何使用指针 任何东西的学习最好可以总结成一种通用化的…...
TK海外抢单源码/指定卡单
抢单源码,有指定派单,打针,这套二改过充值跳转客服 前端vue 后端php 两端分离 可二开 可以指定卡第几单,金额多少, 前后端开源 PHP7.2 MySQL5.6 前端要www.域名,后端要admin.域名 前端直接静态 伪静…...
Docker MCP 目录和工具包简介:使用 MCP 为 AI 代理提供支持的简单安全方法
目录 Model Context Protocol 势头强劲 — 还需要改进哪些?发现正确的、官方的和/或值得信赖的工具是很困难的复杂的安装和分发身份验证和权限不足Docker 如何帮助解决这些挑战在安全、隔离的容器中轻松发现和运行 MCP 服务器一键式 MCP 客户端集成,内置安全认证企业就绪的 M…...
