tensorflow recommenders 系列2:召回模型介绍

2024-01-29 18:58

本文主要是介绍tensorflow recommenders 系列2:召回模型介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基础教程部分,参考:Recommending movies: retrieval  |  TensorFlow Recommenders

在召回模型训练阶段涉及到

task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(movies.batch(5).map(movie_model),k=3
)

其作用有两个,一个是返回定义召回效果评估度量标准FactorizedTopK,另一个是定义了损失函数的计算方法,并对每一个批次返回对应的损失计算结果。

首先介绍损失函数部分Retrieval对应的call方法(Retrieval方法只适合双塔模型),

  def call(self,query_embeddings: tf.Tensor,candidate_embeddings: tf.Tensor,sample_weight: Optional[tf.Tensor] = None,candidate_sampling_probability: Optional[tf.Tensor] = None,candidate_ids: Optional[tf.Tensor] = None,compute_metrics: bool = True) -> tf.Tensor:"""Computes the task loss and metrics.The main argument are pairs of query and candidate embeddings: the first rowof query_embeddings denotes a query for which the candidate from the firstrow of candidate embeddings was selected by the user.The task will try to maximize the affinity of these query, candidate pairswhile minimizing the affinity between the query and candidates belongingto other queries in the batch.Args:query_embeddings: [num_queries, embedding_dim] tensor of queryrepresentations.candidate_embeddings: [num_queries, embedding_dim] tensor of candidaterepresentations.sample_weight: [num_queries] tensor of sample weights.candidate_sampling_probability: Optional tensor of candidate samplingprobabilities. When given will be be used to correct the logits toreflect the sampling probability of negative candidates.candidate_ids: Optional tensor containing candidate ids. When givenenables removing accidental hits of examples used as negatives. Anaccidental hit is defined as an candidate that is used as an in-batchnegative but has the same id with the positive candidate.compute_metrics: Whether to compute metrics. Set this to Falseduring training for faster training.Returns:loss: Tensor of loss values."""#此处将两个向量的乘积结果作为用户和实体向量的相似度值,参考https://www.cnblogs.com/daniel-D/p/3244718.htmlscores = tf.linalg.matmul(query_embeddings, candidate_embeddings, transpose_b=True)num_queries = tf.shape(scores)[0]num_candidates = tf.shape(scores)[1]#根据结果构造对角函数,作为预期结果矩阵labels = tf.eye(num_queries, num_candidates)metric_update_ops = []if compute_metrics:if self._factorized_metrics:metric_update_ops.append(self._factorized_metrics.update_state(query_embeddings,candidate_embeddings))if self._batch_metrics:metric_update_ops.extend([batch_metric.update_state(labels, scores)for batch_metric in self._batch_metrics])if self._temperature is not None:#scores = scores / self._temperatureif candidate_sampling_probability is not None:scores = layers.loss.SamplingProbablityCorrection()(scores, candidate_sampling_probability)if candidate_ids is not None:scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)if self._num_hard_negatives is not None:scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(scores,labels)loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)if not metric_update_ops:return losswith tf.control_dependencies(metric_update_ops):return tf.identity(loss)
  • 关于score和labels的计算已经在代码中进行了注解
  • 关于温度淬火法参数的解释self._temperature参考:深度学习中的temperature parameter是什么 - 知乎
  • 关于损失函数,默认的损失函数为类别交叉熵损失函数(对应ont-hot表示法),如果是integer表示法,可以换为
    SparseCategoricalCrossentropy
tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

 对应的调用样例参考:

>> y_true = [[0, 1, 0], [0, 0, 1]]>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]>>> # Using 'auto'/'sum_over_batch_size' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy()>>> cce(y_true, y_pred).numpy()1.177>>> # Calling with 'sample_weight'.>>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()0.814>>> # Using 'sum' reduction type.>>> cce = tf.keras.losses.CategoricalCrossentropy(...     reduction=tf.keras.losses.Reduction.SUM)>>> cce(y_true, y_pred).numpy()2.354

关于度量标准部分 ,此处重点理解TopKCategoricalAccuracy

查看源码解释既可明白,其是对每一个用户的topk推荐结果(是数值topk而非位置)是否包含了预期的结果,统计所有用户的对应情况占比。

  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])>>> m.result().numpy()
#1/20.5>>> m.reset_state()>>> m.update_state([[0, 0, 1], [0, 1, 0]],...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],...                sample_weight=[0.7, 0.3])>>> m.result().numpy()
#0*0.7+1*0.30.3

在召回模型使用阶段,

 如果候选实体比较少的时候,可以使用暴力求解法:

index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(movies.batch(100).map(lambda title: (title, model.movie_model(title))))# Get some recommendations.
_, titles = index(np.array(["42"]))
print(f"Top 3 recommendations for user 42: {titles[0, :3]}")

如果候选实体比较大的时候,可以使用近似ScaNN方法,tfrs.layers.factorized_top_k.ScaNN。关于ScaNN部分,也可以通过类似milvas等向量数据库的查询方法得以实现

这篇关于tensorflow recommenders 系列2:召回模型介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/657876

相关文章

Python中win32包的安装及常见用途介绍

《Python中win32包的安装及常见用途介绍》在Windows环境下,PythonWin32模块通常随Python安装包一起安装,:本文主要介绍Python中win32包的安装及常见用途的相关... 目录前言主要组件安装方法常见用途1. 操作Windows注册表2. 操作Windows服务3. 窗口操作

c++中的set容器介绍及操作大全

《c++中的set容器介绍及操作大全》:本文主要介绍c++中的set容器介绍及操作大全,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录​​一、核心特性​​️ ​​二、基本操作​​​​1. 初始化与赋值​​​​2. 增删查操作​​​​3. 遍历方

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

HTML img标签和超链接标签详细介绍

《HTMLimg标签和超链接标签详细介绍》:本文主要介绍了HTML中img标签的使用,包括src属性(指定图片路径)、相对/绝对路径区别、alt替代文本、title提示、宽高控制及边框设置等,详细内容请阅读本文,希望能对你有所帮助... 目录img 标签src 属性alt 属性title 属性width/h

MybatisPlus service接口功能介绍

《MybatisPlusservice接口功能介绍》:本文主要介绍MybatisPlusservice接口功能介绍,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录Service接口基本用法进阶用法总结:Lambda方法Service接口基本用法MyBATisP

MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)

《MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)》掌握多表联查(INNERJOIN,LEFTJOIN,RIGHTJOIN,FULLJOIN)和子查询(标量、列、行、表子查询、相关/非相关、... 目录第一部分:多表联查 (JOIN Operations)1. 连接的类型 (JOIN Types)

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

Java实现本地缓存的常用方案介绍

《Java实现本地缓存的常用方案介绍》本地缓存的代表技术主要有HashMap,GuavaCache,Caffeine和Encahche,这篇文章主要来和大家聊聊java利用这些技术分别实现本地缓存的方... 目录本地缓存实现方式HashMapConcurrentHashMapGuava CacheCaffe