tensorflow recommenders 系列2:召回模型介绍

2024-01-29 18:58

本文主要是介绍tensorflow recommenders 系列2:召回模型介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基础教程部分,参考:Recommending movies: retrieval  |  TensorFlow Recommenders

在召回模型训练阶段涉及到

task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(movies.batch(5).map(movie_model),k=3
)

其作用有两个,一个是返回定义召回效果评估度量标准FactorizedTopK,另一个是定义了损失函数的计算方法,并对每一个批次返回对应的损失计算结果。

首先介绍损失函数部分Retrieval对应的call方法(Retrieval方法只适合双塔模型),

  def call(self,query_embeddings: tf.Tensor,candidate_embeddings: tf.Tensor,sample_weight: Optional[tf.Tensor] = None,candidate_sampling_probability: Optional[tf.Tensor] = None,candidate_ids: Optional[tf.Tensor] = None,compute_metrics: bool = True) -> tf.Tensor:"""Computes the task loss and metrics.The main argument are pairs of query and candidate embeddings: the first rowof query_embeddings denotes a query for which the candidate from the firstrow of candidate embeddings was selected by the user.The task will try to maximize the affinity of these query, candidate pairswhile minimizing the affinity between the query and candidates belongingto other queries in the batch.Args:query_embeddings: [num_queries, embedding_dim] tensor of queryrepresentations.candidate_embeddings: [num_queries, embedding_dim] tensor of candidaterepresentations.sample_weight: [num_queries] tensor of sample weights.candidate_sampling_probability: Optional tensor of candidate samplingprobabilities. When given will be be used to correct the logits toreflect the sampling probability of negative candidates.candidate_ids: Optional tensor containing candidate ids. When givenenables removing accidental hits of examples used as negatives. Anaccidental hit is defined as an candidate that is used as an in-batchnegative but has the same id with the positive candidate.compute_metrics: Whether to compute metrics. Set this to Falseduring training for faster training.Returns:loss: Tensor of loss values."""#此处将两个向量的乘积结果作为用户和实体向量的相似度值,参考https://www.cnblogs.com/daniel-D/p/3244718.htmlscores = tf.linalg.matmul(query_embeddings, candidate_embeddings, transpose_b=True)num_queries = tf.shape(scores)[0]num_candidates = tf.shape(scores)[1]#根据结果构造对角函数,作为预期结果矩阵labels = tf.eye(num_queries, num_candidates)metric_update_ops = []if compute_metrics:if self._factorized_metrics:metric_update_ops.append(self._factorized_metrics.update_state(query_embeddings,candidate_embeddings))if self._batch_metrics:metric_update_ops.extend([batch_metric.update_state(labels, scores)for batch_metric in self._batch_metrics])if self._temperature is not None:#scores = scores / self._temperatureif candidate_sampling_probability is not None:scores = layers.loss.SamplingProbablityCorrection()(scores, candidate_sampling_probability)if candidate_ids is not None:scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)if self._num_hard_negatives is not None:scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(scores,labels)loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)if not metric_update_ops:return losswith tf.control_dependencies(metric_update_ops):return tf.identity(loss)
  • 关于score和labels的计算已经在代码中进行了注解
  • 关于温度淬火法参数的解释self._temperature参考:深度学习中的temperature parameter是什么 - 知乎
  • 关于损失函数,默认的损失函数为类别交叉熵损失函数(对应ont-hot表示法),如果是integer表示法,可以换为
    SparseCategoricalCrossentropy
tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

 对应的调用样例参考:

>> y_true = [[0, 1, 0], [0, 0, 1]]>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]>>> # Using 'auto'/'sum_over_batch_size' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy()>>> cce(y_true, y_pred).numpy()1.177>>> # Calling with 'sample_weight'.>>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()0.814>>> # Using 'sum' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy(...     reduction=tf.keras.losses.Reduction.SUM)>>> cce(y_true, y_pred).numpy()2.354

关于度量标准部分 ,此处重点理解TopKCategoricalAccuracy

查看源码解释既可明白,其是对每一个用户的topk推荐结果(是数值topk而非位置)是否包含了预期的结果,统计所有用户的对应情况占比。

  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])>>> m.result().numpy()
#1/20.5>>> m.reset_state()>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],...                sample_weight=[0.7, 0.3])>>> m.result().numpy()
#0*0.7+1*0.30.3

在召回模型使用阶段,

 如果候选实体比较少的时候,可以使用暴力求解法:

index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(movies.batch(100).map(lambda title: (title, model.movie_model(title))))# Get some recommendations.
_, titles = index(np.array(["42"]))
print(f"Top 3 recommendations for user 42: {titles[0, :3]}")

如果候选实体比较大的时候,可以使用近似ScaNN方法,tfrs.layers.factorized_top_k.ScaNN。关于ScaNN部分,也可以通过类似milvas等向量数据库的查询方法得以实现

这篇关于tensorflow recommenders 系列2:召回模型介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/657876

相关文章

Java中HashMap的用法详细介绍

《Java中HashMap的用法详细介绍》JavaHashMap是一种高效的数据结构,用于存储键值对,它是基于哈希表实现的,提供快速的插入、删除和查找操作,:本文主要介绍Java中HashMap... 目录一.HashMap1.基本概念2.底层数据结构:3.HashCode和equals方法为什么重写Has

Springboot项目构建时各种依赖详细介绍与依赖关系说明详解

《Springboot项目构建时各种依赖详细介绍与依赖关系说明详解》SpringBoot通过spring-boot-dependencies统一依赖版本管理,spring-boot-starter-w... 目录一、spring-boot-dependencies1.简介2. 内容概览3.核心内容结构4.

setsid 命令工作原理和使用案例介绍

《setsid命令工作原理和使用案例介绍》setsid命令在Linux中创建独立会话,使进程脱离终端运行,适用于守护进程和后台任务,通过重定向输出和确保权限,可有效管理长时间运行的进程,本文给大家介... 目录setsid 命令介绍和使用案例基本介绍基本语法主要特点命令参数使用案例1. 在后台运行命令2.

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

zookeeper端口说明及介绍

《zookeeper端口说明及介绍》:本文主要介绍zookeeper端口说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、zookeeper有三个端口(可以修改)aVNMqvZ二、3个端口的作用三、部署时注意总China编程结一、zookeeper有三个端口(可以

Python中win32包的安装及常见用途介绍

《Python中win32包的安装及常见用途介绍》在Windows环境下,PythonWin32模块通常随Python安装包一起安装,:本文主要介绍Python中win32包的安装及常见用途的相关... 目录前言主要组件安装方法常见用途1. 操作Windows注册表2. 操作Windows服务3. 窗口操作

c++中的set容器介绍及操作大全

《c++中的set容器介绍及操作大全》:本文主要介绍c++中的set容器介绍及操作大全,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录​​一、核心特性​​️ ​​二、基本操作​​​​1. 初始化与赋值​​​​2. 增删查操作​​​​3. 遍历方

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

HTML img标签和超链接标签详细介绍

《HTMLimg标签和超链接标签详细介绍》:本文主要介绍了HTML中img标签的使用,包括src属性(指定图片路径)、相对/绝对路径区别、alt替代文本、title提示、宽高控制及边框设置等,详细内容请阅读本文,希望能对你有所帮助... 目录img 标签src 属性alt 属性title 属性width/h

MybatisPlus service接口功能介绍

《MybatisPlusservice接口功能介绍》:本文主要介绍MybatisPlusservice接口功能介绍,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录Service接口基本用法进阶用法总结:Lambda方法Service接口基本用法MyBATisP