tensorflow recommenders 系列2:召回模型介绍

2024-01-29 18:58

本文主要是介绍tensorflow recommenders 系列2:召回模型介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基础教程部分,参考:Recommending movies: retrieval  |  TensorFlow Recommenders

在召回模型训练阶段涉及到

task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(movies.batch(5).map(movie_model),k=3
)

其作用有两个,一个是返回定义召回效果评估度量标准FactorizedTopK,另一个是定义了损失函数的计算方法,并对每一个批次返回对应的损失计算结果。

首先介绍损失函数部分Retrieval对应的call方法(Retrieval方法只适合双塔模型),

  def call(self,query_embeddings: tf.Tensor,candidate_embeddings: tf.Tensor,sample_weight: Optional[tf.Tensor] = None,candidate_sampling_probability: Optional[tf.Tensor] = None,candidate_ids: Optional[tf.Tensor] = None,compute_metrics: bool = True) -> tf.Tensor:"""Computes the task loss and metrics.The main argument are pairs of query and candidate embeddings: the first rowof query_embeddings denotes a query for which the candidate from the firstrow of candidate embeddings was selected by the user.The task will try to maximize the affinity of these query, candidate pairswhile minimizing the affinity between the query and candidates belongingto other queries in the batch.Args:query_embeddings: [num_queries, embedding_dim] tensor of queryrepresentations.candidate_embeddings: [num_queries, embedding_dim] tensor of candidaterepresentations.sample_weight: [num_queries] tensor of sample weights.candidate_sampling_probability: Optional tensor of candidate samplingprobabilities. When given will be be used to correct the logits toreflect the sampling probability of negative candidates.candidate_ids: Optional tensor containing candidate ids. When givenenables removing accidental hits of examples used as negatives. Anaccidental hit is defined as an candidate that is used as an in-batchnegative but has the same id with the positive candidate.compute_metrics: Whether to compute metrics. Set this to Falseduring training for faster training.Returns:loss: Tensor of loss values."""#此处将两个向量的乘积结果作为用户和实体向量的相似度值,参考https://www.cnblogs.com/daniel-D/p/3244718.htmlscores = tf.linalg.matmul(query_embeddings, candidate_embeddings, transpose_b=True)num_queries = tf.shape(scores)[0]num_candidates = tf.shape(scores)[1]#根据结果构造对角函数,作为预期结果矩阵labels = tf.eye(num_queries, num_candidates)metric_update_ops = []if compute_metrics:if self._factorized_metrics:metric_update_ops.append(self._factorized_metrics.update_state(query_embeddings,candidate_embeddings))if self._batch_metrics:metric_update_ops.extend([batch_metric.update_state(labels, scores)for batch_metric in self._batch_metrics])if self._temperature is not None:#scores = scores / self._temperatureif candidate_sampling_probability is not None:scores = layers.loss.SamplingProbablityCorrection()(scores, candidate_sampling_probability)if candidate_ids is not None:scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)if self._num_hard_negatives is not None:scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(scores,labels)loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)if not metric_update_ops:return losswith tf.control_dependencies(metric_update_ops):return tf.identity(loss)
  • 关于score和labels的计算已经在代码中进行了注解
  • 关于温度淬火法参数的解释self._temperature参考:深度学习中的temperature parameter是什么 - 知乎
  • 关于损失函数,默认的损失函数为类别交叉熵损失函数(对应ont-hot表示法),如果是integer表示法,可以换为
    SparseCategoricalCrossentropy
tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

 对应的调用样例参考:

>> y_true = [[0, 1, 0], [0, 0, 1]]>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]>>> # Using 'auto'/'sum_over_batch_size' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy()>>> cce(y_true, y_pred).numpy()1.177>>> # Calling with 'sample_weight'.>>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()0.814>>> # Using 'sum' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy(...     reduction=tf.keras.losses.Reduction.SUM)>>> cce(y_true, y_pred).numpy()2.354

关于度量标准部分 ,此处重点理解TopKCategoricalAccuracy

查看源码解释既可明白,其是对每一个用户的topk推荐结果(是数值topk而非位置)是否包含了预期的结果,统计所有用户的对应情况占比。

  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])>>> m.result().numpy()
#1/20.5>>> m.reset_state()>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],...                sample_weight=[0.7, 0.3])>>> m.result().numpy()
#0*0.7+1*0.30.3

在召回模型使用阶段,

 如果候选实体比较少的时候,可以使用暴力求解法:

index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(movies.batch(100).map(lambda title: (title, model.movie_model(title))))# Get some recommendations.
_, titles = index(np.array(["42"]))
print(f"Top 3 recommendations for user 42: {titles[0, :3]}")

如果候选实体比较大的时候,可以使用近似ScaNN方法,tfrs.layers.factorized_top_k.ScaNN。关于ScaNN部分,也可以通过类似milvas等向量数据库的查询方法得以实现

这篇关于tensorflow recommenders 系列2:召回模型介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/657876

相关文章

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

redis过期key的删除策略介绍

《redis过期key的删除策略介绍》:本文主要介绍redis过期key的删除策略,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录第一种策略:被动删除第二种策略:定期删除第三种策略:强制删除关于big key的清理UNLINK命令FLUSHALL/FLUSHDB命

Pytest多环境切换的常见方法介绍

《Pytest多环境切换的常见方法介绍》Pytest作为自动化测试的主力框架,如何实现本地、测试、预发、生产环境的灵活切换,本文总结了通过pytest框架实现自由环境切换的几种方法,大家可以根据需要进... 目录1.pytest-base-url2.hooks函数3.yml和fixture结论你是否也遇到过

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

MySQL中慢SQL优化的不同方式介绍

《MySQL中慢SQL优化的不同方式介绍》慢SQL的优化,主要从两个方面考虑,SQL语句本身的优化,以及数据库设计的优化,下面小编就来给大家介绍一下有哪些方式可以优化慢SQL吧... 目录避免不必要的列分页优化索引优化JOIN 的优化排序优化UNION 优化慢 SQL 的优化,主要从两个方面考虑,SQL 语

C++中函数模板与类模板的简单使用及区别介绍

《C++中函数模板与类模板的简单使用及区别介绍》这篇文章介绍了C++中的模板机制,包括函数模板和类模板的概念、语法和实际应用,函数模板通过类型参数实现泛型操作,而类模板允许创建可处理多种数据类型的类,... 目录一、函数模板定义语法真实示例二、类模板三、关键区别四、注意事项 ‌在C++中,模板是实现泛型编程

Python实现html转png的完美方案介绍

《Python实现html转png的完美方案介绍》这篇文章主要为大家详细介绍了如何使用Python实现html转png功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 1.增强稳定性与错误处理建议使用三层异常捕获结构:try: with sync_playwright(

Java使用多线程处理未知任务数的方案介绍

《Java使用多线程处理未知任务数的方案介绍》这篇文章主要为大家详细介绍了Java如何使用多线程实现处理未知任务数,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 知道任务个数,你可以定义好线程数规则,生成线程数去跑代码说明:1.虚拟线程池:使用 Executors.newVir