当前位置: 首页 > news >正文

Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

相关文章:

Pytorch深度学习实践(5)逻辑回归

逻辑回归 逻辑回归主要是解决分类问题 回归任务:结果是一个连续的实数分类任务:结果是一个离散的值 分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 0−−9),因为各个类别…...

认识漏洞-GitLab 远程命令执行漏洞、致远OA-ajax.do未授权任意文件上传漏洞

为方便您的阅读,可点击下方蓝色字体,进行跳转↓↓↓ 01 [GitLab 远程命令执行漏洞复现(CVE-2021-22205)](https://mp.weixin.qq.com/s/4QT-vxKpBn4ppNM9ipt-nQ)02 [致远OA-ajax.do未授权任意文件上传Getshell](https://mp.weixin.qq.com/s/TH2A5J5TXU36Y…...

vue实现电子签名、图片合成、及预览功能

业务功能:电子签名、图片合成、及预览功能 业务背景:需求说想要实现一个电子签名,然后需要提供一个预览的功能,可以查看签完名之后的完整效果。 需求探讨:后端大佬跟我说,文档我返回给你一个PDF的oss链接…...

【flink】之如何消费kafka数据?

为了编写一个使用Apache Flink来读取Apache Kafka消息的示例,我们需要确保我们的环境已经安装了Flink和Kafka,并且它们都能正常运行。此外,我们还需要在项目中引入相应的依赖库。以下是一个详细的步骤指南,包括依赖添加、代码编写…...

科研绘图系列:R语言山脊图(Ridgeline Chart)

介绍 山脊图(Ridge Chart)是一种用于展示数据分布和比较不同类别或组之间差异的数据可视化技术。它通常用于展示多个维度或变量之间的关系,以及它们在不同组中的分布情况。山脊图的特点: 多变量展示:山脊图可以同时展示多个变量的分布情况,允许用户比较不同变量之间的关…...

Boost搜索引擎:如何建立 用户搜索内容 与 网页文件内容 之间的关系

如果想使“用户搜索内容”和“网页文件内容”之间产生联系,就应该将“用户搜索内容”和“网页文件”分为很小的单元 (这个单元就是关键词),寻找用户搜索单元是否出现在这个文档之中,如果出现就证明这个网页文件和用户搜…...

【QT】QT 窗口(菜单栏、工具栏、状态栏、浮动窗口、对话框)

Qt 窗口是通过 QMainWindow类来实现的。 QMainWindow 是一个为用户提供主窗口程序的类,继承自 QWidget 类,并且提供了⼀个预定义的布局。QMainWindow 包含一个菜单栏(Menu Bar)、多个工具栏(Tool Bars)、…...

Golang | Leetcode Golang题解之第283题移动零

题目&#xff1a; 题解&#xff1a; func moveZeroes(nums []int) {left, right, n : 0, 0, len(nums)for right < n {if nums[right] ! 0 {nums[left], nums[right] nums[right], nums[left]left}right} }...

ubuntu22.04 安装 NVIDIA 驱动以及CUDA

目录 1、事前问题解决 2、安装 nvidia 驱动 3、卸载 nvidia 驱动方法 4、安装 CUDA 5、安装 Anaconda 6、安装 PyTorch 1、事前问题解决 在安装完ubuntu之后&#xff0c;如果进入ubuntu出现黑屏情况&#xff0c;一般就是nvidia驱动与linux自带的不兼容&#xff0c;可以通…...

数据结构·AVL树

1. AVL树的概念 二叉搜索树虽可以缩短查找的效率&#xff0c;但如果存数据时接近有序&#xff0c;二叉搜索将退化为单支树&#xff0c;此时查找元素效率相当于在顺序表中查找&#xff0c;效率低下。因此两位俄罗斯数学家 G.M.Adelson-Velskii 和E.M.Landis 在1962年发明了一种解…...

记一次Mycat分库分表实践

接了个活,又搞分库分表。 一、分库分表 在系统的研发过程中,随着数据量的不断增长,单库单表已无法满足数据的存储需求,此时就需要对数据库进行分库分表操作。 分库分表是随着业务的不断发展,单库单表无法承载整体的数据存储时,采取的一种将整体数据分散存储到不同服务…...

数据分析:微生物数据的荟萃分析框架

介绍 Meta-analysis of fecal metagenomes reveals global microbial signatures that are specific for colorectal cancer提供了一种荟萃分析的框架&#xff0c;它主要基于常用的Wilcoxon rank-sum test和Blocked Wilcoxon rank-sum test 方法计算显著性&#xff0c;再使用分…...

Django—admin后台管理

Django官网 https://www.djangoproject.com/ 如果已经有了Django跳过这步 安装Django&#xff1a; 如果你还没有安装Django&#xff0c;可以通过Python的包管理器pip来安装&#xff1a; pip install django 创建项目&#xff1a; 使用Django创建一个新的项目&#xff1a; …...

数字图像处理中的常用特殊矩阵及MATLAB应用

一、前言 Matlab的名称来源于“矩阵实验室&#xff08;Matrix Laboratory&#xff09;”&#xff0c;其对矩阵的操作具有先天性的优势&#xff08;特别是相对于C语言的数组来说&#xff09;。在数字图像处理中&#xff0c;为了提高编程效率&#xff0c;我们可以使用多种方式来创…...

vue侦听器(Watch)精彩案例剖析一

目录 watch介绍 监视普通数据类型 监视对象类型 watch介绍 在 Vue 中,watch主要用于监视数据的变化,并执行相应操作。一旦被监视的属性发生变化,回调函数将自动被触发。当在 Vue 中使用watch来响应数据变化时,首先要清楚,watch本质上是一个对象,且必须以对象的…...

HTTP 协议浅析

HTTP&#xff08;HyperText Transfer Protocol&#xff0c;超文本传输协议&#xff09;是应用层最重要的协议之一。它定义了客户端和服务器之间的数据传输方式&#xff0c;并成为万维网&#xff08;World Wide Web&#xff09;的基石。本文将深入解析 HTTP 协议的基础知识、工作…...

VsCode | 让空文件夹始终展开不折叠

文章目录 1 问题引入2 解决办法3 效果展示 1 问题引入 可能很多小伙伴更新VsCode或者下载新版本时候 &#xff0c;创建的文件 会出现xxx文件夹/xxx文件夹&#xff0c;看着很不舒服&#xff0c;所以该如何展开所有空文件夹呢&#xff1f; 2 解决办法 找到VsCode的设置 &…...

Centos7_Minimal安装Cannot find a valid baseurl for repo: base/7/x86_6

问题 运行yum报此问题 就是没网 解决方法 修改网络信息配置文件&#xff0c;打开配置文件&#xff0c;输入命令&#xff1a; vi /etc/sysconfig/network-scripts/ifcfg-网卡名字把ONBOOTno&#xff0c;改为ONBOOTyes 重启网卡 /etc/init.d/network restart 网路通了...

Spark_Oracle_II_Spark高效处理Oracle时间数据:通过JDBC桥接大数据与数据库的分析之旅

接前文背景&#xff0c; 当需要从关系型数据库&#xff08;如Oracle&#xff09;中读取数据时&#xff0c;Spark提供了JDBC连接功能&#xff0c;允许我们轻松地将数据从Oracle等数据库导入到Spark DataFrame中。然而&#xff0c;在处理时间字段时&#xff0c;可能会遇到一些挑战…...

力扣 459重复的子字符串

思路&#xff1a; KMP算法的核心是求next数组 next数组代表的是当前字符串最大前后缀的长度 而求重复的子字符串就是求字符串的最大前缀与最大后缀之间的子字符串 如果这个子字符串是字符串长度的约数&#xff0c;则true /** lc appleetcode.cn id459 langcpp** [459] 重复…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

Unity中的transform.up

2025年6月8日&#xff0c;周日下午 在Unity中&#xff0c;transform.up是Transform组件的一个属性&#xff0c;表示游戏对象在世界空间中的“上”方向&#xff08;Y轴正方向&#xff09;&#xff0c;且会随对象旋转动态变化。以下是关键点解析&#xff1a; 基本定义 transfor…...

【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统

Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...

第八部分:阶段项目 6:构建 React 前端应用

现在&#xff0c;是时候将你学到的 React 基础知识付诸实践&#xff0c;构建一个简单的前端应用来模拟与后端 API 的交互了。在这个阶段&#xff0c;你可以先使用模拟数据&#xff0c;或者如果你的后端 API&#xff08;阶段项目 5&#xff09;已经搭建好&#xff0c;可以直接连…...

Java数组Arrays操作全攻略

Arrays类的概述 Java中的Arrays类位于java.util包中&#xff0c;提供了一系列静态方法用于操作数组&#xff08;如排序、搜索、填充、比较等&#xff09;。这些方法适用于基本类型数组和对象数组。 常用成员方法及代码示例 排序&#xff08;sort&#xff09; 对数组进行升序…...