当前位置: 首页 > article >正文

python3+TensorFlow 2.x(二) 回归模型

目录

回归算法

1、线性回归 (Linear Regression)

一元线性回归举例

 2、非线性回归

3、回归分类 


回归算法

回归算法用于预测连续的数值输出。回归分析的目标是建立一个模型,以便根据输入特征预测目标变量,在使用 TensorFlow 2.x 实现线性回归模型时,通常的步骤包括数据预处理、模型构建、训练和评估。

1、线性回归 (Linear Regression)

概述:线性回归是最基本的回归算法之一,假设目标变量与输入特征之间存在线性关系。

模型形式:y=\beta _{0}+\beta _{1}x_{1}+...+\beta _{n}x_{n}+ϵ,其中 y 是目标变量,x​ 是特征,βi是权重,ϵ 是误差项。

一元线性回归举例

实现步骤

导入必要的库
生成或加载数据预处理:使用生成的线性数据集。生成了一个简单的线性关系 y = 2x + 1,并加上了一些噪声来模拟实际的观测数据。np.linspace 生成 100 个从 0 到 10 的点,np.random.normal 用于生成随机噪声。数据处理:使用 X.reshape(-1, 1) 将 X 变成二维数组,以适应 TensorFlow 的输入要求。

构建线性回归模型:使用 tf.keras.Sequential 创建一个简单的线性模型。只使用一个 Dense 层来表示线性回归,其中 input_dim=1 指明输入特征的维度为 1,output_dim=1 表示输出只有一个预测值。        

编译模型:设置损失函数和优化器。使用了 adam 优化器,这是一个常用且效果不错的优化器。损失函数选择 mean_squared_error,这是回归问题中常见的损失函数。

训练模型:使用训练数据来训练模型。model.fit 方法用于训练模型。设置了 200 个 epoch 和10 的批次大小。

评估模型:通过测试数据评估模型性能。model.evaluate 会返回训练集的损失值,用来评估训练过程中的效果.

预测结果:使用训练好的模型进行预测。使用 matplotlib 绘制训练过程中每个 epoch 的损失变化情况,以便观察模型训练的收敛过程,使用 model.predict 来预测训练集上的输出,然后将预测结果与真实数据一起绘制出来,查看模型的拟合效果

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt# 1. 生成数据:y = 2x + 1
np.random.seed(42)
X = np.linspace(0, 10, 100)  # 生成100个点,范围是[0, 10]
Y = 2 * X + 1 + np.random.normal(0, 1, X.shape[0])  # y = 2x + 1,加上一些噪声# 2. 数据预处理:将数据转化为TensorFlow的张量(也可以直接使用NumPy数组)
X_train = X.reshape(-1, 1)  # 特征,转换成二维数组
Y_train = Y.reshape(-1, 1)  # 标签,转换成二维数组# 3. 构建线性回归模型
model = keras.Sequential([keras.layers.Dense(1, input_dim=1)  # 只有一个输入特征,输出一个值
])# 4. 编译模型:选择损失函数和优化器
model.compile(optimizer='adam', loss='mean_squared_error')# 5. 训练模型
history = model.fit(X_train, Y_train, epochs=200, batch_size=10, verbose=0)# 6. 评估模型
loss = model.evaluate(X_train, Y_train)
print(f"Final training loss: {loss}")# 7. 绘制训练过程中的损失变化
plt.plot(history.history['loss'])
plt.title('Training Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()# 8. 预测结果
Y_pred = model.predict(X_train)# 9. 可视化真实数据和预测结果
plt.scatter(X_train, Y_train, color='blue', label='Actual Data')
plt.plot(X_train, Y_pred, color='red', label='Predicted Line')
plt.title('Linear Regression with TensorFlow 2')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.show()

 2、非线性回归

创建合成数据集:使用 NumPy 生成从 -3 到 3 的 100 个点,并计算对应的 y 值为sin(x) 加上一些噪声。

划分训练集和测试集:使用 train_test_split 将数据集划分为训练集和测试集,比例为 80% 训练,20% 测试。

构建曲线拟合模型:使用 tf.keras.Sequential 创建一个简单的神经网络模型,包含两个隐藏层,每层有 64 个神经元,激活函数为 ReLU,最后一层为输出层。

编译模型:使用 Adam 优化器和均方误差损失函数编译模型。

训练模型:使用 fit 方法训练模型,设置训练轮数为 200,批次大小为 10。

进行预测:使用 predict 方法对测试集进行预测。可视化预测结果,使用 Matplotlib 绘制实际值和预测值的散点图。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 1. 创建合成数据集
np.random.seed(0)
X = np.linspace(-3, 3, 100)  # 生成从-3到3的100个点
y = np.sin(X) + np.random.normal(0, 0.1, X.shape)  # y = sin(x) + 噪声# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 构建曲线拟合模型
# 将输入数据转换为二维数组
X_train = X_train.reshape(-1, 1)
X_test = X_test.reshape(-1, 1)model = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu', input_shape=(1,)),  # 隐藏层tf.keras.layers.Dense(64, activation='relu'),  # 隐藏层tf.keras.layers.Dense(1)  # 输出层
])# 4. 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')# 5. 训练模型
model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0)# 6. 进行预测
predictions = model.predict(X_test)# 7. 可视化预测结果
plt.figure(figsize=(10, 6))
plt.scatter(X_test, y_test, color='blue', label='Actual Values')  # 实际值
plt.scatter(X_test, predictions, color='red', label='Predicted Values')  # 预测值
plt.title('Curve Fitting Regression')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.grid()
plt.show()

