当前位置: 首页 > news >正文

3 Tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

本篇目标是介绍如何构建一个简单的线性回归模型,要点如下:

  • 了解神经网络原理
  • 构建模型的一般步骤
  • 模型重要参数介绍


1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

首先我们要了解下密集层(也叫全连接层),密集层是一个深度连接的神经网络层,在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

在上一篇我们创建了预测价格模型,代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

其中Sequential是顺序的意思,Dense就是密集层。

看文字有点抽象,举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连,这样形成的神经网络就是密集层。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

了解这些基本的原理后,我们就开始创建一个简单的费用预测模型。

2、西瓜费用预测模型详解

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

具体流程如下:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)训练模型
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

3、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])
  • 输入层有3个变量,input_shape=[3]
  • 隐藏层有2个神经元,所以 tf.keras.layers.Dense(2, input_shape=[3]) 的units设为2
  • 输出层只有1个神经元,所以 tf.keras.layers.Dense(1) 的units设为1
  • tf.keras.Sequential的‘Sequential’是顺序的意思,添加的这些layers就按顺序堆叠

         

相关文章:

3 Tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客 本篇目标是介绍如何构建一个简单的线性回归模型,要点如下: 了解神经网络原理构建模型的一般步骤模型重要参数介绍 1、神经网络概念 接上一篇,用tensorflow写了一个猜测西…...

智慧农场牧场小程序源码 智慧农业认养系统源码

智慧农场牧场小程序源码 智慧农业认养系统源码 要了解源码的,看文末。 随着科技的进步和人们对绿色食品的需求增加,智慧农场正成为未来农业发展的方向。智慧农场是指运用先进的技术手段,如物联网、云计算、智能控制技术、大数据分析等&…...

3D数据过滤为2D数据集并渲染

开发环境: Windows 11 家庭中文版Microsoft Visual Studio Community 2019VTK-9.3.0.rc0vtk-example参考代码 代码逻辑:初始化数据集points -> 添加数据集到polydata -> 通过vtkVertexGlyphFilter过滤(带顶点、单元数据)po…...

第十一章 ObjectScript 系统宏(二)

