机器学习7:pytorch的逻辑回归
一、说明
逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。
二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。
二、sigmoid函数
这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。
对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。
图片来源:作者
三、过程
要进行二项式逻辑回归,我们需要做各种事情:
- 创建训练数据集。
- 使用 PyTorch 创建我们的模型。
- 将我们的数据拟合到模型中。
逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42) # set a random seed
我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear
model = Linear(in_features=1, out_features=1) # use a linear model
接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是
- 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1) # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1) # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x]) # matrix of x values
Y = torch.vstack([blue_y, red_y]) # matrix of y values
现在,我们的代码应如下所示:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42) # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1) # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1) # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x]) # matrix of x values
Y = torch.vstack([blue_y, red_y]) # matrix of y values
四、优化
我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。
我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。
为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。
epochs = 2000 # run 2000 iterations
criterion = nn.BCELoss() # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent
现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。
optimizer.zero_grad()
Yhat = torch.sigmoid(model(X))
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step()
五、收尾
为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。
epochs = 2000 # run 2000 iterations
criterion = nn.BCELoss() # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
将所有代码片段放在一起,我们应该得到以下代码:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42) # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1) # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1) # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x]) # matrix of x values
Y = torch.vstack([blue_y, red_y]) # matrix of y valuesepochs = 2000 # run 2000 iterations
criterion = nn.BCELoss() # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss: 0.53861
slope: 0.61276
intercept: -3.17314
两千个时期后的最终输出:
epoch: 2000
loss: 0.53861
slope: 0.61276
intercept: -3.17314
六、可视化
最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:
x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()
图片来源:作者
七、局限性
二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?
在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。
悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。
八、结论
成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。
通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。
相关文章:

机器学习7:pytorch的逻辑回归
一、说明 逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。 二项式逻辑回归的一个例子是预测人…...
Java应用程序中如何实现FTP功能 | 代码示例和教程
原为地址:https://www.toymoban.com/diary/java/363.html 在Java应用程序中实现FTP功能需要使用FTPClient类和相关方法。下面是实现三个主要功能的示例代码: 1)显示FTP服务器上的文件: void ftpList_actionPerformed(ActionEv…...
kotlin:list的for循环
代码: var list { "a", "b", "c" } for (i in list.indices) {print("app"i""list[i]) }...

asp.net电影院选座系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio
一、源码特点 asp.net电影院选座系统 是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言开发 asp.net电影院选座系统1 二、功能介…...

CSS鼠标指针表
(机翻)搬运自:cursor - CSS: Cascading Style Sheets | MDN (mozilla.org) 类型Keyword演示注释全局autoUA将基于当前上下文来确定要显示的光标。例如,相当于悬停文本时的文本。default 依赖于平台的默认光标。通常是箭头。none不会渲染光标。链接&状态contex…...

树的基本概念及二叉树
目录 一、树的基本概念 (1)树的结点 (2)度 (3)结点层次 (4)树的高度 树的特点: 二、二叉树 (1)满二叉树 (2)完…...

