深度学习与大模型第3课:线性回归模型的构建与训练

2024-09-07 22:36

本文主要是介绍深度学习与大模型第3课:线性回归模型的构建与训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 使用Python实现线性回归:从基础到scikit-learn
    • 1. 环境准备
    • 2. 数据准备和可视化
    • 3. 使用numpy实现线性回归
    • 4. 使用模型进行预测
    • 5. 可视化预测结果
    • 6. 使用scikit-learn实现线性回归
    • 7. 梯度下降法
    • 8. 随机梯度下降和小批量梯度下降
    • 9. 比较不同的梯度下降方法
    • 总结

使用Python实现线性回归:从基础到scikit-learn

线性回归是机器学习中最基础也是最重要的算法之一。本文将带领读者从基础的numpy实现,到使用成熟的scikit-learn库,全面了解线性回归的实现过程。我们将通过实际的代码示例和可视化来深入理解这个算法。

1. 环境准备

首先,让我们导入所需的库并设置环境:

from __future__ import division, print_function, unicode_literals
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import warnings
np.random.seed(42)
%matplotlib inline
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
warnings.filterwarnings(action="ignore", message="^internal gelsd")

这段代码导入了必要的库,设置了随机种子以确保结果可重现,并配置了matplotlib的一些参数。

2. 数据准备和可视化

假设我们已经有了训练数据X和y。让我们先来可视化这些数据:

plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([0, 2, 0, 15])
plt.show()

这将绘制一个散点图,展示我们的数据分布。

3. 使用numpy实现线性回归

现在,让我们使用numpy来手动实现线性回归:

X_b = np.c_[np.ones((100, 1)), X]  # 添加x0 = 1到每个实例
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)

这里,我们首先添加了一列1到X矩阵,然后使用正规方程计算最优的theta值。

4. 使用模型进行预测

有了theta_best,我们就可以进行预测了:

X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
y_predict = X_new_b.dot(theta_best)

5. 可视化预测结果

让我们把原始数据和预测结果可视化:

plt.plot(X_new, y_predict, "r-", linewidth=2, label="Predictions")
plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend(loc="upper left", fontsize=14)
plt.axis([0, 2, 0, 15])
plt.show()

这将绘制一个图,显示原始数据点和我们的预测线。

6. 使用scikit-learn实现线性回归

最后,让我们看看如何使用scikit-learn来实现相同的功能:

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X, y)
print("截距:", lin_reg.intercept_)
print("系数:", lin_reg.coef_)
# 预测
print("预测结果:", lin_reg.predict(X_new))

使用scikit-learn,我们只需要几行代码就可以完成模型的训练和预测。

7. 梯度下降法

除了使用正规方程,我们还可以使用梯度下降法来训练线性回归模型。以下是批量梯度下降的实现:

eta = 0.1  # 学习率
n_iterations = 1000
m = 100theta = np.random.randn(2,1)  # 随机初始化
for iteration in range(n_iterations):gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)theta = theta - eta * gradientsprint("梯度下降法得到的theta:", theta)

我们还可以可视化梯度下降的过程:

theta_path_bgd = []def plot_gradient_descent(theta, eta, theta_path=None):m = len(X_b)plt.plot(X, y, "b.")n_iterations = 1000for iteration in range(n_iterations):if iteration < 10:y_predict = X_new_b.dot(theta)style = "b-" if iteration > 0 else "r--"plt.plot(X_new, y_predict, style)gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)theta = theta - eta * gradientsif theta_path is not None:theta_path.append(theta)plt.xlabel("$x_1$", fontsize=18)plt.axis([0, 2, 0, 15])plt.title(r"$\eta = {}$".format(eta), fontsize=16)np.random.seed(42)
theta = np.random.randn(2,1)  # 随机初始化plt.figure(figsize=(10,4))
plt.subplot(131); plot_gradient_descent(theta, eta=0.02)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.subplot(132); plot_gradient_descent(theta, eta=0.1, theta_path=theta_path_bgd)
plt.subplot(133); plot_gradient_descent(theta, eta=0.5)
plt.show()

这段代码展示了不同学习率对梯度下降过程的影响。

8. 随机梯度下降和小批量梯度下降

除了批量梯度下降,我们还可以实现随机梯度下降(SGD)和小批量梯度下降:

