当前位置: 首页 > news >正文

机器学习7:pytorch的逻辑回归

一、说明

        逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。

        二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。

二、sigmoid函数

        这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。

        对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。

图片来源:作者

三、过程

        要进行二项式逻辑回归,我们需要做各种事情:

  1. 创建训练数据集。
  2. 使用 PyTorch 创建我们的模型。
  3. 将我们的数据拟合到模型中。

        逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seed

我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear

model = Linear(in_features=1, out_features=1) # use a linear model

接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是

  1. 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

现在,我们的代码应如下所示:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

四、优化

        我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。

        我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。

        为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent

        现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。

optimizer.zero_grad()
Yhat = torch.sigmoid(model(X)) 
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step() 

五、收尾

        为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()

将所有代码片段放在一起,我们应该得到以下代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y valuesepochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314

两千个时期后的最终输出:

epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314 

六、可视化

        最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:

x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()

图片来源:作者

七、局限性

        二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?

        在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。

        悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。

 八、结论

        成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。

        通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。

相关文章:

机器学习7:pytorch的逻辑回归

一、说明 逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。 二项式逻辑回归的一个例子是预测人…...

Java应用程序中如何实现FTP功能 | 代码示例和教程

原为地址:https://www.toymoban.com/diary/java/363.html 在Java应用程序中实现FTP功能需要使用FTPClient类和相关方法。下面是实现三个主要功能的示例代码: 1)显示FTP服务器上的文件: void ftpList_actionPerformed(ActionEv…...

kotlin:list的for循环

代码: var list { "a", "b", "c" } for (i in list.indices) {print("app"i""list[i]) }...

asp.net电影院选座系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net电影院选座系统 是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言开发 asp.net电影院选座系统1 二、功能介…...

CSS鼠标指针表

(机翻)搬运自:cursor - CSS: Cascading Style Sheets | MDN (mozilla.org) 类型Keyword演示注释全局autoUA将基于当前上下文来确定要显示的光标。例如,相当于悬停文本时的文本。default 依赖于平台的默认光标。通常是箭头。none不会渲染光标。链接&状态contex…...

树的基本概念及二叉树

目录 一、树的基本概念 (1)树的结点 (2)度 (3)结点层次 (4)树的高度 树的特点: 二、二叉树 (1)满二叉树 (2)完…...

BUUCTF Basic 解题记录--BUU XXE COURSE

1、XXE漏洞 初步学习,可参考链接: 一篇文章带你深入理解漏洞之 XXE 漏洞 - 先知社区 2、了解了XXE漏洞,用burpsuite获取到的url转发给repeater,修改XML的信息,引入外部实体漏洞,修改发送内容,…...

kotlin:LogKit