BUUCTF Basic 解题记录--BUU XXE COURSE
1、XXE漏洞 初步学习,可参考链接: 一篇文章带你深入理解漏洞之 XXE 漏洞 - 先知社区 2、了解了XXE漏洞,用burpsuite获取到的url转发给repeater,修改XML的信息,引入外部实体漏洞,修改发送内容,…...
kotlin:LogKit
看到别人的一个代码,觉得有点意思,就复制过来。 package robatimport android.util.Log import java.util.*object LogKit {private val MIN_STACK_OFFSET 3var defaultTag "LogKit"private val lineSeparator System.getProperty("l…...

yolo_tracking中osnet不支持.pth格式,而model_zoo中仅有.pth
yolo_traking-7.0中REID模块用到了osnet,track.py中模型文件不支持.pth,而model_zoo中仅有.pth,改动代码太麻烦了,网上查到的.pth文件转化为.pt文件都需要读取网络架构,不太可能实现。 读取osnet_x0_25_msmt17.pth发现…...

Tailwind CSS浅析与实操
Tailwind CSS 一、Tailwind CSS简介 What is Tailwind CSS Tailwind CSS| TailwindCSS中文文档 | TailwindCSS中文网官方解释:只需书写 HTML 代码,无需书写 CSS,即可快速构建美观的网站。本质上是一个工具集,包含了大量类似 fle…...

Activiti工作流引擎详解与应用
一、简介 Activiti是一个开源的工作流引擎,基于BPMN2.0标准进行流程定义。它可以将业务系统中复杂的业务流程抽取出来,使用专门的建模语言BPMN2.0进行定义,业务流程按照预先定义的流程进行执行,实现了系统的流程由Activiti进行管…...

New Journal of Physics:不同机器学习力场特征的准确性测试
文章信息 作者:Ting Han1, Jie Li1, Liping Liu2, Fengyu Li1, * and Lin-Wang Wang2, * 通信单位:内蒙古大学物理科学与技术学院、中国科学院半导体研究所 DOI:10.1088/1367-2630/acf2bb 研究背景 近年来,基于DFT数据的机器学…...

ubuntu22.04 x11窗口环境手势控制
ubuntu22.04 x11窗口环境手势控制 ubuntu x11窗口环境的手势控制并不优秀,我们可以使用touchegg去代替 这个配置过程非常简单,并且可以很容易在一定范围内达到你想到的效果,类比mac的手势控制 关于安装 首先添加源,并安装 sud…...
【ARM CoreLink 系列 4 -- NIC-400 控制器详细介绍】
文章目录 1.1 ARM NIC-400(Network interconnect)1.1.1 NIC-400 系统框图1.1.2 NIC-400 Network Interconnect1.2 NIC-400 特点1.2.1 QoS-400 Advanced Quality of Service1.2.2 QVN-400 QoS Virtual Networks1.2.3 TLX-400 Thin Links1.3 NIC-400 Top1.4 NIC-400 Terminology1…...

【生成模型】解决生成模型面对长尾类型物体时的问题 RE-IMAGEN: RETRIEVAL-AUGMENTED TEXT-TO-IMAGE GENERATOR
介绍 尽管最先进的模型可以生成常见实体的高质量图像,但它们通常难以生成不常见实体的图像,例如“Chortai(狗)”或“Picarones(食物)”。为了解决这个问题,我们提出了检索增强文本到图像生成器…...

南美巴西市场最全分析开发攻略,收藏一篇就够了
巴西位于南美洲东部,是南美洲资源最丰富,经济活力和经济实力最强的国家。巴西作为拉丁美洲的出口大国,一直是一个比较有潜力的市场,亦是我国外贸公司和独立外贸人集群的地方。中国长期是巴西主要的合作伙伴,2022年占巴…...
c++中操作符->与 . 的使用与区别
在C中,-> 和 . 是两个不同的成员访问操作符,用于访问类、结构体或联合体的成员。 “->” 操作符: 用于通过指针访问指针所指向对象的成员。当有一个指向对象的指针时,可以使用 -> 操作符来访问该指针所指向对象的成员。…...

golang 编译器 汉化
1、找到左上角file选项,点击选中settings进行单机 2、找到settings中找到plugins选中进行点击 3、再框中输入chinese进行搜索,出结果后找到如下图所示,点击进行安装 4、安装完成后进行重启ide,完美解决...
压缩包系列
1、zip伪加密 一个zip文件由三部分组成:压缩源文件数据区压缩源文件目录区压缩源文件目录结束标志。 伪加密原理:zip伪加密是在文件头中加密标志位做修改,然后在打开时误被识别成加密压缩包。 压缩源文件数据区: 50 4B 03 04&a…...

互联网图片安全风控实战训练营开营!
内容安全风控,即针对互联网产生的海量内容的外部、内部风险做宏观到微观的引导和审核,从内容安全领域帮助企业化解监管风险和社会舆论风险,其核心是识别文本、图片、视频、音频中的有害内容。 由于互联网内容类型繁杂、多如牛毛,加…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
Java - Mysql数据类型对应
Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
镜像里切换为普通用户
如果你登录远程虚拟机默认就是 root 用户,但你不希望用 root 权限运行 ns-3(这是对的,ns3 工具会拒绝 root),你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案:创建非 roo…...
Spring Boot面试题精选汇总
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...

短视频矩阵系统文案创作功能开发实践,定制化开发
在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...

R语言速释制剂QBD解决方案之三
本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...