Pytorch深度学习实践(5)逻辑回归
逻辑回归
逻辑回归主要是解决分类问题
- 回归任务:结果是一个连续的实数
- 分类任务:结果是一个离散的值
分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 0−−9),因为各个类别之间没有大小之差。
因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果
下载MNIST数据集
import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
- 通过
train参数来指定训练集和测试集
逻辑回归
将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试
| x(hours) | y(pass/fail) |
|---|---|
| 1 | 0(fail) |
| 2 | 0(fail) |
| 3 | 1(pass) |
| 4 | ? |
其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1
当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值
对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率
s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+e−x1

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:
- 饱和函数
- 单调递增
- 有极限
当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(x∗ω+b)
损失函数
线性回归使用的损失函数是计算预测值和真实值之差
而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 0−1分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=−(ylogy^+(1−y)log(1−y^))
即我们比较的是分布之间的差异
交叉熵 c r o s s − e n t r o p y cross-entropy cross−entropy
存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x)和 P D 2 ( x ) P_{D2}(x) PD2(x)
两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) ∑i=1nPD1(xi)lnPD2(xi) 来衡量
上述公式越大时,两个分布的差异越小
模型的改变
模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数
模型损失函数的改变
使用交叉熵函数BCELoss
criterion = torch.nn.BCELoss(size_average=False)
整体代码
import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy() # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

