Pytorch深度学习实践(5)逻辑回归
逻辑回归
逻辑回归主要是解决分类问题
- 回归任务:结果是一个连续的实数
- 分类任务:结果是一个离散的值
分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 0−−9),因为各个类别之间没有大小之差。
因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果
下载MNIST
数据集
import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
- 通过
train
参数来指定训练集和测试集
逻辑回归
将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试
x(hours) | y(pass/fail) |
---|---|
1 | 0(fail) |
2 | 0(fail) |
3 | 1(pass) |
4 | ? |
其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1
当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定
的值
对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率
s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+e−x1
S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:
- 饱和函数
- 单调递增
- 有极限
当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(x∗ω+b)
损失函数
线性回归使用的损失函数是计算预测值和真实值之差
而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 0−1分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=−(ylogy^+(1−y)log(1−y^))
即我们比较的是分布之间的差异
交叉熵 c r o s s − e n t r o p y cross-entropy cross−entropy
存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x)和 P D 2 ( x ) P_{D2}(x) PD2(x)
两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) ∑i=1nPD1(xi)lnPD2(xi) 来衡量
上述公式越大时,两个分布的差异越小
模型的改变
模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
需要先将输入写入到linear()
线性模型中,再使用Sigmoid()
函数
模型损失函数的改变
使用交叉熵函数BCELoss
criterion = torch.nn.BCELoss(size_average=False)
整体代码
import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy() # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()
相关文章:

Pytorch深度学习实践(5)逻辑回归
逻辑回归 逻辑回归主要是解决分类问题 回归任务:结果是一个连续的实数分类任务:结果是一个离散的值 分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 0−−9),因为各个类别…...

