一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】

本文主要是介绍一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

不知道大家入门上手机器学习项目是首先入坑的哪个深度学习框架,对于我来说,最先看到的听到的就是Tensorflow了,但是实际上手做项目开发的时候却发现了一个很重要的问题,不容易上手,基于原生的tf框架来直接开发模总是有不小的难度,后来发现了Keras,简直就是深度学习的福音,本质上来讲,Keras是对后端深度学习框架的高级封装。

框架介绍

当涉及深度学习和机器学习时,TensorFlow、PyTorch和JAX都是非常流行的框架。以下是它们的详细介绍以及各自的优点和缺点:

【TensorFlow】
TensorFlow 是由 Google Brain 团队开发的开源框架,可以实现深度学习和机器学习模型的构建和训练。
优点:

深度学习生态系统完善,支持常见的深度学习模型开发和部署。
支持分布式计算,适合大规模数据和模型训练。
提供了 TensorFlow Serving 等支持模型部署和服务化的工具。
在生产环境中表现稳定,有着广泛的产业应用。
缺点:

前端API相对较复杂,学习曲线较陡。
相比其他框架,可读性稍低,需要编写更多的代码。
在一些场景下的灵活性和易用性不如 PyTorch。
【PyTorch】
PyTorch 是 Facebook 开发的深度学习框架,也是一个开源项目。它和 TensorFlow 一样,可以用于实现深度学习算法。
优点:

前端API设计简洁直观,易于学习和使用。
动态计算图的特性使得调试更加方便,同时也更具灵活性。
提供了针对自然语言处理和计算机视觉等任务的高层抽象库,如 TorchText 和 TorchVision。
缺点:

缺乏在大规模分布式训练中的一些优化和工具支持。
对生产部署支持相对不足,相较 TensorFlow 在工业界的部署相对不如。
【JAX】
JAX 是 Google 开源的一个库,它提供了用于数值计算(尤其是自动微分和加速机器学习模型)的高性能Python接口。
优点:

采用 XLA (Accelerated Linear Algebra)进行加速,能够在 CPU 或 GPU 上自动并行化。
支持自动微分,可以用于构建自定义的优化器和模型。
完全可扩展,可以在大规模计算中实现高性能计算。
缺点:

相对 TensorFlow 和 PyTorch,社区和生态系统相对较小。
API 相对不够丰富,需要编写更多的自定义代码。
综上所述,TensorFlow、PyTorch 和 JAX 都有各自的优势和劣势,选择哪个框架要根据具体的需求和背景来决定。例如,如果需要处理大规模的生产环境应用,可能更倾向于选择 TensorFlow;而如果需要快速的原型开发和实验,PyTorch 是一个更好的选择。而 JAX 则适合于寻求高性能计算的研究和实验工作。

接下来我们来详细看官方本次发布的Keras3.0的详细介绍,官方介绍在这里,如下所示:

经过五个月的广泛公测,我们很高兴地宣布Keras 3.0的正式发布。Keras 3是对Keras的全面重写,使您能够在JAX、TensorFlow或PyTorch之上运行Keras工作流,并释放全新的大规模模型培训和部署功能。你可以选择最适合你的框架,并根据你当前的目标从一个框架切换到另一个框架。您还可以使用Keras作为一种低级的跨框架语言来开发自定义组件,如层、模型或度量,这些组件可以在JAX、TensorFlow或PyTorch中的本地工作流中使用——只需一个代码库。

欢迎使用多框架机器学习

您已经熟悉了使用Keras的好处——它通过专注于出色的UX、API设计和可调试性来实现高速开发。这也是一个经过战斗测试的框架,已被超过250万开发者选择,为世界上一些最复杂、规模最大的ML系统提供了动力,如Waymo自动驾驶车队和YouTube推荐引擎。但是使用新的多后端Keras 3还有什么额外的好处呢?

始终为您的模型获得最佳性能

在我们的基准测试中,我们发现JAX通常在GPU、TPU和CPU上提供最佳的训练和推理性能,但结果因模型而异,因为非XLA TensorFlow在GPU上偶尔会更快。能够动态选择为您的模型提供最佳性能的后端,而无需对代码进行任何更改,这意味着您可以保证以最高的效率进行培训和服务。

