DataWhale-西瓜书+南瓜书第3章线性模型学习总结-Task02-202110

2023-12-27 21:58

本文主要是介绍DataWhale-西瓜书+南瓜书第3章线性模型学习总结-Task02-202110,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

3.1 基本形式

样本\mathbf{x}=(x_1,x_2,\dots,x_d),其中x_i\mathbf{x}在第i个属性上的取值。线性模型试图学得一个通过属性得线性组合来进行预测得函数,即 

                                              \begin{equation} \begin{aligned} f(\mathbf{x})&=w_1x_1+w_2x_2+\dots+w_dx_d+b\\ &=\mathbf{w}^T\mathbf{x}+b \end{equation} \end{aligned}   

3.2 线性回归 

3.2.1 一元线性回归

均方误差\mathit{E}_{(w,b)}=\sum_{i=1}^m(y_i-wx_i-b)^2最小化,对w和b求导:

                                 \frac{\partial\mathit{E}}{\partial w}=2[w\sum_{i=1}^mx_i^2-\sum_{i=1}^m(y_i-b)x_i]

                                 \frac{\partial\mathit{E}}{\partial b}=2[mb-\sum_{i=1}^m(y_i-wx_i)]

上面两个方程等于0可以得到

                                w=\frac{\sum_{i=1}^m y_i(x_i-\bar{x})}{\sum_{i=1}^mx_i^2-\frac{1}{m}(\sum_{i=1}^mx_i)^2}

                                b=\frac{1}{m}\sum_{i=1}^m(y_i-wx_i)

3.2.2 多元线性回归

m个d个元素得示例,把数据集D表示为一个m\times(d+1)的大小的矩阵\mathbf{X}:

                 

 则均方误差为:

                             \mathit{E}_{\hat{\mathbf{w}}}=(\mathbf{y}-\mathbf{X}\hat{\mathbf{w}})^T(\mathbf{y}-\mathbf{X}\hat{\mathbf{w}})

\hat{\mathbf{w}}求导得到:

                               \frac{\partial\mathit{E_{\hat{\mathbf{w}}}}}{\partial\hat{\mathbf{w}}}=2\mathbf{X}^T(\mathbf{X}\hat{\mathbf{w}}-\hat{\mathbf{y}})

\mathbf{X}^T\mathbf{X}为满秩矩阵或正定矩阵时,上式为0可得:

                              \hat{\mathbf{w}}^*=(\mathbf{X^T}\mathbf{X})^{-1}\mathbf{X}^T\mathbf{y}

代码实现1:

import numpy as np
class LinearRegression:def __init__(self):self._theta = Noneself.intercept_ = Noneself.coef_ = Nonedef fit(self,x_train,y_train):X_b = np.hstack([np.ones((len(x_train),1)), x_train])self._theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y_train)self.intercept_ = self._theta[0]self.coef_ = self._theta[1:]return selfdef predict(self,x_predict):X_b = np.hstack([np.ones((len(x_predict),1)), x_predict])return X_b.dot(self._theta)

3.2.3 梯度下降法

因为

                                           \frac{\partial\mathit{E_{\hat{\mathbf{w}}}}}{\partial\hat{\mathbf{w}}}=2\mathbf{X}^T(\mathbf{X}\hat{\mathbf{w}}-\hat{\mathbf{y}})

所以

                                     \hat{\mathbf{w}}_{next}=\hat{\mathbf{w}}_{next}-\frac{\eta}{m}\mathbf{X}^T(\mathbf{X}\hat{\mathbf{w}}-\mathbf{y})

 代码实现如下:

import numpy as np alpha = 0.01def cost_function(theta, X, y):diff = np.dot(X, theta) - yreturn (1./(2*m)) * np.dot(np.transpose(diff), diff)def gradient_function(theta, X, y):diff = np.dot(X, theta) - yreturn (1./m) * np.dot(np.transpose(X), diff)def gradient_descent(X, y, alpha):theta = np.ones((X.shape[1]+1,1))gradient = gradient_function(theta, X, y)while not np.all(np.absolute(gradient) <= 1e-5):theta = theta - alpha * gradientgradient = gradient_function(theta, X, y)return theta

3.3 对数几率回归

对于二分类任务,

 使用对数几率函数可以得到:

                                                            y=\frac{1}{1+e^{-(\mathbf{w}^T\mathbf{x}+b)}}

变换后得到:

                                                           \ln\frac{y}{1-y}=\mathbf{w}^T\mathbf{x}+b

将y视为样本x作为正例的可能性,1-y是反例的可能性,则有

                                            p_1=p(y=1|x)=\frac{e^{w^Tx+b}}{1+e^{w^Tx+b}}

                                            p_0=p(y=0|x)=\frac{1}{1+e^{w^Tx+b}}  

为简便计算令\mathbf{\beta}=(\mathbf{w},b),\hat{\mathbf{x}}=(\mathbf{x},1),  对数回归模型的最大化似然函数为:

                                         \mathit{l}(\mathbf{\beta})=\sum_{i=1}^m\ln p(y_i|\hat{\mathbf{x}}_i,\mathbf{\beta})

带入p的表达式:

                         p(y_i|\hat{\mathbf{x}}_i,\mathbf{\beta})=y_ip_1(\hat{\mathbf{x}}_i,\mathbf{\beta})+(1-y_i)p_0(\hat{\mathbf{x}}_i,\mathbf{\beta})

可以得到:

                             \mathit{l}(\mathbf{\beta})=\sum_{i=1}^m[-y_i\beta^T\hat{\mathbf{x}}_i+\ln(1+e^{\beta^T\hat{\mathbf{x}}_i})]

利用这个表达式,可以用梯度下降法求解参数。

这篇关于DataWhale-西瓜书+南瓜书第3章线性模型学习总结-Task02-202110的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/544404

相关文章

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Spring 依赖注入与循环依赖总结

《Spring依赖注入与循环依赖总结》这篇文章给大家介绍Spring依赖注入与循环依赖总结篇,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Spring 三级缓存解决循环依赖1. 创建UserService原始对象2. 将原始对象包装成工

MySQL中查询和展示LONGBLOB类型数据的技巧总结

《MySQL中查询和展示LONGBLOB类型数据的技巧总结》在MySQL中LONGBLOB是一种二进制大对象(BLOB)数据类型,用于存储大量的二进制数据,:本文主要介绍MySQL中查询和展示LO... 目录前言1. 查询 LONGBLOB 数据的大小2. 查询并展示 LONGBLOB 数据2.1 转换为十

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

在Java中实现线程之间的数据共享的几种方式总结

《在Java中实现线程之间的数据共享的几种方式总结》在Java中实现线程间数据共享是并发编程的核心需求,但需要谨慎处理同步问题以避免竞态条件,本文通过代码示例给大家介绍了几种主要实现方式及其最佳实践,... 目录1. 共享变量与同步机制2. 轻量级通信机制3. 线程安全容器4. 线程局部变量(ThreadL

Spring Boot 与微服务入门实战详细总结

《SpringBoot与微服务入门实战详细总结》本文讲解SpringBoot框架的核心特性如快速构建、自动配置、零XML与微服务架构的定义、演进及优缺点,涵盖开发环境准备和HelloWorld实战... 目录一、Spring Boot 核心概述二、微服务架构详解1. 微服务的定义与演进2. 微服务的优缺点三