看到别人的一个代码,觉得有点意思,就复制过来。 package robatimport android.util.Log import java.util.*object LogKit {private val MIN_STACK_OFFSET 3var defaultTag "LogKit"private val lineSeparator System.getProperty("l…...

yolo_tracking中osnet不支持.pth格式,而model_zoo中仅有.pth

yolo_traking-7.0中REID模块用到了osnet,track.py中模型文件不支持.pth,而model_zoo中仅有.pth,改动代码太麻烦了,网上查到的.pth文件转化为.pt文件都需要读取网络架构,不太可能实现。 读取osnet_x0_25_msmt17.pth发现…...

Tailwind CSS浅析与实操

Tailwind CSS 一、Tailwind CSS简介 What is Tailwind CSS Tailwind CSS| TailwindCSS中文文档 | TailwindCSS中文网官方解释:只需书写 HTML 代码,无需书写 CSS,即可快速构建美观的网站。本质上是一个工具集,包含了大量类似 fle…...

Activiti工作流引擎详解与应用

一、简介 Activiti是一个开源的工作流引擎,基于BPMN2.0标准进行流程定义。它可以将业务系统中复杂的业务流程抽取出来,使用专门的建模语言BPMN2.0进行定义,业务流程按照预先定义的流程进行执行,实现了系统的流程由Activiti进行管…...

New Journal of Physics:不同机器学习力场特征的准确性测试

文章信息 作者:Ting Han1, Jie Li1, Liping Liu2, Fengyu Li1, * and Lin-Wang Wang2, * 通信单位:内蒙古大学物理科学与技术学院、中国科学院半导体研究所 DOI:10.1088/1367-2630/acf2bb 研究背景 近年来,基于DFT数据的机器学…...

ubuntu22.04 x11窗口环境手势控制

ubuntu22.04 x11窗口环境手势控制 ubuntu x11窗口环境的手势控制并不优秀,我们可以使用touchegg去代替 这个配置过程非常简单,并且可以很容易在一定范围内达到你想到的效果,类比mac的手势控制 关于安装 首先添加源,并安装 sud…...

【ARM CoreLink 系列 4 -- NIC-400 控制器详细介绍】

文章目录 1.1 ARM NIC-400(Network interconnect)1.1.1 NIC-400 系统框图1.1.2 NIC-400 Network Interconnect1.2 NIC-400 特点1.2.1 QoS-400 Advanced Quality of Service1.2.2 QVN-400 QoS Virtual Networks1.2.3 TLX-400 Thin Links1.3 NIC-400 Top1.4 NIC-400 Terminology1…...

【生成模型】解决生成模型面对长尾类型物体时的问题 RE-IMAGEN: RETRIEVAL-AUGMENTED TEXT-TO-IMAGE GENERATOR

介绍 尽管最先进的模型可以生成常见实体的高质量图像,但它们通常难以生成不常见实体的图像,例如“Chortai(狗)”或“Picarones(食物)”。为了解决这个问题,我们提出了检索增强文本到图像生成器…...

南美巴西市场最全分析开发攻略,收藏一篇就够了

巴西位于南美洲东部,是南美洲资源最丰富,经济活力和经济实力最强的国家。巴西作为拉丁美洲的出口大国,一直是一个比较有潜力的市场,亦是我国外贸公司和独立外贸人集群的地方。中国长期是巴西主要的合作伙伴,2022年占巴…...

c++中操作符->与 . 的使用与区别

在C中,-> 和 . 是两个不同的成员访问操作符,用于访问类、结构体或联合体的成员。 “->” 操作符: 用于通过指针访问指针所指向对象的成员。当有一个指向对象的指针时,可以使用 -> 操作符来访问该指针所指向对象的成员。…...

golang 编译器 汉化

1、找到左上角file选项,点击选中settings进行单机 2、找到settings中找到plugins选中进行点击 3、再框中输入chinese进行搜索,出结果后找到如下图所示,点击进行安装 4、安装完成后进行重启ide,完美解决...

压缩包系列

1、zip伪加密 一个zip文件由三部分组成:压缩源文件数据区压缩源文件目录区压缩源文件目录结束标志。 伪加密原理:zip伪加密是在文件头中加密标志位做修改,然后在打开时误被识别成加密压缩包。 压缩源文件数据区: 50 4B 03 04&a…...

互联网图片安全风控实战训练营开营!

内容安全风控,即针对互联网产生的海量内容的外部、内部风险做宏观到微观的引导和审核,从内容安全领域帮助企业化解监管风险和社会舆论风险,其核心是识别文本、图片、视频、音频中的有害内容。 由于互联网内容类型繁杂、多如牛毛,加…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...

保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek

文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事,必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后,我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集,就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目,该项目是一个 Spring AI 快速入门的样例工程项目,旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计,每个模块都专注于特定的功能领域,便于学习和…...

mac:大模型系列测试

0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

怎么开发一个网络协议模块(C语言框架)之(六) ——通用对象池总结(核心)

+---------------------------+ | operEntryTbl[] | ← 操作对象池 (对象数组) +---------------------------+ | 0 | 1 | 2 | ... | N-1 | +---------------------------+↓ 初始化时全部加入 +------------------------+ +-------------------------+ | …...