梯度下降(Gradient Descent)原理以及Python代码

2024-05-26 08:48

本文主要是介绍梯度下降(Gradient Descent)原理以及Python代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

给定一个函数f(x),我们想知道当x是值是多少的时候使这个函数达到最小值。为了实现这个目标,我们可以使用梯度下降(Gradient Descent)进行近似求解。

梯度下降是一个迭代算法,具体地,下一次迭代令

x_{n+1} = x_{n} - \eta {f}'(x_{n})

{f}'(x)是梯度,其中\eta是学习率(learning rate),代表这一轮迭代使用多少负梯度进行更新。梯度下降非常简单有效,但是其中的原理是怎么样呢?

原理

为什么每次使用负梯度进行更新呢?这要从泰勒公式(Taylor's formula)说起:

f(x) = f(x_{0}) + \frac{​{f}'(x_{0})}{1!}(x-x_{0}) + \frac{f{}''(x_{0})}{2!}(x-x_{0}) + ...

泰勒公式的目的是使用x-x_{0}的多项式去逼近函数f(x),这里可以理解泰勒公式在x-x_{0}的展开是原函数的一个近似函数。

那泰勒公式跟梯度下降有什么关系呢?

我们的目标是使f(x_{n+1})\leq f(x_{n}),我们对f(x_{n+1})x_{n}处进行一阶泰勒展开:

f(x_{n+1}) \approx f(x_{n}) + {f}'(x_{n})(x_{n+1}-x_{n})

由此可知,我们只需令x_{n+1}-x_{n} = -{f}'(x_{n}),就会使f(x_{n+1})\leq f(x_{n})

所以迭代公式可以为x_{n+1}= x_{n} -\eta {f}'(x_{n})

案例

下面我们看具体例子,假设我们有以下函数

f(x) = \frac{1}{2}\left \| Ax-b \right \|^2

矩阵和A向量b已知,我们想知道当x取值为多少的时候,函数f(x)的值最小。

根据梯度下降法,我们只需计算出负梯度,然给定一个初始值x_{0},不断迭代就能找到一个近似解了。负梯度计算如下:

{f}'(x) =A ^{T}(Ax-b)=A ^{T}Ax-A ^{T}b

接下来让我写一段代码解决这个问题

定义梯度下降函数

首先,定义cal_gradient函数用来计算梯度,然后使用gradient_decent进行迭代,其中learning_rate就是公式中的\eta,这个值需要合理设置,过大的话会导致震荡,过下的话又会导致迭代时间过长。step代表迭代的次数,理想情况下找到满意的解就停止。

我们会在代码中调整这两个参数查看它们对求解过程的影响。

import numpy as np
import time#calculate gradient
def cal_gradient(A, b, x):left = np.dot(np.dot(A.T, A), x)right = np.dot(A.T, b)gradient = left - rightreturn gradient# iteration
def gradient_decent(x, A, b, learning_rate, step):start = time.time()for i in range(step):gradient = cal_gradient(A, b, x)delta = learning_rate * gradientx = x - deltaend = time.time()time_cost = round(end - start, 4)print('done! x = {a}, time cost = {b}s'.format(a=x, b=time_cost))

求解过程

我们给了矩阵和A向量b的值以及标准答案 [29, 16, 3],然后我们随机初始化一个x_{0},让学习率\eta =0.01,迭代次数step=1000000

A = np.array([[1.0, -2.0, 1.0], [0.0, 2.0, -8.0], [-4.0, 5.0, 9.0]])
b = np.array([0.0, 8.0, -9.0])
# Giveb A and b,the solution x is [29, 16, 3]x0 = np.array([1.0, 1.0, 1.0])
learning_rate = 0.01
step = 1000000gradient_decent(x0, A, b, learning_rate, step)

结果

以下为结果,可以看出求得的近似解和标准答案 [29, 16, 3]还是非常接近的。

done! x = [28.98272933 15.99042465  2.99763054], time cost = 4.6037s

调整学习率

其他参数都一样,我们让学习率变小,运行相同的步数,从以下结果看到求得的近似解跟标准答案还有一定差距。这意味着小的过小学习率需要学习更久的时间。

learning_rate = 0.001# result
# done! x = [15.8048349   8.68422815  1.18968306], time cost = 4.5997s

调整初始值

我们只调整初始值,学习相同的步数,发现求得的近似解尽管与标准答案相似,但是不如第一个方法求得解。这说明梯度下降方法也会受到初始值得影响。

x0 = np.array([1000, 1000, 1000])# result
# done! x = [29.78036839 16.43265826  3.10706301], time cost = 4.5528s

总结

梯度下降方法是一种非常有效的优化方法,它的效果会受到初始值、学习率、步数的影响。如果要说缺点的话,就是它容易找到局部最优解,有时候会发生震荡现象。

 

参考

https://sm1les.com/2019/03/01/gradient-descent-and-newton-method/

这篇关于梯度下降(Gradient Descent)原理以及Python代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1003996

相关文章

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

Python中win32包的安装及常见用途介绍

《Python中win32包的安装及常见用途介绍》在Windows环境下,PythonWin32模块通常随Python安装包一起安装,:本文主要介绍Python中win32包的安装及常见用途的相关... 目录前言主要组件安装方法常见用途1. 操作Windows注册表2. 操作Windows服务3. 窗口操作

Python中re模块结合正则表达式的实际应用案例

《Python中re模块结合正则表达式的实际应用案例》Python中的re模块是用于处理正则表达式的强大工具,正则表达式是一种用来匹配字符串的模式,它可以在文本中搜索和匹配特定的字符串模式,这篇文章主... 目录前言re模块常用函数一、查看文本中是否包含 A 或 B 字符串二、替换多个关键词为统一格式三、提

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

python常用的正则表达式及作用

《python常用的正则表达式及作用》正则表达式是处理字符串的强大工具,Python通过re模块提供正则表达式支持,本文给大家介绍python常用的正则表达式及作用详解,感兴趣的朋友跟随小编一起看看吧... 目录python常用正则表达式及作用基本匹配模式常用正则表达式示例常用量词边界匹配分组和捕获常用re

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

python删除xml中的w:ascii属性的步骤

《python删除xml中的w:ascii属性的步骤》使用xml.etree.ElementTree删除WordXML中w:ascii属性,需注册命名空间并定位rFonts元素,通过del操作删除属... 可以使用python的XML.etree.ElementTree模块通过以下步骤删除XML中的w:as