# 随机梯度下降
theta_path_sgd = []
m = len(X_b)
np.random.seed(42)
n_epochs = 50
t0, t1 = 5, 50  # 学习率调度超参数def learning_schedule(t):return t0 / (t + t1)theta = np.random.randn(2,1)  # 随机初始化
for epoch in range(n_epochs):for i in range(m):random_index = np.random.randint(m)xi = X_b[random_index:random_index+1]yi = y[random_index:random_index+1]gradients = 2 * xi.T.dot(xi.dot(theta) - yi)eta = learning_schedule(epoch * m + i)theta = theta - eta * gradientstheta_path_sgd.append(theta)# 小批量梯度下降
theta_path_mgd = []
n_iterations = 50
minibatch_size = 20
np.random.seed(42)
theta = np.random.randn(2,1)  # 随机初始化
t0, t1 = 200, 1000def learning_schedule(t):return t0 / (t + t1)t = 0
for epoch in range(n_iterations):shuffled_indices = np.random.permutation(m)X_b_shuffled = X_b[shuffled_indices]y_shuffled = y[shuffled_indices]for i in range(0, m, minibatch_size):t += 1xi = X_b_shuffled[i:i+minibatch_size]yi = y_shuffled[i:i+minibatch_size]gradients = 2/minibatch_size * xi.T.dot(xi.dot(theta) - yi)eta = learning_schedule(t)theta = theta - eta * gradientstheta_path_mgd.append(theta)

9. 比较不同的梯度下降方法

最后,我们可以比较不同梯度下降方法的参数路径:

theta_path_bgd = np.array(theta_path_bgd)
theta_path_sgd = np.array(theta_path_sgd)
theta_path_mgd = np.array(theta_path_mgd)plt.figure(figsize=(7,4))
plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], "r-s", linewidth=1, label="Stochastic")
plt.plot(theta_path_mgd[:, 0], theta_path_mgd[:, 1], "g-+", linewidth=2, label="Mini-batch")
plt.plot(theta_path_bgd[:, 0], theta_path_bgd[:, 1], "b-o", linewidth=3, label="Batch")
plt.legend(loc="upper left", fontsize=16)
plt.xlabel(r"$\theta_0$", fontsize=20)
plt.ylabel(r"$\theta_1$   ", fontsize=20, rotation=0)
plt.axis([2.5, 4.5, 2.3, 3.9])
plt.show()

总结

在这篇博客中,我们学习了如何使用numpy手动实现线性回归,以及如何利用scikit-learn快速实现相同的功能。我们还深入探讨了不同的梯度下降方法,包括批量梯度下降、随机梯度下降和小批量梯度下降,并通过可视化比较了它们的性能。

通过这些实现和比较,我们不仅可以更深入地理解线性回归的原理,还能体会到使用成熟库的便利性,以及不同优化方法的特点。这些知识对于理解更复杂的机器学习算法和深度学习模型都是非常有帮助的。

希望这篇教程对你有所帮助!如果你有任何问题,欢迎在评论区留言。

这篇关于深度学习与大模型第3课:线性回归模型的构建与训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1146354

相关文章

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

Three.js构建一个 3D 商品展示空间完整实战项目

《Three.js构建一个3D商品展示空间完整实战项目》Three.js是一个强大的JavaScript库,专用于在Web浏览器中创建3D图形,:本文主要介绍Three.js构建一个3D商品展... 目录引言项目核心技术1. 项目架构与资源组织2. 多模型切换、交互热点绑定3. 移动端适配与帧率优化4. 可

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

Python利用PySpark和Kafka实现流处理引擎构建指南

《Python利用PySpark和Kafka实现流处理引擎构建指南》本文将深入解剖基于Python的实时处理黄金组合:Kafka(分布式消息队列)与PySpark(分布式计算引擎)的化学反应,并构建一... 目录引言:数据洪流时代的生存法则第一章 Kafka:数据世界的中央神经系统消息引擎核心设计哲学高吞吐

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Springboot项目构建时各种依赖详细介绍与依赖关系说明详解

《Springboot项目构建时各种依赖详细介绍与依赖关系说明详解》SpringBoot通过spring-boot-dependencies统一依赖版本管理,spring-boot-starter-w... 目录一、spring-boot-dependencies1.简介2. 内容概览3.核心内容结构4.

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?