解锁您的模型的生态系统可选性

任何Keras 3模型都可以实例化为PyTorch模块,可以导出为TensorFlow SavedModel,也可以实例化为无状态JAX函数。这意味着您可以将Keras 3模型与PyTorch生态系统包、全套TensorFlow部署和生产工具(如TF Serving、TF.js和TFLite)以及JAX大规模TPU培训基础设施一起使用。使用Keras3API编写一个model.py,并访问ML世界所提供的一切。

利用JAX的大规模模型并行性和数据并行性。Keras3包含一个全新的发行版API,Keras.distribution名称空间,目前为JAX后端实现(即将在TensorFlow和PyTorch后端实现)。它使在任意模型规模和集群规模上进行模型并行、数据并行以及两者的组合变得容易。因为它使模型定义、训练逻辑和分片配置彼此分离,所以它使您的分发工作流易于开发和维护。请参阅我们的入门指南。

最大限度地扩大您的开源模型发布范围

想要发布一个预先训练好的模型吗?想要尽可能多的人能够使用它吗?如果你在纯TensorFlow或PyTorch中实现它,大约一半的社区都可以使用它。如果你在Keras 3中实现它,那么任何人都可以立即使用它,无论他们选择的框架是什么(即使他们自己不是Keras用户)。在不增加开发成本的情况下实现两倍的效果。

使用任何来源的数据管道

Keras 3 fit()/evaluatate()/predict()例程与tf.data.Dataset对象、PyTorch DataLoader对象、NumPy数组、Pandas数据帧兼容,无论您使用的是后端。您可以在PyTorch DataLoader上训练Keras 3+TensorFlow模型,也可以在tf.data.Dataset上训练Keras3+PyTorch模型。

完整的Keras API,可用于JAX、TensorFlow和PyTorch

Keras 3实现了完整的Keras API,并使其与TensorFlow、JAX和PyTorch一起使用-超过一百层、数十个度量、丢失函数、优化器和回调、Keras训练和评估循环以及Keras保存和序列化基础设施。所有你熟悉和喜爱的API都在这里。

任何只使用内置层的Keras模型都将立即与所有支持的后端一起工作。事实上,您现有的仅使用内置层的tf.keras模型可以立即在JAX和PyTorch中运行!没错,您的代码库刚刚获得了一组全新的功能。

编写多框架层、模型、度量

Keras 3使您能够创建在任何框架中都能正常工作的组件(如任意自定义层或预训练模型)。特别是,Keras3允许您访问跨所有后端工作的Keras.ops命名空间。它包含:

NumPy API的完整实现。不是类似NumPy的东西——只是字面上的NumPy API,具有相同的函数和相同的参数。您将获得ops.matmul、ops.sum、ops.stack、ops.einsum等。

NumPy中没有的一组特定于神经网络的函数,如ops.softmax、ops.binary_crossentropy、ops.cov等。

只要您只使用keras.ops中的操作,您的自定义层、自定义损失、自定义度量和自定义优化器将使用JAX、PyTorch和TensorFlow使用相同的代码。这意味着您只能维护一个组件实现(例如,一个model.py和一个检查点文件),并且您可以在所有框架中使用它,使用完全相同的数字。

…与任何JAX、TensorFlow和PyTorch工作流无缝配合

Keras 3不仅仅适用于以Keras为中心的工作流,您可以在其中定义Keras模型、Keras优化器、Keras损失和度量,并调用fit()、evaluate()和predict()。它还意味着可以与底层后端本机工作流无缝协作:您可以采用Keras模型(或任何其他组件,如损失或度量),并开始在JAX训练循环、TensorFlow训练循环或PyTorch训练循环中使用它,或者作为JAX或PyTorc模型的一部分,零摩擦。Keras3在JAX和PyTorch中提供了与tf.Keras之前在TensorFlow中所做的完全相同程度的底层实现灵活性。

您可以:

编写一个低级JAX训练循环,使用optax优化器JAX.grad、JAX.jit、JAX.pmap来训练Keras模型。

编写一个低级TensorFlow训练循环,使用tf.GradientTape和tf.distribute训练Keras模型。