3、回归分类 

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification, make_regression
import matplotlib.pyplot as plt# 生成回归数据集
X_reg, y_reg = make_regression(n_samples=1000, n_features=1, noise=10, random_state=42)# 生成分类数据集
X_class, y_class = make_classification(n_samples=1000, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=42)# 划分训练集和测试集
X_reg_train, X_reg_test, y_reg_train, y_reg_test = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)
X_class_train, X_class_test, y_class_train, y_class_test = train_test_split(X_class, y_class, test_size=0.2, random_state=42)# 创建线性回归模型
model_reg = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))  # 输入特征为 1,输出为 1
])# 编译模型
model_reg.compile(optimizer='adam', loss='mean_squared_error')# 训练模型
model_reg.fit(X_reg_train, y_reg_train, epochs=100, batch_size=32, verbose=1)# 评估模型
loss_reg = model_reg.evaluate(X_reg_test, y_reg_test, verbose=0)
print(f'回归模型测试集损失: {loss_reg:.4f}')# 可视化回归结果
plt.scatter(X_reg, y_reg, color='blue', label='Data points')
plt.scatter(X_reg_test, model_reg.predict(X_reg_test), color='red', label='Predictions')
plt.title('Linear Regression Results')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.legend()
plt.show()# 创建逻辑回归模型
model_class = tf.keras.Sequential([tf.keras.layers.Dense(1, activation='sigmoid', input_shape=(2,))  # 输入特征为 2,输出为 1
])# 编译模型
model_class.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 训练模型
model_class.fit(X_class_train, y_class_train, epochs=100, batch_size=32, verbose=1)# 评估模型
loss_class, accuracy_class = model_class.evaluate(X_class_test, y_class_test, verbose=0)
print(f'分类模型测试集损失: {loss_class:.4f}, 测试集准确率: {accuracy_class:.4f}')# 可视化分类数据点
plt.scatter(X_class_train[y_class_train == 0][:, 0], X_class_train[y_class_train == 0][:, 1], color='blue', label='Class 0', alpha=0.5)
plt.scatter(X_class_train[y_class_train == 1][:, 0], X_class_train[y_class_train == 1][:, 1], color='red', label='Class 1', alpha=0.5)# 绘制决策边界
x_min, x_max = X_class[:, 0].min() - 1, X_class[:, 0].max() + 1
y_min, y_max = X_class[:, 1].min() - 1, X_class[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))
Z = model_class.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)plt.contourf(xx, yy, Z, levels=[0, 0.5, 1], alpha=0.2, colors=['blue', 'red'])
plt.title('Logistic Regression Decision Boundary')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.show()

make_classification 是 sklearn.datasets 模块中的一个函数,用于生成用于分类的合成数据集。可以通过不同的参数来控制生成数据的特性。

