java(kotlin) ai框架djl

2024-06-13 09:28
文章标签 java ai 框架 kotlin djl

本文主要是介绍java(kotlin) ai框架djl,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl--><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>0.28.0</version></dependency><!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘│┌─────────▼─────────┐│       Engine      │└───────┬─┬─────────┘│ │┌───────▼─▼─────────┐│     NDManager     │└───────┬─┬─────────┘│ │┌─────────▼─▼───────────┐│    Dataset └─────────┬─────────────┘│┌─────────▼─────────────┐│  Trainer / Predictor  │└───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;import java.io.IOException;
      import java.nio.file.Paths;public class DjlExample {public static void main(String[] args) throws IOException, ModelException, TranslateException {// 加载模型Criteria<Image, Classifications> criteria = Criteria.builder().optEngine("TensorFlow") // 选择引擎.setTypes(Image.class, Classifications.class).optModelPath(Paths.get("path/to/model")).build();try (Model model = ModelZoo.loadModel(criteria);Predictor<Image, Classifications> predictor = model.newPredictor()) {// 加载图像Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));// 进行推理Classifications result = predictor.predict(img);System.out.println(result);}}
      }
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample {public static void main(String[] args) throws IOException, TranslateException {NDManager manager = NDManager.newBaseManager();FashionMnist fashionMnist = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true) // 32 is the batch size.optLimit(Long.MAX_VALUE) // Use this to limit the number of samples.build();fashionMnist.prepare();Model model = Model.newInstance("fashion-mnist-model");TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()).addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)) {trainer.initialize(new long[]{1, 28, 28}); // Example shape for image dataMetrics metrics = new Metrics();trainer.setMetrics(metrics);for (Batch batch : trainer.iterateDataset(fashionMnist)) {EasyTrain.trainBatch(trainer, batch);trainer.step();batch.close();}trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));}}
      }

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DJLEngineExample {@Throws(ModelException::class, TranslateException::class, IOException::class)@JvmStaticfun main(args: Array<String>) {// Initialize the modelval model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained modelmodel.load(Paths.get("path/to/your/model")) // Ensure the path is correct// Define a translator for data preprocessing and postprocessingval translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {val manager = ctx.ndManagerval array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessaryreturn NDList(array)}override fun processOutput(ctx: TranslatorContext, list: NDList): Float {// Assuming the output is a single scalar valuereturn list[0].getFloat() // Use getFloat() to get the scalar value}override fun getBatchifier(): Batchifier? {return null // Or implement batching if needed}}model.newPredictor(translator).use { predictor ->val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shapeval output = predictor.predict(input)println("Prediction: $output")}}
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManagerobject NDManagerExample {@JvmStaticfun main(args: Array<String>) {NDManager.newBaseManager().use { manager ->val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))println("Array: $array")// Perform operationsval result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Pathsobject DjlTrainerDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load datasetval trainDataset: RandomAccessDataset =FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()trainDataset.prepare(ProgressBar())// Define modelval model = Model.newInstance("mlp")model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))// Define training configurationval config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener())model.newTrainer(config).use { trainer ->trainer.initialize(Shape(1, (28 * 28).toLong()))for (epoch in 0..9) {for (batch in trainer.iterateDataset(trainDataset)) {trainer.step()batch.close()}trainer.notifyListeners { listener: TrainingListener ->listener.onEpoch(trainer)}}model.save(Paths.get("model"), "mlp")}}
    }
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DjlPredictorDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load modelval model = Model.newInstance("mlp")model.load(Paths.get("model"), "mlp")// Define Translatorval translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {return NDList(input.reshape(Shape(1, (28 * 28).toLong())))}override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {// Assuming the output NDArray is the first element in NDListval probabilities = list.singletonOrThrow()return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels}override fun getBatchifier(): Batchifier {return Batchifier.STACK}}model.newPredictor(translator).use { predictor ->val manager = NDManager.newBaseManager()val array = manager.ones(Shape(1, (28 * 28).toLong()))val classifications = predictor.predict(array)println(classifications)}}
    }

这篇关于java(kotlin) ai框架djl的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056911

相关文章

Java 实用工具类Spring 的 AnnotationUtils详解

《Java实用工具类Spring的AnnotationUtils详解》Spring框架提供了一个强大的注解工具类org.springframework.core.annotation.Annot... 目录前言一、AnnotationUtils 的常用方法二、常见应用场景三、与 JDK 原生注解 API 的

Java controller接口出入参时间序列化转换操作方法(两种)

《Javacontroller接口出入参时间序列化转换操作方法(两种)》:本文主要介绍Javacontroller接口出入参时间序列化转换操作方法,本文给大家列举两种简单方法,感兴趣的朋友一起看... 目录方式一、使用注解方式二、统一配置场景:在controller编写的接口,在前后端交互过程中一般都会涉及

Java中的StringBuilder之如何高效构建字符串

《Java中的StringBuilder之如何高效构建字符串》本文将深入浅出地介绍StringBuilder的使用方法、性能优势以及相关字符串处理技术,结合代码示例帮助读者更好地理解和应用,希望对大家... 目录关键点什么是 StringBuilder?为什么需要 StringBuilder?如何使用 St

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

Java并发编程之如何优雅关闭钩子Shutdown Hook

《Java并发编程之如何优雅关闭钩子ShutdownHook》这篇文章主要为大家详细介绍了Java如何实现优雅关闭钩子ShutdownHook,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起... 目录关闭钩子简介关闭钩子应用场景数据库连接实战演示使用关闭钩子的注意事项开源框架中的关闭钩子机制1.

Maven中引入 springboot 相关依赖的方式(最新推荐)

《Maven中引入springboot相关依赖的方式(最新推荐)》:本文主要介绍Maven中引入springboot相关依赖的方式(最新推荐),本文给大家介绍的非常详细,对大家的学习或工作具有... 目录Maven中引入 springboot 相关依赖的方式1. 不使用版本管理(不推荐)2、使用版本管理(推

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

如何在 Spring Boot 中实现 FreeMarker 模板

《如何在SpringBoot中实现FreeMarker模板》FreeMarker是一种功能强大、轻量级的模板引擎,用于在Java应用中生成动态文本输出(如HTML、XML、邮件内容等),本文... 目录什么是 FreeMarker 模板?在 Spring Boot 中实现 FreeMarker 模板1. 环

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll