02 TensorFlow 2.0:前向传播之张量实战

2024-06-23 21:48

本文主要是介绍02 TensorFlow 2.0:前向传播之张量实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

你是前世未止的心跳
你是来生胸前的记号
未见分晓
怎么把你忘掉
                                                                                                                                《千年》

内容覆盖:

  • convert to tensor
  • reshape
  • slices
  • broadcast (mechanism)
import tensorflow as tf
print(tf.__version__)import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'import warnings
warnings.filterwarnings('ignore')from tensorflow import keras
from tensorflow.keras import datasets
2.0.0-alpha0

1. global constants setting

lr = 1e-3
epochs = 10

2. load data and tensor object 0-1

## load mnist data
# x: [6w, 28, 28]
# y: [6w]
(x,y),_ = datasets.mnist.load_data()
## x: 0-255. => 0-1.
x = tf.convert_to_tensor(x, dtype=tf.float32)/255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape)
print(tf.reduce_max(x), tf.reduce_min(x))
print(tf.reduce_max(y), tf.reduce_min(y))
(60000, 28, 28) (60000,)
tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(9, shape=(), dtype=int32) tf.Tensor(0, shape=(), dtype=int32)

3. split batch

## split batches
# x: [128, 28, 28]
# y: [128, 28, 28]
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(128)
train_iter_ = iter(train_db)
sample_ = next(train_iter_)
print('first batch & next batch:', sample_[0].shape, len(sample), sample_[1])
first batch & next batch: (96, 784) 2 tf.Tensor( [3 4 5 6 7 8 9 0 1 2 3 4 8 9 0 1 2 3 4 5 6 7 8 9 6 0 3 4 1 4 0 7 8 7 7 9 0 4 9 4 0 5 8 5 9 8 8 4 0 7 1 3 5 3 1 6 5 3 8 7 3 1 6 8 5 9 2 2 0 9 2 4 6 7 3 1 3 6 6 2 1 2 6 0 7 8 9 2 9 5 1 8 3 5 6 8], shape=(96,), dtype=int32)

4. parameters init

## parameters init. in order to adapt below GradientTape(),parameters must to be tf.Variable
w1 = tf.Variable(tf.random.truncated_normal([28*28, 256], stddev=0.1)) # truncated normal init
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

5. compute(update) loss&gradient for each epoch&batch