编写一个低级PyTorch训练循环,使用torch.optim优化器、torch损失函数和torch.nn.parallel.DistributtedDataParallel包装器来训练Keras模型。

在PyTorch模块中使用Keras层(因为它们也是模块实例!)

在Keras模型中使用任何PyTorch模块,就好像它是Keras层一样。

等等

一种新的分布式API,用于大规模数据并行和模型并行

我们一直在研究的模型越来越大,所以我们想为多设备模型分片问题提供一个Kerasic解决方案。我们设计的API使模型定义、训练逻辑和分片配置完全分离,这意味着可以将模型编写为在单个设备上运行。然后,当需要对任意模型进行训练时,可以将任意分片配置添加到任意模型中。

数据并行性(在多个设备上相同地复制一个小模型)只需两行即可处理:

模型并行性使您可以沿着多个命名维度为模型变量和中间输出张量指定分片布局。在典型情况下,您可以将可用设备组织为二维网格(称为设备网格),其中第一个维度用于数据并行,第二个维度用于模型并行。然后,您可以将模型配置为沿模型维度进行分片,并沿数据维度进行复制。

API允许您通过正则表达式配置每个变量和每个输出张量的布局。这样可以很容易地为整个变量类别快速指定相同的布局。

新的发行版API旨在成为多后端,但目前仅适用于JAX后端。TensorFlow和PyTorch的支持即将到来。开始使用此指南!

预训练模型

有一系列预先训练好的模型,您今天可以在Keras 3中开始使用。

所有40个Keras应用程序模型(Keras.Applications命名空间)在所有后端都可用。KerasCV和KerasNLP中的大量预训练模型也适用于所有后端。这包括:

支持所有后端的跨框架数据管道

多框架ML也意味着多框架数据的加载和预处理。Keras 3模型可以使用广泛的数据管道进行训练——无论您使用的是JAX、PyTorch还是TensorFlow后端。它只是起作用。

tf.data.Dataset管道:可扩展生产ML的参考。

torch.utils.data.DataLoader对象。

NumPy数组和Pandas数据帧。

Keras自己的Keras.utils.PyDataset对象。

复杂性的逐步披露

复杂性的渐进披露是Keras API核心的设计原则。Keras不会强迫你遵循一种“真正”的方式来构建和训练模型。相反,它支持各种不同的工作流,从非常高级到非常低级,对应于不同的用户配置文件。

这意味着您可以从简单的工作流开始,例如使用Sequential和Functional模型,并使用fit()对它们进行训练。当您需要更大的灵活性时,您可以轻松地自定义不同的组件,同时重用大多数以前的代码。随着你的需求变得更加具体,你不会突然从复杂性的悬崖上跌落,也不需要切换到不同的工具集。

我们把这个原则带到了我们所有的后台。例如,您可以自定义训练循环中发生的事情,同时仍然利用fit()的功能,而不必从头开始编写自己的训练循环——只需重写train_step方法。

以下是它在PyTorch和TensorFlow中的工作方式:

这是JAX版本的链接。

class CustomModel(keras.Model):def compute_loss_and_updates(self,trainable_variables,non_trainable_variables,x,y,training=False,):y_pred, non_trainable_variables = self.stateless_call(trainable_variables,non_trainable_variables,x,training=training,)loss = self.compute_loss(x, y, y_pred)return loss, (y_pred, non_trainable_variables)def train_step(self, state, data):(trainable_variables,non_trainable_variables,optimizer_variables,metrics_variables,) = statex, y = data# Get the gradient function.grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)# Compute the gradients.(loss, (y_pred, non_trainable_variables)), grads = grad_fn(trainable_variables,non_trainable_variables,x,y,training=True,)# Update trainable variables and optimizer variables.(trainable_variables,optimizer_variables,) = self.optimizer.stateless_apply(optimizer_variables, grads, trainable_variables)# Update metrics.new_metrics_vars = []for metric in self.metrics:this_metric_vars = metrics_variables[len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)]if metric.name == "loss":this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)else:this_metric_vars = metric.stateless_update_state(this_metric_vars, y, y_pred)logs = metric.stateless_result(this_metric_vars)new_metrics_vars += this_metric_vars# Return metric logs and updated state variables.state = (trainable_variables,non_trainable_variables,optimizer_variables,new_metrics_vars,)return logs, state

