用TensorFlow实现线性回归
说明
本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然,如果您已经安装了tensorflow,可以采用python直接运行。
贡献
学习时采取动手学深度学习第二版作为教材,但由于本书通过引入d2l(著者自写库)进行深度学习,我希望将d2l的影响去掉,即不使用d2l,使用tensorflow,这一点通过查询GitHub中d2l库提供的相关函数尝试进行实现。
如果本系列文章具有良好表现,将译为英文版上传至Github。
预备知识
学习本篇文章之前,您最好具有以下基础知识:
- 线性回归的基础知识
- python的基础知识
基本原理
使用一个仿射变换,通过y=wx+b的模型来对数据进行预测(w和x均为矩阵,大小取决于输入规模),反向传播采用随机梯度下降对参数进行更新,参数包括w和b,即权重和偏差。
实现过程
生成数据集
只需要引入tensorflow即可,synthetic_data()函数将初始化X和Y,即通过真实的权重和偏差值生成数据集。
import tensorflow as tfdef synthetic_data(w, b, num_examples):X = tf.zeros((num_examples, w.shape[0]))X += tf.random.normal(shape=X.shape)y = tf.matmul(X, tf.reshape(w, (-1, 1))) + by += tf.random.normal(shape=y.shape, stddev=0.01)y = tf.reshape(y, (-1, 1))return X, ytrue_w = tf.constant([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
读取数据集
加载刚刚生成的数据集,is_train表示是否进行打乱,默认对数据进行打乱处理,使用load_array函数加载数据集。
def load_array(data_arrays, batch_size, is_train=True):dataset = tf.data.Dataset.from_tensor_slices(data_arrays)if is_train:dataset = dataset.shuffle(buffer_size=1000)dataset = dataset.batch(batch_size)return datasetbatch_size = 10
data_iter = load_array((features, labels), batch_size)
定义模型
模型使用keras API实现,keras是tensorflow中机器学习相关的库。先使用Sequential类定义承载容器,之后添加一个单神经元的全连接层。在TensorFlow中,Sequential表示容器相关的类,layer表示层相关的类。线性回归只需要通过keras中的单神经元的全连接层即可实现,神经元的值即为输出结果。
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1))
示例的线性回归仅有一个输入X,实际在其他线性回归过程中,很有可能有多个x及其对应的w,但keras的代码均不会发生改变,因为keras的Dense类可以自动判断输入的个数。
初始化模型参数
stddev表示标准差,initializer生成一个标准差为1,均值为0的正态分布。在构建全连接层时,使用该正态分布进行初始化。
initializer = tf.initializers.RandomNormal(stddev=0.01)
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1, kernel_initializer=initializer))
定义损失函数和优化算法
损失函数使用平方损失函数进行计算,训练时使用小批量随机梯度下降SGD方法进行训练,学习率为0.03。
loss = tf.keras.losses.MeanSquaredError()
trainer = tf.keras.optimizers.SGD(learning_rate=0.03)
训练
运行以下代码可以观察训练结果。运行轮次为3轮,每一轮对所有训练集数据进行学习。计算w和b的梯度值,使用梯度下降更新权重w和偏差b。每一轮输出损失函数的值,最终显示权重和偏差的估计误差。
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:with tf.GradientTape() as tape:l = loss(net(X, training=True), y)grads = tape.gradient(l, net.trainable_variables)trainer.apply_gradients(zip(grads, net.trainable_variables))l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')
w = net.get_weights()[0]
print('w的估计误差:', true_w - tf.reshape(w, true_w.shape))
b = net.get_weights()[1]
print('b的估计误差:', true_b - b)
运行结果
epoch 1, loss 0.000194
epoch 2, loss 0.000091
epoch 3, loss 0.000091
w的估计误差: tf.Tensor([-0.00026917 0.00094557], shape=(2,), dtype=float32)
b的估计误差: [4.7683716e-06]
改进尝试
- 更改SGD优化算法为Adam
- 更改MeanSquaredError为其他损失函数
对于上述改进,损失均有显著增加,表明原有方法已为最好方法。
相关文章:

用TensorFlow实现线性回归
说明 本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然&#…...

IT计算机软件系统类毕业论文结构指南:从标题到结论的全景视角
一、背景 在快速发展的IT和人工智能领域,毕业论文不仅是学术研究的重要成果,也展示了学生掌握新技术和应用的能力。随着大数据和智能系统的复杂性增加,毕业设计(毕设)的论文章节安排变得尤为关键。一个结构清晰、内容详…...
leetcode27:移除元素(正解)
移除元素 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。 假设 nums 中不等于 val 的元素数量为 k,要通过此题,您需要执行以下操作…...
docker部署nginx--(部署静态文件和服务)
文档参考 1、http://testingpai.com/article/1649671014266 2、下载nginx docker pull nginx:alpine 然后启动nginx, docker run --rm -it -p 9192:80 nginx:alpine /bin/sh 关闭容器后,自动删除该容器 进入后,启动nginx, nginx进行curl h…...

websocket的介绍及springBoot集成示例
目录 一、什么是Websocket 二、Websocket特点 三、WebSocket与HTTP的区别 四、常见应用场景 五、SpringBoot集成WebSocket 1. 原生注解 2. Spring封装 一、什么是Websocket WebSocket 是一种在单个 TCP 连接上进行 全双工 通信的协议,它可以让客户端和服务器…...

软件测试-自动化测试
自动化测试 测试人员编写自动化测试脚本,维护并解决自动化脚本问题 自动化的主要目的就是用来进行回归测试 回归测试 常见面试题 ⾃动化测试能够取代人工测试吗? ⾃动化测试不⼀定⽐人工测试更能保障系统的可靠性,⾃动化测试是测试⼈员手…...

