当前位置: 首页 > news >正文

用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦什么是逻辑回归?
  • 🥦分类问题
  • 🥦交叉熵
  • 🥦代码实现
  • 🥦总结

🥦引言

当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。

在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架

🥦什么是逻辑回归?

我们首先来回顾一下什么是逻辑回归?

逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。

🥦分类问题

这里以MINIST Dataset手写数字集为例
在这里插入图片描述

这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。

这里介绍一下与torch相关联的库—torchvision
torchvision:

  • “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
    它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。
  • “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
    此外,它还包括了用于图像预处理、转换和可视化的函数。

上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了

我们再来看一个数据集(CIFAR-10)
在这里插入图片描述
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。

接下来我们从一张图更加直观的查看分类和回归
在这里插入图片描述

左边的是回归,右边的是分类


在这里插入图片描述

过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。

下图展示一些其他的Sigmoid函数
在这里插入图片描述

🥦交叉熵

过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵

==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。

在交叉熵的上下文中,通常有两个概率分布:

  • 真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。

  • 预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。

交叉熵的一般定义如下:
在这里插入图片描述其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。

交叉熵的主要特点和用途包括:

  • 度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。

  • 损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。

  • 反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。

在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。

刘二大人的PPT中也介绍了
在这里插入图片描述
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值

🥦代码实现

在这里插入图片描述

根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
在这里插入图片描述
这里就是BCEloss的调用,里面的参数代表求不求均值

完整代码如下

import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
model = LogisticRegressionModel() 
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()

最后绘制一下

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))  # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') 
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

运行结果如下
在这里插入图片描述

🥦总结

这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

相关文章:

用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…...

[nltk_data] Error loading stopwords: <urlopen error [WinError 10054]

报错提示&#xff1a; >>> import nltk >>> nltk.download(stopwords) 按照提示执行后 [nltk_data] Error loading stopwords: <urlopen error [WinError 10054] 找到路径C:\\Users\\EDY\\nltk_data&#xff0c;如果没有nltk_data文件夹&#xff0c;在…...

基于Spring Boot的网上租贸系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…...

通过IP地址管理提升企业网络安全防御

在今天的数字时代&#xff0c;企业面临着越来越多的网络安全威胁。这些威胁可能来自各种来源&#xff0c;包括恶意软件、网络攻击和数据泄露。为了提高网络安全防御&#xff0c;企业需要采取一系列措施&#xff0c;其中IP地址管理是一个重要的方面 1. IP地址的基础知识 首先&a…...

termius mac版无需登录注册直接永久使用

1. 下载地址&#xff1a;termius下载 2. 解压安装 3. 当出现 “termius”已损坏,无法打开 则输入以下命令即可&#xff1a;sudo xattr -r -d com.apple.quarantine /Applications/Termius.app 最后去 系统设置-> 隐私与安全性-> 仍要打开 4. 删除app-update.yml文件&…...

TPU编程竞赛|Stable Diffusion大模型巅峰对决,第五届全球校园人工智能算法精英赛正式启动!

目录 赛题介绍 赛题背景 赛题任务 赛程安排 评分机制 奖项设置 近日&#xff0c;2023第五届全球校园人工智能算法精英赛正式开启报名。作为赛题合作方&#xff0c;算丰承办了“算法专项赛”赛道&#xff0c;提供赛题「面向Stable Diffusion的图像提示语优化」&#xff0c…...

微信小程序 rpx 转 px

前言 略 rpx 转 px let query wx.createSelectorQuery(); query.selectViewport().boundingClientRect(function(res){let rpx2Px 1 * (res.width/750);console.log("1rpx " rpx2Px "px"); }); query.exec();参考 https://blog.csdn.net/qq_39702…...

机器学习之旅-从Python 开始

导读你想知道如何开始机器学习吗&#xff1f;在这篇文章中&#xff0c;我将简要概括一下使用 Python 来开始机器学习的一些步骤。Python 是一门流行的开源程序设计语言&#xff0c;也是在人工智能及其它相关科学领域中最常用的语言之一。机器学习简称 ML&#xff0c;是人工智能…...

100天精通Python(可视化篇)——第103天:Pyecharts绘制多种炫酷水球图参数说明+代码实战

文章目录 专栏导读一、水球图介绍1. 水球图是什么?2. 水球图的应用场景二、水球图类配置选项1. 导包2. Liquid类3. add函数三、水球图实战1. 基础水球图2. 矩形水球图3. 圆棱角矩形水球图4. 三角形水球图5. 菱形水球图6. 箭头型水球图7. 修改数据精度8. 设置无边框9. 多个并排…...

好用的文件备份软件推荐!

为什么需要文件备份软件&#xff1f; 在我们使用计算机的日常工作生活中&#xff0c;可能会遇到各种不同类型的文件&#xff0c;例如文档、Word文档、Excel表格、PPT演示文稿、图片等&#xff0c;这些数据中可能有些对我们来说很重要&#xff0c;但是可能会因为一些意外状况…...

1130 - Host ‘192.168.10.10‘ is not allowed to connect to this MysOL server