新的无状态API,用于层、模型、度量和优化器

你喜欢函数式编程吗?你会得到款待的。

Keras中的所有有状态对象(即拥有在训练或评估过程中更新的数值变量的对象)现在都有一个无状态的API,从而可以在JAX函数中使用它们(要求完全无状态):

所有层和模型都有一个stateless_call()方法,该方法镜像__call__()。所有优化器都有一个stateless_apply()方法,该方法镜像apply(()。所有度量都有一个镜像update_state()的stateless_update_state)方法和一个镜像result()的stateless_sult()方法。

这些方法没有任何副作用:它们将目标对象的状态变量的当前值作为输入,并将更新值作为输出的一部分返回,例如:

outputs, updated_non_trainable_variables = layer.stateless_call(trainable_variables,non_trainable_variables,inputs,
)

您永远不必自己实现这些方法——只要您实现了有状态版本(例如call()或update_state()),它们就会自动可用。

从Keras 2移动到Keras 3

Keras 3与Keras 2高度向后兼容:它实现了Keras 2的完整公共API表面,这里列出了有限的例外情况。大多数用户不需要进行任何代码更改就可以在Keras3上开始运行他们的Keras脚本。

较大的代码库可能需要一些代码更改,因为它们更有可能遇到上面列出的异常之一,并且更有可能使用私有API或不推荐使用的API(tf.compat.v1.keras命名空间、实验命名空间、keras.src私有命名空间)。为了帮助您迁移到Keras 3,我们发布了一份完整的迁移指南,其中包含您可能遇到的所有问题的快速修复程序。

您还可以选择忽略Keras 3中的更改,只需将Keras 2与TensorFlow一起使用——这对于那些没有积极开发但需要使用更新的依赖关系继续运行的项目来说是一个很好的选择。您有两种可能性:

1、如果您将keras作为一个独立的包访问,只需切换到使用Python包tf_keras即可,您可以通过pip-install-tf_keras安装该包。代码和API完全没有变化——它是Keras 2.15,具有不同的包名称。我们将继续修复tf_keras中的错误,并定期发布新版本。但是,由于该软件包现在处于维护模式,因此不会添加任何新功能或性能改进。

2、如果您通过tf.keras访问keras,那么在TensorFlow 2.16之前不会立即发生更改。TensorFlow 2.16+默认情况下将使用Keras 3。在TensorFlow 2.16+中,要继续使用Keras 2,可以先安装tf_Keras,然后导出环境变量tf_USE_LEGACY_Keras=1。这将指导TensorFlow 2.16+将tf.keras解析为本地安装的tf_keras包。请注意,这可能影响的不仅仅是您自己的代码:它将影响Python进程中导入tf.keras的任何包。为了确保您的更改只影响您自己的代码,您应该使用tf_keras包。

常见问题解答

Q: Keras 3是否与旧版Keras 2兼容?

使用tf.keras开发的代码通常可以像使用keras 3(使用TensorFlow后端)一样运行。您应该注意的不兼容性数量有限,所有这些都在本迁移指南中介绍。

当涉及到同时使用来自tf.keras和keras 3的API时,这是不可能的——它们是不同的包,运行在完全不同的引擎上。

Q: 在旧版Keras 2中开发的预训练模型是否适用于Keras 3?

一般来说,是的。任何tf.keras模型都应该使用带有TensorFlow后端的keras 3(确保以.keras v3格式保存)。此外,如果模型只使用内置的Keras层,那么它也可以在带有JAX和PyTorch后端的Keras 3中开箱即用。

如果模型包含使用TensorFlow API编写的自定义层,则通常很容易将代码转换为后端无关的。例如,我们只花了几个小时就将keras应用程序中的所有40个遗留tf.keras模型转换为后端不可知的模型。

Q: 我可以在一个后端保存一个Keras 3模型并在另一个后端重新加载它吗?

是的,你可以。在保存的.keras文件中没有后端专门化。您保存的Keras模型与框架无关,可以使用任何后端重新加载。

