PyTorch实现逻辑回归
最终效果
先看下最终效果:

这里用一条直线把二维平面上不同的点分开。
生成随机数据
#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5) # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5) # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)
数据可视化
def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()
使用x和o来表示两种不同类别的数据。

定义模型和损失函数
#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True) # 随机初始化w
b = torch.zeros((1),requires_grad=True) # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()
这里使用了平方损失函数来估算模型准确度。
训练模型
最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。
xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03: # 停止条件break
全部代码
import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5) # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5) # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止条件break
相关文章:
PyTorch实现逻辑回归
最终效果 先看下最终效果: 这里用一条直线把二维平面上不同的点分开。 生成随机数据 #创建训练数据 x torch.rand(10,1)*10 #shape(10,1) y 2*x (5 torch.randn(10,1))#构建线性回归参数 w torch.randn((1))#随机初始化w,要用到自动梯度求导 b …...
什么是FPGA原型验证?
EDA工具的使用主要分为设计、验证和制造三大类。验证工作贯穿整个芯片设计流程,可以说芯片的验证阶段占据了整个芯片开发的大部分时间。从芯片需求定义、功能设计开发到物理实现制造,每个环节都需要进行大量的验证。 现如今验证方法也越来越多ÿ…...
基于VUE3+Layui从头搭建通用后台管理系统(前端篇)十四:系统设置模块相关功能实现
一、本章内容 本章使用已实现的公共组件实现系统管理中的系统设置模块相关功能,包括菜单管理、角色管理、日志管理、用户管理、系统配置、数据字典等。 1. 详细课程地址: 待发布 2. 源码下载地址: 待发布 二、界面预览 三、开发视频 3.1 B站视频地址:...
使用Visual Studio(VS)创建空项目的Win32桌面应用程序【main函数入口变WinMain】
前言 在Visual Studio中直接新建Windows桌面应用程序会有很多多余的代码生成,本文将提供从空项目创建Win32项目的方法,解决新建空项目直接使用WinMain代码编译报错的问题 例如:LNK2019 :无法解析的外部符号 参考博客࿱…...
基于自动化脚本批量上传依赖到nexus内网私服
前言 因为某些原因某些企业希望私服是不能连接外网的,所以需要某些开源依赖需要我们手动导入到nexus中,尽管nexus为我们提供了web页面。但是一个个手动导入显然是一个庞大的工程。 对此我们就不妨基于脚本的方式实现这一过程。 预期效果 笔者本地仓库…...
Linux中ps命令使用指南
目录 1 前言2 ps命令的含义和作用3 ps命令的基本使用4 常用选项参数5 一些常用情景5.1 查看系统中的所有进程(标准语法)5.2 使用 BSD 语法查看系统中的所有进程5.3 打印进程树5.4 获取线程信息5.5 获取安全信息5.6 查看以 root 用户身份(实际…...
PHP开发语言中,网页端常用的标签
在PHP开发语言中,网页端常用的标签包括以下几种: <html>:用于定义整个HTML文档。<head>:用于定义文档的头部,包含元数据、样式表和脚本等。<title>:用于定义文档的标题,显示…...
Java 入门第四篇 集合
Java 入门第四篇 集合 一,什么是集合 在Java中,集合(Collection)是一种用于存储和操作一组对象的容器类。它提供了一系列的方法和功能,用于方便地管理和操作对象的集合。集合框架是Java中非常重要和常用的一部分&…...
VBA技术资料MF93:将多个Excel表插入PowerPoint不同位置
我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…...
STM32 MCU的易坑点收集
IIC配置中的Clock No Stretch Mode Clock Stretch Mode时钟延长模式: 时钟延长是一个术语,某些从设备可以把时钟线拉低,主设备发现自己释放时钟线之后时钟线还没有变成高电平,就会停止发送数据,然后等待从设备释放时钟…...
Vue3项目filter.js组件封装
1、element-plus(el-table)修改table的行样式 export function elTableRowClassName({ row, rowIndex }) {if (rowIndex % 2 ! 0) {return default-row} }2、时间戳转换格式 export function parseTimeFilter(dateTime, dateType) {if (dateTime || dateTime undefined ||…...
Linux: pwd命令查看当前工作目录
pwd 是 Linux 和其他类 Unix 操作系统中的一个命令,用于显示当前工作目录的绝对路径。 语法 pwd 描述 pwd 是 "print working directory" 的缩写,它用于打印当前工作目录的完整路径。这对于确定当前目录位置非常有用,特别是在嵌…...
【深度学习】PHP操作mysql数据库总结
一.PHP数据库的扩展分类 1.MySQL 扩展是针对 MySQL 4.1.3 或更早版本设计的,是 PHP 与 MySQL数据库交互的早期扩展。由于其不支持 MySQL 数据库服务器的新特性,且安全性差,在项目开发中不建议使用,可用 MySQLi 扩展代替。 2.MySQ…...
【送书活动】探究AIGC、AGI、GPT和人工智能大模型
文章目录 前言01 《ChatGPT 驱动软件开发》推荐语 02 《ChatGPT原理与实战》推荐语 03 《神经网络与深度学习》推荐语 04 《AIGC重塑教育》推荐语 05 《通用人工智能》推荐语 后记赠书活动 前言 人工智能技术在过去几年中发展迅猛,得益于大数据、云计算、深度学习等…...
Apple Find My「查找」认证芯片找哪家,认准伦茨科技ST17H6x芯片
深圳市伦茨科技有限公司(以下简称“伦茨科技”)发布ST17H6x Soc平台。成为继Nordic之后全球第二家取得Apple Find My「查找」认证的芯片厂家,该平台提供可通过Apple Find My认证的Apple查找(Find My)功能集成解决方案。…...
java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value
问题描述 使用Springcloudalibaba的nacos作为配置中心,服务启动时报错: java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value java.lang.IllegalArgumentException: Param ‘serviceName’ is illegal, serviceName is …...
自动机器学习是什么?概念及应用
自动机器学习 (Auto Machine Learning) 的应用和方法 随着众多企业在大量场景中开始采用机器学习,前后期处理和优化的数据量及规模指数级增长。企业很难雇用充足的人手来完成与高级机器学习模型相关的所有工作,因此机器学习自动化工具是未来人工智能 (A…...
el-date-picker限制选择7天内禁止内框选择
需求:elementPlus时间段选择框需要满足:①最多选7天时间。②不能手动输入。 <el-date-picker v-model"timeArrange" focus"timeEditable" :editable"false" type"datetimerange" range-separator"至&qu…...
Navicat 技术指引 | 适用于 GaussDB 分布式的调试器
Navicat Premium(16.3.3 Windows 版或以上)正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能,还提供强大的高阶功能(如模型、结…...
人工智能导论习题集(3)
第五章:不确定性推理 题1题2题3题4题5题6题7题8 题1 题2 题3 题4 题5 题6 题7 题8...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...
Admin.Net中的消息通信SignalR解释
定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...
【Linux】C语言执行shell指令
在C语言中执行Shell指令 在C语言中,有几种方法可以执行Shell指令: 1. 使用system()函数 这是最简单的方法,包含在stdlib.h头文件中: #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
动态 Web 开发技术入门篇
一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...
python爬虫——气象数据爬取
一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用: 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests:发送 …...
MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释
以Module Federation 插件详为例,Webpack.config.js它可能的配置和含义如下: 前言 Module Federation 的Webpack.config.js核心配置包括: name filename(定义应用标识) remotes(引用远程模块࿰…...