## for each epoch
for epoch in range(epochs):## for each batchfor step, (x, y) in enumerate(train_db): # x: [b, 28, 28] => [b, 28*28]x = tf.reshape(x, [-1, 28*28])## compute forward output for each batchwith tf.GradientTape() as tape: # GradientTape below parameters must be tf.Variable# print(x.shape, w1.shape, b1.shape)h1 = x@w1 + b1 # implicitly,b1 ([256]) broadcast_to [b,256]h1 = tf.nn.relu(h1)h2 = h1@w2 + b2 # like aboveh2 = tf.nn.relu(h2)h3 = h2@w3 + b3 # like aboveout = tf.nn.relu(h3)## copute lossy_onehot = tf.one_hot(y, depth=10)loss = tf.reduce_mean(tf.square(y_onehot - out)) # loss is scalar## compute gradientsgrads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])# update parametersw1.assign_sub(lr*grads[0])b1.assign_sub(lr*grads[1])w2.assign_sub(lr*grads[2])b2.assign_sub(lr*grads[3])w3.assign_sub(lr*grads[4])b3.assign_sub(lr*grads[5])if step%100==0:print('epoch/step:', epoch, step,'loss:', float(loss))
epoch/step: 0 0 loss: 0.18603835999965668
epoch/step: 0 100 loss: 0.13570542633533478
epoch/step: 0 200 loss: 0.11861399561166763
epoch/step: 0 300 loss: 0.11322200298309326
epoch/step: 0 400 loss: 0.10488209873437881
epoch/step: 1 0 loss: 0.10238083451986313
epoch/step: 1 100 loss: 0.10504438728094101
epoch/step: 1 200 loss: 0.10291490703821182
epoch/step: 1 300 loss: 0.10242557525634766
epoch/step: 1 400 loss: 0.09785071760416031
epoch/step: 2 0 loss: 0.09843370318412781
epoch/step: 2 100 loss: 0.10121582448482513
epoch/step: 2 200 loss: 0.0993235856294632
epoch/step: 2 300 loss: 0.09929462522268295
epoch/step: 2 400 loss: 0.09492874145507812
epoch/step: 3 0 loss: 0.09640722721815109
epoch/step: 3 100 loss: 0.09940245747566223
epoch/step: 3 200 loss: 0.0968528538942337
epoch/step: 3 300 loss: 0.09739632904529572
epoch/step: 3 400 loss: 0.09268360584974289
epoch/step: 4 0 loss: 0.09469369798898697
epoch/step: 4 100 loss: 0.09802170842885971
epoch/step: 4 200 loss: 0.09442965686321259
epoch/step: 4 300 loss: 0.09557832777500153
epoch/step: 4 400 loss: 0.09028112888336182
epoch/step: 5 0 loss: 0.09288302809000015
epoch/step: 5 100 loss: 0.09671110659837723
epoch/step: 5 200 loss: 0.09200755506753922
epoch/step: 5 300 loss: 0.09379477798938751
epoch/step: 5 400 loss: 0.0879468247294426
epoch/step: 6 0 loss: 0.09075240045785904
epoch/step: 6 100 loss: 0.09545578807592392
epoch/step: 6 200 loss: 0.08961271494626999
epoch/step: 6 300 loss: 0.09208488464355469
epoch/step: 6 400 loss: 0.08578769862651825
epoch/step: 7 0 loss: 0.08858789503574371
epoch/step: 7 100 loss: 0.09415780007839203
epoch/step: 7 200 loss: 0.08701150119304657
epoch/step: 7 300 loss: 0.09043200314044952
epoch/step: 7 400 loss: 0.08375751972198486
epoch/step: 8 0 loss: 0.08612515032291412
epoch/step: 8 100 loss: 0.09273834526538849
epoch/step: 8 200 loss: 0.08432737737894058
epoch/step: 8 300 loss: 0.08866600692272186
epoch/step: 8 400 loss: 0.08179832994937897
epoch/step: 9 0 loss: 0.08383172750473022
epoch/step: 9 100 loss: 0.09108485281467438
epoch/step: 9 200 loss: 0.08158060908317566
epoch/step: 9 300 loss: 0.08686531335115433
epoch/step: 9 400 loss: 0.0796399861574173

6. notice

  • 训练出来loss为nan或者不变等情况
    可能出现梯度爆炸等情况,这里可能需要 change parameter init等,比如这里利用 m u = 0 , s t d = 0.1 mu=0, std=0.1 mu=0,std=0.1截尾normal初始化权重 w w w
    参见一些解释:为什么用tensorflow训练网络,出现了loss=nan,accuracy总是一个固定值?

这篇关于02 TensorFlow 2.0:前向传播之张量实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088366

相关文章

Spring事务传播机制最佳实践

《Spring事务传播机制最佳实践》Spring的事务传播机制为我们提供了优雅的解决方案,本文将带您深入理解这一机制,掌握不同场景下的最佳实践,感兴趣的朋友一起看看吧... 目录1. 什么是事务传播行为2. Spring支持的七种事务传播行为2.1 REQUIRED(默认)2.2 SUPPORTS2

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(

Java Web实现类似Excel表格锁定功能实战教程

《JavaWeb实现类似Excel表格锁定功能实战教程》本文将详细介绍通过创建特定div元素并利用CSS布局和JavaScript事件监听来实现类似Excel的锁定行和列效果的方法,感兴趣的朋友跟随... 目录1. 模拟Excel表格锁定功能2. 创建3个div元素实现表格锁定2.1 div元素布局设计2.