但是,请注意,重新加载包含具有不同后端的自定义组件的模型需要使用与后端无关的API(例如keras.ops)来实现自定义组件。

Q: 我可以在tf.data管道中使用Keras 3组件吗?

对于TensorFlow后端,Keras 3与tf.data完全兼容(例如,您可以将序列模型映射到tf.data管道中)。

使用不同的后端,Keras 3对tf.data的支持有限。您将无法将任意层或模型映射到tf.data管道中。但是,您可以将特定的Keras 3预处理层与tf.data一起使用,例如IntegerLookup或CategoryEncoding。

当涉及到使用tf.data管道(不使用Keras)来提供对.fit()、.evaluate()或.predict()的调用时,所有后端都是现成的。

Q: Keras 3型号在不同后端运行时表现相同吗?

是的,后端的数字是相同的。但是,请记住以下注意事项:

RNG行为在不同的后端之间是不同的(即使在种子设定之后-您的结果在每个后端都是确定的,但在后端之间是不同的)。所以随机权重初始化值和退出值在后端会有所不同。

由于浮点实现的性质,在float32中,每个函数执行的结果在1e-7精度以内是相同的。因此,当长时间训练一个模型时,微小的数值差异会积累起来,最终可能导致显著的数值差异。

由于PyTorch中缺少对使用非对称填充的平均池的支持,使用padding=“same”的平均池层可能会导致边框行/列上的数字不同。这在实践中并不经常发生——在40个Keras应用程序视觉模型中,只有一个受到影响。

Q: Keras 3是否支持分布式训练?

JAX、TensorFlow和PyTorch支持数据并行分发。JAX通过keras.distribution API支持模型并行分发。

使用TensorFlow:

Keras3与tf.distribute兼容——只需打开一个Distribution策略范围并在其中创建/训练您的模型。

使用PyTorch:

Keras 3与PyTorch的DistributedDataParallel实用程序兼容。这里有一个例子。

使用JAX:

您可以使用keras.distribution API在JAX中进行数据并行和模型并行分发。例如,要进行数据并行分发,只需要以下代码段:

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

有关模型并行分布,请参见以下指南。

您还可以通过JAX.sharding等JAX API分发培训。

Q: 我的自定义Keras层是否可以用于本地PyTorch模块或亚麻模块?

如果它们只使用kerasapi(例如Keras.ops名称空间)编写,那么是的,您的Keras层将使用本机PyTorch和JAX代码直接工作。在PyTorch中,只需像其他PyTorch模块一样使用Keras层。在JAX中,确保使用无状态层API,即layer.statese\u call()。

Q: 您将来会添加更多后端吗?那么框架XYZ呢?

我们愿意添加新的后端,只要目标框架有一个大的用户群,或者有一些独特的技术优势。然而,添加和维护一个新的后端是一个很大的负担,所以我们将仔细考虑每个新的后端候选人在个案的基础上,我们不太可能添加许多新的后端。我们不会添加任何尚未完善的新框架。我们现在可能会考虑添加一个用Mojo编写的后端。如果这是一些你可能会发现有用的东西,请让Mojo团队知道。

这篇关于一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/437426

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

C++ HTTP框架推荐(特点及优势)

《C++HTTP框架推荐(特点及优势)》:本文主要介绍C++HTTP框架推荐的相关资料,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Crow2. Drogon3. Pistache4. cpp-httplib5. Beast (Boos

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

SpringBoot基础框架详解

《SpringBoot基础框架详解》SpringBoot开发目的是为了简化Spring应用的创建、运行、调试和部署等,使用SpringBoot可以不用或者只需要很少的Spring配置就可以让企业项目快... 目录SpringBoot基础 – 框架介绍1.SpringBoot介绍1.1 概述1.2 核心功能2

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

MyBatis分页插件PageHelper深度解析与实践指南

《MyBatis分页插件PageHelper深度解析与实践指南》在数据库操作中,分页查询是最常见的需求之一,传统的分页方式通常有两种内存分页和SQL分页,MyBatis作为优秀的ORM框架,本身并未提... 目录1. 为什么需要分页插件?2. PageHelper简介3. PageHelper集成与配置3.