相关文章:
Pytorch深度学习实践(5)逻辑回归
逻辑回归 逻辑回归主要是解决分类问题 回归任务:结果是一个连续的实数分类任务:结果是一个离散的值 分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 0−−9),因为各个类别…...
认识漏洞-GitLab 远程命令执行漏洞、致远OA-ajax.do未授权任意文件上传漏洞
为方便您的阅读,可点击下方蓝色字体,进行跳转↓↓↓ 01 [GitLab 远程命令执行漏洞复现(CVE-2021-22205)](https://mp.weixin.qq.com/s/4QT-vxKpBn4ppNM9ipt-nQ)02 [致远OA-ajax.do未授权任意文件上传Getshell](https://mp.weixin.qq.com/s/TH2A5J5TXU36Y…...
vue实现电子签名、图片合成、及预览功能
业务功能:电子签名、图片合成、及预览功能 业务背景:需求说想要实现一个电子签名,然后需要提供一个预览的功能,可以查看签完名之后的完整效果。 需求探讨:后端大佬跟我说,文档我返回给你一个PDF的oss链接…...
【flink】之如何消费kafka数据?
为了编写一个使用Apache Flink来读取Apache Kafka消息的示例,我们需要确保我们的环境已经安装了Flink和Kafka,并且它们都能正常运行。此外,我们还需要在项目中引入相应的依赖库。以下是一个详细的步骤指南,包括依赖添加、代码编写…...
科研绘图系列:R语言山脊图(Ridgeline Chart)
介绍 山脊图(Ridge Chart)是一种用于展示数据分布和比较不同类别或组之间差异的数据可视化技术。它通常用于展示多个维度或变量之间的关系,以及它们在不同组中的分布情况。山脊图的特点: 多变量展示:山脊图可以同时展示多个变量的分布情况,允许用户比较不同变量之间的关…...
Boost搜索引擎:如何建立 用户搜索内容 与 网页文件内容 之间的关系
如果想使“用户搜索内容”和“网页文件内容”之间产生联系,就应该将“用户搜索内容”和“网页文件”分为很小的单元 (这个单元就是关键词),寻找用户搜索单元是否出现在这个文档之中,如果出现就证明这个网页文件和用户搜…...
【QT】QT 窗口(菜单栏、工具栏、状态栏、浮动窗口、对话框)
Qt 窗口是通过 QMainWindow类来实现的。 QMainWindow 是一个为用户提供主窗口程序的类,继承自 QWidget 类,并且提供了⼀个预定义的布局。QMainWindow 包含一个菜单栏(Menu Bar)、多个工具栏(Tool Bars)、…...
Golang | Leetcode Golang题解之第283题移动零
题目: 题解: func moveZeroes(nums []int) {left, right, n : 0, 0, len(nums)for right < n {if nums[right] ! 0 {nums[left], nums[right] nums[right], nums[left]left}right} }...
ubuntu22.04 安装 NVIDIA 驱动以及CUDA
目录 1、事前问题解决 2、安装 nvidia 驱动 3、卸载 nvidia 驱动方法 4、安装 CUDA 5、安装 Anaconda 6、安装 PyTorch 1、事前问题解决 在安装完ubuntu之后,如果进入ubuntu出现黑屏情况,一般就是nvidia驱动与linux自带的不兼容,可以通…...
数据结构·AVL树
1. AVL树的概念 二叉搜索树虽可以缩短查找的效率,但如果存数据时接近有序,二叉搜索将退化为单支树,此时查找元素效率相当于在顺序表中查找,效率低下。因此两位俄罗斯数学家 G.M.Adelson-Velskii 和E.M.Landis 在1962年发明了一种解…...
记一次Mycat分库分表实践
接了个活,又搞分库分表。 一、分库分表 在系统的研发过程中,随着数据量的不断增长,单库单表已无法满足数据的存储需求,此时就需要对数据库进行分库分表操作。 分库分表是随着业务的不断发展,单库单表无法承载整体的数据存储时,采取的一种将整体数据分散存储到不同服务…...
数据分析:微生物数据的荟萃分析框架
介绍 Meta-analysis of fecal metagenomes reveals global microbial signatures that are specific for colorectal cancer提供了一种荟萃分析的框架,它主要基于常用的Wilcoxon rank-sum test和Blocked Wilcoxon rank-sum test 方法计算显著性,再使用分…...
Django—admin后台管理
Django官网 https://www.djangoproject.com/ 如果已经有了Django跳过这步 安装Django: 如果你还没有安装Django,可以通过Python的包管理器pip来安装: pip install django 创建项目: 使用Django创建一个新的项目: …...
数字图像处理中的常用特殊矩阵及MATLAB应用
一、前言 Matlab的名称来源于“矩阵实验室(Matrix Laboratory)”,其对矩阵的操作具有先天性的优势(特别是相对于C语言的数组来说)。在数字图像处理中,为了提高编程效率,我们可以使用多种方式来创…...
vue侦听器(Watch)精彩案例剖析一
目录 watch介绍 监视普通数据类型 监视对象类型 watch介绍 在 Vue 中,watch主要用于监视数据的变化,并执行相应操作。一旦被监视的属性发生变化,回调函数将自动被触发。当在 Vue 中使用watch来响应数据变化时,首先要清楚,watch本质上是一个对象,且必须以对象的…...
HTTP 协议浅析
HTTP(HyperText Transfer Protocol,超文本传输协议)是应用层最重要的协议之一。它定义了客户端和服务器之间的数据传输方式,并成为万维网(World Wide Web)的基石。本文将深入解析 HTTP 协议的基础知识、工作…...
VsCode | 让空文件夹始终展开不折叠
文章目录 1 问题引入2 解决办法3 效果展示 1 问题引入 可能很多小伙伴更新VsCode或者下载新版本时候 ,创建的文件 会出现xxx文件夹/xxx文件夹,看着很不舒服,所以该如何展开所有空文件夹呢? 2 解决办法 找到VsCode的设置 &…...
Centos7_Minimal安装Cannot find a valid baseurl for repo: base/7/x86_6
问题 运行yum报此问题 就是没网 解决方法 修改网络信息配置文件,打开配置文件,输入命令: vi /etc/sysconfig/network-scripts/ifcfg-网卡名字把ONBOOTno,改为ONBOOTyes 重启网卡 /etc/init.d/network restart 网路通了...
Spark_Oracle_II_Spark高效处理Oracle时间数据:通过JDBC桥接大数据与数据库的分析之旅
接前文背景, 当需要从关系型数据库(如Oracle)中读取数据时,Spark提供了JDBC连接功能,允许我们轻松地将数据从Oracle等数据库导入到Spark DataFrame中。然而,在处理时间字段时,可能会遇到一些挑战…...
力扣 459重复的子字符串
思路: KMP算法的核心是求next数组 next数组代表的是当前字符串最大前后缀的长度 而求重复的子字符串就是求字符串的最大前缀与最大后缀之间的子字符串 如果这个子字符串是字符串长度的约数,则true /** lc appleetcode.cn id459 langcpp** [459] 重复…...
网络编程(Modbus进阶)
思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...
【JavaEE】-- HTTP
1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...
基于TurtleBot3在Gazebo地图实现机器人远程控制
1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...
LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》
这段 Python 代码是一个完整的 知识库数据库操作模块,用于对本地知识库系统中的知识库进行增删改查(CRUD)操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 📘 一、整体功能概述 该模块…...
宇树科技,改名了!
提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...
tomcat入门
1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效,稳定,易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...
破解路内监管盲区:免布线低位视频桩重塑停车管理新标准
城市路内停车管理常因行道树遮挡、高位设备盲区等问题,导致车牌识别率低、逃费率高,传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法,正成为破局关键。该设备安装于车位侧方0.5-0.7米高度,直接规避树枝遮…...
HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散
前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说,在叠衣服的过程中,我会带着团队对比各种模型、方法、策略,毕竟针对各个场景始终寻找更优的解决方案,是我个人和我司「七月在线」的职责之一 且个人认为,…...
人工智能 - 在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型
在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型。这些平台各有侧重,适用场景差异显著。下面我将从核心功能定位、典型应用场景、真实体验痛点、选型决策关键点进行拆解,并提供具体场景下的推荐方案。 一、核心功能定位速览 平台核心定位技术栈亮…...
基于stm32F10x 系列微控制器的智能电子琴(附完整项目源码、详细接线及讲解视频)
注:文章末尾网盘链接中自取成品使用演示视频、项目源码、项目文档 所用硬件:STM32F103C8T6、无源蜂鸣器、44矩阵键盘、flash存储模块、OLED显示屏、RGB三色灯、面包板、杜邦线、usb转ttl串口 stm32f103c8t6 面包板 …...