Linux 安装TELEPORT堡垒机
一、查看官方文档 堡垒机官网地址:走向成功 - Teleport,高效易用的堡垒机 (一)官网资源链接 -》Teleport 在线文档 (二)手动下载安装包 二、压缩包下载和安装 (一)加压下载的安装…...
【14】即时编译器的中间表达形式
中间表达形式(IR) 编译器一般被分为前端和后端。 前端会对输入的程序进行词法分析、语法分析和语义分析,然后生成中间表达形式(IR);后端对IR进行优化,生成目标代码 不考虑解释执行的话…...

Mysql(三)---增删查改(基础)
文章目录 前言1.补充1.修改表名1.2.修改列名1.3.修改列类型1.4.增加新列1.5.删除指定列 2.CRUD3.新增(Create)3.1.单行插入3.2.指定列插入3.3.多行插入 4.数据库的约束4.1.约束的分类4.2.NULL约束4.3.Unique约束4.4.Default 默认值约束4.5.PRIMARY KEY:主键约束4.6.…...
Dialog实现原理分析
在 Android 中,对话框(Dialog)是一种非常常见的用户界面组件,用于向用户提供额外的信息或者请求用户的确认。Android 提供了几种不同类型的对话框,例如简单的消息对话框 (AlertDialog)、进度条对话框 (ProgressDialog)…...

21.1 基于Netty实现聊天
21.1 基于Netty实现聊天 一. 章节概述二. `Netty`介绍三. 阻塞与非阻塞1. 阻塞与非阻塞简介2. BIO同步阻塞3. NIO同步非阻塞4. AIO异步非阻塞IO5. 异步阻塞IO(用的极少)6. 总结四. Netty三种线程模型1. 单线程模型2. 多线程模型3. 主从线程模型五. 构建Netty服务器************…...

尼卡音乐 v1.0.5 — 全新推出的免费音乐听歌软件
尼卡音乐是一款全新推出的免费音乐听歌软件,无需注册登录,打开即拥有全部功能。聚合了六大音源曲库、歌单、排行榜,支持在线试听、无损下载以及高清MV播放。资源全、无广告、更新快,适合寻找高品质音乐体验的用户。 拿走的麻烦评…...
Scratch深潜:解锁递归与分治算法的编程之门
亮眼标题:“Scratch深潜:解锁递归与分治算法的编程之门” 在编程的世界里,递归和分治算法是解决问题的强大工具。Scratch,这款广受儿童和初学者欢迎的图形化编程语言,以其独特的拖拽式编程块,激发了无数年…...
【1.0】vue3的创建
【1.0】vue3的创建 【一】vue3介绍 vue2的所有东西,vue3都兼容 vue3中写js代码由两种,组合式和配置项 配置项api,就是vue2的写法,将数据放进data,方法放进methods等 export default{data(){return {}},methods:…...
刷刷前端手写题
闭包用途 闭包 闭包让你可以在一个内层函数中访问到其外层函数的作用域 防抖 描述 前面所有触发都被取消,最后一次执行,在规定时间之后才会触发,也就是说如果连续快速的触发,用户操作频繁,但只会执行一次 。 常用场…...

论文解读:LONGWRITER: UNLEASHING 10,000+ WORD GENERATION FROM LONG CONTEXT LLMS
摘要 现象:当前的大预言模型可以接受超过100,000个tokens的输入,但是却难以生成超过2000个token的输出。 原因:监督微调过程(SFT)中看到的样本没有足够长的样本。 解决方法: Agent Write,可以将长任务分解为子任务&a…...

一文了解Ansible原理以及常见使用模块
ansible使用手册 1. 简述 Ansible 是一种开源的自动化工具,主要用于配置管理、应用程序部署和任务自动化。 它使用简单的 YAML 语言来定义自动化的任务【playbook】,使得配置和部署变得更加直观和易于管理。 基于SSH协议连接到远程主机来执行指令。 2…...

JavaEE从入门到起飞(九) ~Activiti 工作流
工作流 当一道流程逻辑需要用到多个表单的提交和多个角色的审核共同完成的时候,就可以使用工作流。 工作流一般使用的是第三方技术,也就是说别人帮你创建数据库表和service层、mapper层,你只需要注入工具接口即可使用。 原理:一…...

微服务的保护
一、雪崩问题及解决方案 1.雪崩问题 微服务之间,一个微服务依赖多个其他的微服务。当一个微服务A依赖的一个微服务B出错时,微服务A会被阻塞,但其他不依赖于B的微服务不会受影响。 当有多个微服务依赖于B时,服务器支持的线程和并…...
2024前端面试题-网络篇
1.跨域问题 同源策略:需要协议、域名、端口号相同跨域原因:不符合同源策略便会产生跨域问题解决跨域:JSONP、配置代理、通过CORS解决 2.RPC和HTTP的区别 主要区别是序列化和反序列化,RPC通过二进制高效传输,HTTP是j…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动
一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...
线程与协程
1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...

Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...
return this;返回的是谁
一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请,不同级别的经理有不同的审批权限: // 抽象处理者:审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
SpringAI实战:ChatModel智能对话全解
一、引言:Spring AI 与 Chat Model 的核心价值 🚀 在 Java 生态中集成大模型能力,Spring AI 提供了高效的解决方案 🤖。其中 Chat Model 作为核心交互组件,通过标准化接口简化了与大语言模型(LLM࿰…...

VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...

针对药品仓库的效期管理问题,如何利用WMS系统“破局”
案例: 某医药分销企业,主要经营各类药品的批发与零售。由于药品的特殊性,效期管理至关重要,但该企业一直面临效期问题的困扰。在未使用WMS系统之前,其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...