参数解释:n_samples: 生成的样本数量。
n_features: 特征的总数。设置为 2,表示每个样本有 2 个特征。
n_informative: 有效特征的数量,这些特征对分类任务有贡献。设置为 2,表示所有特征都是有效特征。
n_redundant: 冗余特征的数量,这些特征是通过线性组合生成的有效特征。设置为 0,表示没有冗余特征。
n_clusters_per_class: 每个类别的聚类数量。设置为 1,表示每个类别只有一个聚类。
random_state: 随机种子,用于确保结果的可重复性。设置为 42。

make_regression 是 sklearn.datasets 模块中的一个函数,用于生成用于回归的合成数据集。

参数解释:n_samples: 生成的样本数量。
n_features: 特征的总数。比如设置为 1,表示每个样本有 1 个特征。
n_informative: 有效特征的数量,这些特征对目标变量有贡献。比如设置为 1,表示所有特征都是有效特征。
n_targets: 目标变量的数量。默认值为 1,表示生成一个目标变量。
bias: 截距项,表示模型的偏置。可以设置为一个常数,比如 0。
noise: 添加到输出中的噪声的标准差。可以设置为一个浮点数,如 0.1表示添加一定的随机噪声。
random_state: 随机种子,用于确保结果的可重复性。可以设置为一个整数,比如 42。

相关文章:

python3+TensorFlow 2.x(二) 回归模型

目录 回归算法 1、线性回归 (Linear Regression) 一元线性回归举例 2、非线性回归 3、回归分类 回归算法 回归算法用于预测连续的数值输出。回归分析的目标是建立一个模型,以便根据输入特征预测目标变量,在使用 TensorFlow 2.x 实现线性回归模型时&…...

cpp智能指针

普通指针的不足 new和new[]的内存需要用delete和deletel]释放。 程序员的主观失误,忘了或漏了释放。 程序员也不确定何时释放。 普通指针的释放 类内的指针,在析构函数中释放。 C内置数据类型,如何释放? new出来的类,本身如…...

Android --- CameraX讲解

预备知识 surface surfaceView SurfaceHolder surface 是什么? 一句话来说: surface是一块用于填充图像数据的内存。 surfaceView 是什么? 它是一个显示surface 的View。 在app中仍在 ViewHierachy 中,但在wms 中可以理解为…...

CentOS7非root用户离线安装Docker及常见问题总结、各种操作系统docker桌面程序下载地址

环境说明 1、安装用户有sudo权限 2、本文讲docker组件安装,不是桌面程序安装 3、本文讲离线安装,不是在线安装 4、目标机器是内网机器,与外部网络不连通 下载 1、下载离线安装包,并上传到$HOME/basic-tool 目录 下载地址&am…...

前端面试笔试题目(一)

以下模拟了大厂前端面试流程,并给出了涵盖HTML、CSS、JavaScript等基础和进阶知识的前端笔试题目,以帮助你更好地准备面试。 面试流程模拟 1. 自我介绍(5 - 10分钟):面试官会请你进行简单的自我介绍,包括…...

笔记本搭配显示器

笔记本:2022款拯救者Y9000P,显卡RTX3060,分辨率2560*1600,刷新率:165Hz,无DP1.4口 显示器:2024款R27Q,27存,分辨率2560*1600,刷新率:165Hz &…...

设计转换Apache Hive的HQL语句为Snowflake SQL语句的Python程序方法

首先,根据以下各类HQL语句的基本实例和官方文档记录的这些命令语句各种参数设置,得到各种HQL语句的完整实例,然后在Snowflake的官方文档找到它们对应的Snowflake SQL语句,建立起对应的关系表。在这个过程中要注意HQL语句和Snowfla…...

DeepSeek R1 linux云部署

云平台:AutoDL 模型加载工具:Ollama 参考:https://github.com/ollama/ollama/blob/main/docs/linux.md 下载Ollama 服务器上下载ollama比较慢,因此我使用浏览器先下载到本地电脑上。 https://ollama.com/download/ollama-linux…...

【multi-agent-system】ubuntu24.04 安装uv python包管理器及安装依赖

