001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归
0. 背景介绍
我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt
1. 生成数据集合(待拟合)
使用python生成待拟合的数据
num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels = w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])
2.数据的分批量处理
def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)
3. 模型构建及训练
3.1 定义模型:
def linreg(X, w, b):return torch.mm(X,w)+b
3.2 定义损失函数
def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2
3.3 定义优化算法
def sgd(params , lr ,batch_size):for param in params:param.data -= lr * param.grad / batch_size
3.4 模型训练
# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))
epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053
基于pytorch的线性模型的实现
- 相关数据和初始化与上面构建相同
- 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net)
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
- 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
- 定义损失函数
loss = nn.MSELoss()
5.定义优化算法
import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
- 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )
epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066
相关文章:
001.从0开始实现线性回归(pytorch)
000动手从0实现线性回归 0. 背景介绍 我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。 设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R10002 …...
Relations Prediction for Knowledge Graph Completion using Large Language Models
文章目录 题目摘要简介相关工作方法论实验结论局限性未来工作 题目 使用大型语言模型进行知识图谱补全的关系预测 论文地址:https://arxiv.org/pdf/2405.02738 项目地址: https://github.com/yao8839836/kg-llm 摘要 知识图谱已被广泛用于以结构化格式表…...
2024年中国研究生数学建模竞赛D题思路代码分析——大数据驱动的地理综合问题
地理系统是自然、人文多要素综合作用的复杂巨系统[1-2],地理学家常用地理综合的方式对地理系统进行主导特征的表达[3]。如以三大阶梯概括中国的地形特征,以秦岭—淮河一线和其它地理区划的方式揭示中国气温、降水、植被、土壤及生态环境在水平和垂直方向…...
全国31省对外开放程度、经济发展水平、政府干预程度指标数据(2000-2022年)
旨在分析2000-2022年间中国31个省份的对外开放程度、经济发展水平和政府干预程度,探讨其背后的动因与影响。 2000年-2022年 全国31省对外开放程度、经济发展水平、政府干预程度指标数据https://download.csdn.net/download/2401_84585615/89478612 数据概览 对外…...
计算机网络传输层---课后综合题
线路:TCP报文下放到物理层传输。 TCP报文段中,“序号”长度为32bit,为了让序列号不会循环,则最多能传输2^32B的数据,则最多能传输:2^32/1500B个报文 结果: 吞吐率一个周期内传输的数据/周期时间…...
【homebrew安装】踩坑爬坑教程
homebrew官网,有安装教程提示,但是在实际安装时,由于待下载的包的尺寸过大,本地git缓存尺寸、超时时间的限制,会报如下错误: error: RPC failed; curl 92 HTTP/2 stream 5 was not closed cleanly…...
反游戏学(Reludology):概念、历史、现状与展望?(豆包AI版)
李升伟 以下是关于“反游戏学(Reludology):概念、历史、现状与展望”的综述: 一、概念 反游戏学(Reludology)是一个相对较新且不太常见的概念,目前尚未有统一明确的定义。一般来说…...
【C/C++语言系列】实现单例模式
1.单例模式概念 定义:单例模式是一种常见的设计模式,它可以保证系统中一个类只有一个实例,而且该实例易于外界访问(一个类一个对象,共享这个对象)。 条件: 只有1个对象易于外界访问共享这个对…...
A. Make All Equal
time limit per test 1 second memory limit per test 256 megabytes You are given a cyclic array a1,a2,…,ana1,a2,…,an. You can perform the following operation on aa at most n−1n−1 times: Let mm be the current size of aa, you can choose any two adjac…...
业务安全治理
业务安全治理 1.账号安全撞库账户盗用 2.爬虫与反爬虫3.API网关防护4.钓鱼与反制钓鱼发现钓鱼处置 5.大数据风控风控介绍 1.账号安全 撞库 撞库分为垂直撞库和水平撞库两种,垂直撞库是对一个账号使用多个不同的密码进行尝试,可以理解为暴力破解&#x…...
HelpLook VS GitBook,在线文档管理工具对比
在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…...
docker面经
docker面经在线链接 docker面经在线链接🔗: (https://h03yz7idw7.feishu.cn/wiki/N3CVwO3kMifLypkJqnic9wNynKh)...
Python 中的 Kombu 类库
Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…...
safepoint是什么?有什么用?
在JVM中,safepoint(安全点)是一个非常重要的概念,特别是在垃圾回收(GC)和其他需要暂停所有应用线程的操作中。 什么是safepoint Safepoint是JVM执行过程中一个特定的位置,在这个位置上&#x…...
axios相关知识点
一、基本概念 1、基于Promise:Axios通过Promise实现异步请求,避免了传统回调函数导致的“回调地狱”问题,使得代码更加清晰和易于维护。 2、跨平台:Axios既可以在浏览器中运行,也可以在Node.js环境中使用,为前后端开…...
LeetCode 面试经典150题 67.二进制求和
415.字符串相加 思路一模一样 题目:给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 eg: 输入a“1010” b“1011” 输出“10101” 思路:从右开始遍历两个字符串,因为右边是低位先运算。如果…...
Dell PowerEdge 网络恢复笔记
我有一台Dell的PowerEdge服务器,之前安装了Ubuntu 20 桌面版。突然有一天不能开机了。 故障排查 Disk Error 首先是看一下机器的正面,有一个非常小的液晶显示器,只能显示一排字。 上面显示Disk Error,然后看挂载的硬盘仓&#…...
Java面试——集合篇
1.Java中常用的容器有哪些? 容器主要包括 Collection 和 Map 两种,Collection 存储着对象的集合,而 Map 存储着键值对(两个对象)的映射表。 如图: 面试官追问:说说集合有哪些类及他们各自的区别和特点? S…...
算法【双向广搜】
双向广搜常见用途 1:小优化。bfs的剪枝策略,分两侧展开分支,哪侧数量少就从哪侧展开。 2:用于解决特征很明显的一类问题。特征:全量样本不允许递归完全展开,但是半量样本可以完全展开。过程:把…...
javascript检测数据类型的方法
1. typeof 运算符 typeof是一个用来检测变量的基本数据类型的运算符。它可以返回以下几种类型的字符串:“undefined”、“boolean”、“number”、“string”、“object”、“function” 和 “symbol”(ES6)。需要注意的是,对于 n…...
UE5 学习系列(二)用户操作界面及介绍
这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...
使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
【JavaEE】-- HTTP
1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...
STM32标准库-DMA直接存储器存取
文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...
MySQL用户和授权
开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务: test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...