mysql 远程登录报错误信息&#xff1a;1130 - Host 124.114.155.70 is not allowed to connect to this MysOL server //需要在mysql 数据库目录下修改 use mysql; //更改用户的登录主机为所有主机&#xff0c;%代表所有主机 update user set host% where userroot; //刷新权…...

如何实现 Es 全文检索、高亮文本略缩处理

如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口&#xff08;支持父类属性的复杂映射&#xff09;Trans 接口的不足真实项目落地效果 前言 最近手上在做 Es 全文检索的需求&#xff0c;类似于…...

Netty(四)NIO-优化与源码

Netty优化与源码 1. 优化 1.1 扩展序列化算法 序列化&#xff0c;反序列化主要用于消息正文的转换。 序列化&#xff1a;将java对象转为要传输对象(byte[]或json&#xff0c;最终都是byte[]) 反序列化&#xff1a;将正文还原成java对象。 //java自带的序列化 // 反序列化 b…...

我的创业之路:我为什么选择 Angular 作为前端的开发框架?

我是一名后端开发人员&#xff0c;在上班时我的主要精力集中在搜索和推荐系统的开发和设计工作上&#xff0c;我比较熟悉的语言包括java、golang和python。对于前端技术中typescript、dom、webpack等流行的框架和工具也懂一些。目前&#xff0c;已成为一名自由职业者&#xff0…...

阿里云服务器ECS是什么?云服务器详细介绍

阿里云服务器ECS英文全程Elastic Compute Service&#xff0c;云服务器ECS是一种安全可靠、弹性可伸缩的云计算服务&#xff0c;阿里云提供多种云服务器ECS实例规格&#xff0c;如经济型e实例、通用算力型u1、ECS计算型c7、通用型g7、GPU实例等&#xff0c;阿里云服务器网分享阿…...

深入了解快速排序:原理、性能分析与 Java 实现

快速排序&#xff08;Quick Sort&#xff09;是一种经典的、高效的排序算法&#xff0c;被广泛应用于计算机科学和软件开发领域。本文将深入探讨快速排序的工作原理、步骤以及其在不同情况下的性能表现。 什么是快速排序&#xff1f; 快速排序是一种基于分治策略的排序算法&am…...

[晕事]今天做了件晕事22;寻找99-sysctl.conf; systemd

这个文件&#xff0c;使用ls命令看不出来是一个链接。 然后满世界的找这个文件怎么来的&#xff0c;后来发现是systemd里的一个文件。 从systemd的源文件里也没找到相关的文件信息。 最后把这个rpm安装包下载下来&#xff0c;才找到这个文件原来是一个链接 #ll /etc/sysctl.d/9…...

2578. 最小和分割

给你一个正整数 num &#xff0c;请你将它分割成两个非负整数 num1 和 num2 &#xff0c;满足&#xff1a; num1 和 num2 直接连起来&#xff0c;得到 num 各数位的一个排列。 换句话说&#xff0c;num1 和 num2 中所有数字出现的次数之和等于 num 中所有数字出现的次数。num1…...

Mybatis mapper报错:Class not found: org.jboss.vfs.VFS

报错 Logging initialized using class org.apache.ibatis.logging.stdout.StdOutImpl adapter. Class not found: org.jboss.vfs.VFS JBoss 6 VFS API is not available in this environment. Class not found: org.jboss.vfs.VirtualFile VFS implementation org.apache.iba…...

ARM作业1

三盏灯流水 代码 .text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[4]->1 0x50000a28 LDR R0,0X50000A28 LDR R1,[R0] 从r0为起始地址的4字节数据取出放在R1 ORR R1,R1,#(0x3<<4) 第4位设置为1 STR R1,[R0] 写回2.设置PE10管…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中&#xff0c;手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力&#xff0c;既支持点击、长按、拖拽等基础单一手势的精细控制&#xff0c;也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成&#xff0c;用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机&#xff1a; ​onCreate()​​ ​调用时机​&#xff1a;Activity 首次创建时调用。​…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

20个超级好用的 CSS 动画库

分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码&#xff0c;而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库&#xff0c;可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画&#xff0c;可以包含在你的网页或应用项目中。 3.An…...

力扣热题100 k个一组反转链表题解

题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...

RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill

视觉语言模型&#xff08;Vision-Language Models, VLMs&#xff09;&#xff0c;为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展&#xff0c;机器人仍难以胜任复杂的长时程任务&#xff08;如家具装配&#xff09;&#xff0c;主要受限于人…...

【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案

目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后&#xff0c;迭代器会失效&#xff0c;因为顺序迭代器在内存中是连续存储的&#xff0c;元素删除后&#xff0c;后续元素会前移。 但一些场景中&#xff0c;我们又需要在执行删除操作…...

论文阅读笔记——Muffin: Testing Deep Learning Libraries via Neural Architecture Fuzzing

Muffin 论文 现有方法 CRADLE 和 LEMON&#xff0c;依赖模型推理阶段输出进行差分测试&#xff0c;但在训练阶段是不可行的&#xff0c;因为训练阶段直到最后才有固定输出&#xff0c;中间过程是不断变化的。API 库覆盖低&#xff0c;因为各个 API 都是在各种具体场景下使用。…...