uv包管理器是跨平台的 参考sudo apt-get update sudo apt-get install -y build-essential我的开发环境是ubuntu24.04 (base) root@k8s-master-pfsrv:/home/zhangbin/perfwork/01_ai/08_multi-agent-system# uv venv 找不到命令 “uv”,但可以通过以下软件...

UE5.3 C++ CDO的初步理解

一.UObject UObject是所有对象的基类,往上还有UObjectBaseUtility。 注释:所有虚幻引擎对象的基类。对象的类型由基于 UClass 类来定义。 这为创建和使用UObject的对象提供了 函数,并且提供了应在子类中重写的虚函数。 /** * The base cla…...

数学平均数应用

给定一个长度为 n 的数组 a。在一次操作中,你可以从索引 2 到 n−1中选择一个索引i,然后执行以下两个操作之一: 将 a[i−1] 减少 1,同时将 a[i1] 增加 1。 将 a[i1] 减少 1,同时将 a[i−1] 增加 1。 在每次操作后&…...

在排序数组中查找元素的第一个和最后一个位置(力扣)

一.题目介绍 二.题目解析 使用二分进行查找 2.1处理边界情况 如果数组为空,直接返回 [-1, -1],因为无法找到目标值。 int[] ret new int[2]; ret[0] ret[1] -1; if (nums.length 0) return ret; 2.2查找左端点(目标值开始位置&#…...

Native Memory Tracking 与 RSS的差异问题

一 问题现象 前一段时间用nmt查看jvm进程的栈区占用的内存大小。测试代码如下 public class ThreadOOM {public static void main(String[] args) {int i 1;while (i < 3000) {Thread thread new TestThread();thread.start();System.out.println("thread : "…...

完美世界前端面试题及参考答案

如何设置事件捕获和事件冒泡? 在 JavaScript 中,可以通过addEventListener方法来设置事件捕获和事件冒泡。该方法接收三个参数,第一个参数是事件类型,如click、mousedown等;第二个参数是事件处理函数;第三个参数是一个布尔值,用于指定是否使用事件捕获机制。当这个布尔值…...

知识库管理如何推动企业数字化转型与创新发展的深层次探索

内容概要 在当今数字化转型的大背景下&#xff0c;知识库管理日益显现出其作为企业创新发展的核心驱动力的潜力。这种管理方式不仅仅是对信息的存储与检索&#xff0c;更是一种赋能&#xff0c;以提升决策效率和员工创造力。企业能够通过系统地整合和管理各类知识资源&#xf…...

《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》

DeepSeek 网页/API 性能异常&#xff08;DeepSeek Web/API Degraded Performance&#xff09;订阅 已识别 - 已识别问题&#xff0c;并且正在实施修复。 1月 29&#xff0c; 2025 - 20&#xff1a;57 CST 更新 - 我们将继续监控任何其他问题。 1月 28&#xff0c; 2025 - 22&am…...

【性能优化专题系列】利用CompletableFuture优化多接口调用场景下的性能

背景说明 在实际的软件开发中&#xff0c;我们经常会遇到需要批量调用接口的场景。例如&#xff0c;电商系统在生成商品详情页时&#xff0c;需要同时调用多个服务接口来获取商品的基本信息、库存信息、价格信息、用户评价等。 传统的依次调用方式存在性能问题 面对上述场景…...

DeepSeek-R1本地部署笔记

文章目录 效果概要下载 ollama终端下载模型【可选】浏览器插件 UIQ: 内存占用高&#xff0c;显存占用不高&#xff0c;正常吗 效果 我的配置如下 E5 2666 V3 AMD 590Gme 可以说是慢的一批了&#xff0c;内存和显卡都太垃圾了&#xff0c;回去用我的新设备再试试 概要 安装…...

鸿蒙开发在onPageShow中数据加载不完整的问题分析与解决

API Version 12 1、onPageShow()作什么的 首先说明下几个前端接口的区别&#xff1a; ArkUI-X的aboutToAppear()接口是一个生命周期接口&#xff0c;用于在页面即将显示之前调用。 在ArkUI-X中&#xff0c;aboutToAppear()接口是一个重要的生命周期接口&#xff0c;它会在页…...

Kadane 算法

Kadane 算法 Kadane 算法用于解决最大子数组和问题,即在一个整数数组中找到具有最大和的连续子数组。此算法基于动态规划思想,在一次遍历过程中完成计算。 动态规划思路 核心在于维护两个变量:currentMax 表示当前子数组的最大和;globalMax 保存迄今为止发现的最大子数组…...

在Ubuntu子系统中基于Nginx部署Typecho

下载部署程序 typecho上传文件到子系统 创建文件夹typecho 在目录/var/www/html中创建一个目录typecho cd /var/www/html mkdir typecho将文件typecho.zip上传至新建的目录下&#xff0c;并解压文件 unzip typecho.zip授权文件夹 sudo chown -R www-data:www-data /var/www…...

C++中常用的十大排序方法之1——冒泡排序

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【&#x1f60a;///计算机爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于C中常用的排序方法之——冒泡排序的相关…...

不只是mini-react第二节:实现最简fiber

省流|总结 首先&#xff0c;我们编写JSX文件&#xff0c;并通过Babel等转换工具将其转化为createElement()函数的调用&#xff0c;最终生成虚拟 DOM&#xff08;Vdom&#xff09;格式。举个例子&#xff1a; // 原始 JSX const App <div>hi-mini-react</div>;//…...

python 使用Whisper模型进行语音翻译

目录 一、Whisper 是什么? 二、Whisper 的基本命令行用法 三、代码实践 四、是否保留Token标记 五、翻译长度问题 六、性能分析 一、Whisper 是什么? Whisper 是由 OpenAI 开源的一个自动语音识别(Automatic Speech Recognition, ASR)系统。它的主要特点是: 多语言…...

priority_queue的创建_结构体类型(重载小于运算符)c++

当优先级队列里面存的是一个自定义&#xff08;结构体&#xff09;类型&#xff0c;我们有两种方式&#xff0c;一个是用内置类型的方式&#xff0c;在priority_queue<>里写三个参数&#xff0c;比如int, vector<int>, less<int>&#xff0c;把int改成结构体…...

数据结构实战之线性表(一)

一.线性表的定义和特点 线性表的定义 线性表是一种数据结构&#xff0c;它包含了一系列具有相同特性的数据元素&#xff0c;数据元素之间存在着顺序关系。例如&#xff0c;26个英文字母的字符表 ( (A, B, C, ....., Z) ) 就是一个线性表&#xff0c;其中每个字母就是一个数据…...

Python学习之旅:进阶阶段(七)数据结构-计数器(collections.Counter)

在 Python 编程的进阶学习中,数据处理是一项重要的任务。collections.Counter作为 Python 标准库collections模块中的一员,为我们提供了一种高效且便捷的方式来统计数据出现的次数。接下来,就让我们一起深入了解这个强大的计数器。 一、什么是计数器 collections.Counter本…...

Spring Boot项目如何使用MyBatis实现分页查询及其相关原理

写在前面&#xff1a;大家好&#xff01;我是晴空๓。如果博客中有不足或者的错误的地方欢迎在评论区或者私信我指正&#xff0c;感谢大家的不吝赐教。我的唯一博客更新地址是&#xff1a;https://ac-fun.blog.csdn.net/。非常感谢大家的支持。一起加油&#xff0c;冲鸭&#x…...

【项目初始化】

项目初始化 使用脚手架创建项目Vite创建项目推荐拓展 使用脚手架创建项目 Vite Vite 是一个现代的前端构建工具&#xff0c;它提供了极速的更新和开发体验&#xff0c;支持多种前端框架&#xff0c;如 Vue、React 等创建项目 pnpm create vuelatest推荐拓展...

LeetCode热题100(八)—— 438.找到字符串中所有字母异位词

LeetCode热题100&#xff08;八&#xff09;—— 438.找到字符串中所有字母异位词 题目描述代码实现思路解析 你好&#xff0c;我是杨十一&#xff0c;一名热爱健身的程序员在Coding的征程中&#xff0c;不断探索与成长LeetCode热题100——刷题记录&#xff08;不定期更新&…...