文章目录 第十一章 ObjectScript 系统宏(二) 宏引用FormatText(text, arg1, arg2, ...)FormatTextHTML(text, arg1, arg2, ...)FormatTextJS(text, arg1, arg2, ...)GETERRORCODE(sc)GETERRORMESSAGE(sc,num)ISERR(sc)ISOK(sc)Text(text, domain, langua…...

跨境电商大作战:2023黑色星期五准备指南

黑色星期五,作为全球购物狂欢的象征,已经成为了电商业务的一年一度的重要节点。尤其对于跨境电商来说,这一天意味着巨大的商机和挑战。为了在这个竞争激烈的时刻脱颖而出,跨境电商必须做好充分的准备。Nox聚星在这里给大家分享几个…...

我的天!阿里云服务器居然比腾讯云优惠1元!

2023阿里云服务器优惠活动来了,以前一直是腾讯云比阿里云优惠,阿里云绝地反击,放开老用户购买资格,99元服务器老用户可以买,并且享受99元续费,阿腾云亲测可行,大家抓紧吧,数量不多&a…...

鸡尾酒学习——未命名(芒果口味)

1、材料:冰块、伏特加、芒果汁、元气森林卡曼橘味; 2、口感:芒果味道,酸甜为主,苦为辅。 3、视觉效果:黄色液体; 4、步骤: (1)向杯子中加入适量冰块&#xff…...

modbusTCP【C#】

为了编写一个完整的Modbus TCP库,您需要遵循以下步骤: 1. 安装NModbus4库:NModbus4是一个用于C#的Modbus库,它支持串口和TCP通信。您可以通过NuGet包管理器安装它。 2. 创建Modbus主机:使用ModbusIpMaster.CreateIp方…...

解决Linux Debian12系统中安装VirtualBox虚拟机无法使用USB设备的问题

Debian12系统中安装VirtualBox,再VirtualBox虚拟机中无法使用 USB设备。如下图所示: 解决方法如下: 1.安装 Virtualbox增强功能。如下图所示: 2.添加相关用户、用户组( Virtualbox 装完成后会有 vboxusers 和 vboxs…...

Spring事务失效的几种情况及其解决方案

Spring事务失效的几种情况及其解决方案 方法权限修饰符不是public Transactional 使用的是 Spring AOP 实现的,而 Spring AOP 是通过动态代理实现的,而 Transactional 在生成代理时会判断,如果方法为非 public 修饰的方法,则不生…...

libgdx实现淡入淡出过渡

libgdx实现淡入淡出过渡 libgdx实现淡入淡出过渡&#xff0c;环境jdk17、libgdx 1.12.02023年11月1日11:02:50最新 依赖 <properties><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target>&…...

linux 出现Access-Your-Private-Data.desktop README.txt

参考:https://blog.csdn.net/h66295112/article/details/81085643 参考:https://askubuntu.com/questions/71708/how-do-i-open-access-your-private-data-desktop 原因应该是通过terminal修改了ubuntu密码&#xff0c;然后重启 THIS DIRECTORY HAS BEEN UNMOUNTED TO PROTECT…...

新生儿积食:原因、科普和注意事项

引言&#xff1a; 新生儿积食&#xff0c;也被称为新生儿喂养问题&#xff0c;是新父母常常面临的挑战之一。尽管它通常是一种暂时的问题&#xff0c;但它可能会引起婴儿的不适&#xff0c;导致家长感到担忧。本文将科普新生儿积食的原因&#xff0c;提供相关信息&#xff0c;…...

看完这个,别说你还找不到免费好用的配音软件

有很多小伙伴还在找配音工具&#xff0c;今天就给大家一次性分享四款免费好用的配音工具&#xff0c;每一个都经过测试&#xff0c;并且是我们自己也在用的免费配音工具 第一款&#xff0c;悦音配音工具 拥有强悍的AI智能配音技术&#xff0c;更专业&#xff0c;完美贴近真人配…...

多种方法解决leetcode经典题目-LCR 155. 将二叉搜索树转化为排序的双向链表, 同时弄透引用变更带来的bug

1 描述 2 解法一: 使用list列表粗出中序遍历的结果&#xff0c;然后再依次处理list中的元素并且双向链接 public Node treeToDoublyList2(Node root) {if(rootnull)return root;Node dummynew Node(-10000);List<Node>ansnew ArrayList<>();dfs2(root,ans);Node p…...

C/C++ 实现UDP发送或接收组播消息,并可指定接收发送网卡

一、发送端代码 #include <iostream> #include <unistd.h> #include <stdio.h> #include <string.h> #include <net/if.h> #include <netinet/in.h> #include <netdb.h> #include <sys/ioctl.h> #include "UDPOperation…...

纬创出售印度子公司给塔塔集团,结束iPhone代工业务 | 百能云芯

纬创&#xff08;Wistron&#xff09;董事会于10月27日通过决议&#xff0c;同意以1.25亿美元的价格出售其印度子公司Wistron InfoComm Manufacturing (India) Private Limited&#xff08;WMMI&#xff09;的100%股权给塔塔集团&#xff0c;交割将尽快完成。此举将意味着纬创退…...

vue手机项目如何控制手电筒打开与关闭

要控制手电筒&#xff0c;您可以使用Vue的Device API&#xff0c;例如cordova-plugin-flashlight或vue-native-flashlight插件。以下是一些基本步骤&#xff1a; 导入手电筒插件或库。在Vue组件中创建一个手电筒对象并初始化它。使用turnOn()和turnOff()方法控制手电筒。 以下…...

电商课堂|5分钟了解电商数据分析完整流程,建议收藏!

账户效果下降&#xff0c;如何能够快速找到问题并优化调整&#xff1f; 相信百分之90%的竞价员都会说&#xff1a;“做数据分析。” 没错&#xff0c;数据分析能够帮助我们快速锁定问题所在&#xff0c;确定优化方向&#xff0c;还可以帮助我们找到流量控制的方向。那么做电商&…...

Redis测试新手入门教程

在测试过程中&#xff0c;我们或多或少会接触到Redis&#xff0c;今天就把在小破站看到的三丰老师课程&#xff0c;把笔记整理了下&#xff0c;用来备忘&#xff0c;也希望能给大家带来亿点点收获。 主要分为两个部分&#xff1a; 一、缓存技术在后端架构中是如何应用的&#…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义&#xff08;Task Definition&…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

FastAPI 教程:从入门到实践

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;支持 Python 3.6。它基于标准 Python 类型提示&#xff0c;易于学习且功能强大。以下是一个完整的 FastAPI 入门教程&#xff0c;涵盖从环境搭建到创建并运行一个简单的…...

2021-03-15 iview一些问题

1.iview 在使用tree组件时&#xff0c;发现没有set类的方法&#xff0c;只有get&#xff0c;那么要改变tree值&#xff0c;只能遍历treeData&#xff0c;递归修改treeData的checked&#xff0c;发现无法更改&#xff0c;原因在于check模式下&#xff0c;子元素的勾选状态跟父节…...

TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案

一、TRS收益互换的本质与业务逻辑 &#xff08;一&#xff09;概念解析 TRS&#xff08;Total Return Swap&#xff09;收益互换是一种金融衍生工具&#xff0c;指交易双方约定在未来一定期限内&#xff0c;基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…...

C++:多态机制详解

目录 一. 多态的概念 1.静态多态&#xff08;编译时多态&#xff09; 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1&#xff09;.协变 2&#xff09;.析构函数的重写 5.override 和 final关键字 1&#…...