认识漏洞-GitLab 远程命令执行漏洞、致远OA-ajax.do未授权任意文件上传漏洞
为方便您的阅读,可点击下方蓝色字体,进行跳转↓↓↓ 01 [GitLab 远程命令执行漏洞复现(CVE-2021-22205)](https://mp.weixin.qq.com/s/4QT-vxKpBn4ppNM9ipt-nQ)02 [致远OA-ajax.do未授权任意文件上传Getshell](https://mp.weixin.qq.com/s/TH2A5J5TXU36Y…...

vue实现电子签名、图片合成、及预览功能
业务功能:电子签名、图片合成、及预览功能 业务背景:需求说想要实现一个电子签名,然后需要提供一个预览的功能,可以查看签完名之后的完整效果。 需求探讨:后端大佬跟我说,文档我返回给你一个PDF的oss链接…...
【flink】之如何消费kafka数据?
为了编写一个使用Apache Flink来读取Apache Kafka消息的示例,我们需要确保我们的环境已经安装了Flink和Kafka,并且它们都能正常运行。此外,我们还需要在项目中引入相应的依赖库。以下是一个详细的步骤指南,包括依赖添加、代码编写…...

科研绘图系列:R语言山脊图(Ridgeline Chart)
介绍 山脊图(Ridge Chart)是一种用于展示数据分布和比较不同类别或组之间差异的数据可视化技术。它通常用于展示多个维度或变量之间的关系,以及它们在不同组中的分布情况。山脊图的特点: 多变量展示:山脊图可以同时展示多个变量的分布情况,允许用户比较不同变量之间的关…...
Boost搜索引擎:如何建立 用户搜索内容 与 网页文件内容 之间的关系
如果想使“用户搜索内容”和“网页文件内容”之间产生联系,就应该将“用户搜索内容”和“网页文件”分为很小的单元 (这个单元就是关键词),寻找用户搜索单元是否出现在这个文档之中,如果出现就证明这个网页文件和用户搜…...

【QT】QT 窗口(菜单栏、工具栏、状态栏、浮动窗口、对话框)
Qt 窗口是通过 QMainWindow类来实现的。 QMainWindow 是一个为用户提供主窗口程序的类,继承自 QWidget 类,并且提供了⼀个预定义的布局。QMainWindow 包含一个菜单栏(Menu Bar)、多个工具栏(Tool Bars)、…...

Golang | Leetcode Golang题解之第283题移动零
题目: 题解: func moveZeroes(nums []int) {left, right, n : 0, 0, len(nums)for right < n {if nums[right] ! 0 {nums[left], nums[right] nums[right], nums[left]left}right} }...

ubuntu22.04 安装 NVIDIA 驱动以及CUDA
目录 1、事前问题解决 2、安装 nvidia 驱动 3、卸载 nvidia 驱动方法 4、安装 CUDA 5、安装 Anaconda 6、安装 PyTorch 1、事前问题解决 在安装完ubuntu之后,如果进入ubuntu出现黑屏情况,一般就是nvidia驱动与linux自带的不兼容,可以通…...

数据结构·AVL树
1. AVL树的概念 二叉搜索树虽可以缩短查找的效率,但如果存数据时接近有序,二叉搜索将退化为单支树,此时查找元素效率相当于在顺序表中查找,效率低下。因此两位俄罗斯数学家 G.M.Adelson-Velskii 和E.M.Landis 在1962年发明了一种解…...

记一次Mycat分库分表实践
接了个活,又搞分库分表。 一、分库分表 在系统的研发过程中,随着数据量的不断增长,单库单表已无法满足数据的存储需求,此时就需要对数据库进行分库分表操作。 分库分表是随着业务的不断发展,单库单表无法承载整体的数据存储时,采取的一种将整体数据分散存储到不同服务…...

数据分析:微生物数据的荟萃分析框架
介绍 Meta-analysis of fecal metagenomes reveals global microbial signatures that are specific for colorectal cancer提供了一种荟萃分析的框架,它主要基于常用的Wilcoxon rank-sum test和Blocked Wilcoxon rank-sum test 方法计算显著性,再使用分…...

Django—admin后台管理
Django官网 https://www.djangoproject.com/ 如果已经有了Django跳过这步 安装Django: 如果你还没有安装Django,可以通过Python的包管理器pip来安装: pip install django 创建项目: 使用Django创建一个新的项目: …...

数字图像处理中的常用特殊矩阵及MATLAB应用
一、前言 Matlab的名称来源于“矩阵实验室(Matrix Laboratory)”,其对矩阵的操作具有先天性的优势(特别是相对于C语言的数组来说)。在数字图像处理中,为了提高编程效率,我们可以使用多种方式来创…...
vue侦听器(Watch)精彩案例剖析一
目录 watch介绍 监视普通数据类型 监视对象类型 watch介绍 在 Vue 中,watch主要用于监视数据的变化,并执行相应操作。一旦被监视的属性发生变化,回调函数将自动被触发。当在 Vue 中使用watch来响应数据变化时,首先要清楚,watch本质上是一个对象,且必须以对象的…...
HTTP 协议浅析
HTTP(HyperText Transfer Protocol,超文本传输协议)是应用层最重要的协议之一。它定义了客户端和服务器之间的数据传输方式,并成为万维网(World Wide Web)的基石。本文将深入解析 HTTP 协议的基础知识、工作…...

VsCode | 让空文件夹始终展开不折叠
文章目录 1 问题引入2 解决办法3 效果展示 1 问题引入 可能很多小伙伴更新VsCode或者下载新版本时候 ,创建的文件 会出现xxx文件夹/xxx文件夹,看着很不舒服,所以该如何展开所有空文件夹呢? 2 解决办法 找到VsCode的设置 &…...

Centos7_Minimal安装Cannot find a valid baseurl for repo: base/7/x86_6
问题 运行yum报此问题 就是没网 解决方法 修改网络信息配置文件,打开配置文件,输入命令: vi /etc/sysconfig/network-scripts/ifcfg-网卡名字把ONBOOTno,改为ONBOOTyes 重启网卡 /etc/init.d/network restart 网路通了...

Spark_Oracle_II_Spark高效处理Oracle时间数据:通过JDBC桥接大数据与数据库的分析之旅
接前文背景, 当需要从关系型数据库(如Oracle)中读取数据时,Spark提供了JDBC连接功能,允许我们轻松地将数据从Oracle等数据库导入到Spark DataFrame中。然而,在处理时间字段时,可能会遇到一些挑战…...

力扣 459重复的子字符串
思路: KMP算法的核心是求next数组 next数组代表的是当前字符串最大前后缀的长度 而求重复的子字符串就是求字符串的最大前缀与最大后缀之间的子字符串 如果这个子字符串是字符串长度的约数,则true /** lc appleetcode.cn id459 langcpp** [459] 重复…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...

黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

【2025年】解决Burpsuite抓不到https包的问题
环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...