fasttext源码学习(2)--模型压缩

2024-06-04 17:38

本文主要是介绍fasttext源码学习(2)--模型压缩,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

fasttext源码学习(2)–模型压缩

前言

fasttext模型压缩的很明显,精度却降低不多,其网站上提供的语种识别模型,压缩前后的对比就是例证,压缩前126M,压缩后917K。太震惊了,必须学习一下。看文档介绍用到权重量化(weight quantization)和特征选择(feature selection),下面结合代码学习下。

说明:文章中代码皆为简化版,为突出重点,简化了逻辑,原版代码需到官方网页下载。

一 特征选择

一开始以为fasttext会用到比较复杂的特征选择算法,直到看到代码才差点闪了腰。。。fasttext用的就是kbest,剩下的全砍掉,就是这么简单直接。

void FastText::quantize(const Args& qargs, const TrainCallback& callback) {if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {auto idx = selectEmbeddings(qargs.cutoff);dict_->prune(idx);  // 剪枝(词典重新计算)if (qargs.retrain) { // 重新训练startThreads(callback);}}
}std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {std::shared_ptr<DenseMatrix> input =std::dynamic_pointer_cast<DenseMatrix>(input_);Vector norms(input->size(0));input->l2NormRow(norms); // [1] 正则化std::vector<int32_t> idx(input->size(0), 0);std::iota(idx.begin(), idx.end(), 0);std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {return (eosid != i2 && norms[i1] > norms[i2]);}); // [2] 按正则化值排序idx.erase(idx.begin() + cutoff, idx.end()); // [3] 保留指定数目的id return idx;
}

从上面代码可以看出,selectEmbeddings主要干的就是按正则化值排序,然后只保留指定数目的行。默认保留是5万,结合上一章的DensMatrix的行数是200万+词语个数,可以看出这一步的特征选择最少能压缩到原来的1/40,压缩比很可观。

值得注意的是dict_->prune(idx),这一步必不可少,因为fasttext是基于hash映射来计算矩阵下标的,特征被筛选后,相应的词典也需要清理压缩,对应关系也需要刷新。

void Dictionary::prune(std::vector<int32_t>& idx) {std::vector<int32_t> words, ngrams;// 按id选取对应的word、ngramfor (auto it = idx.cbegin(); it != idx.cend(); ++it) {if (*it < nwords_) {words.push_back(*it);} else {ngrams.push_back(*it);}}// 按id排序std::sort(words.begin(), words.end());idx = words;// 计算被筛选的ngram的hash与id的对应关系// ngram的对应关系原本不需存储,筛选后,由于对应关系的变化,导致需要存储pruneidx_if (ngrams.size() != 0) {int32_t j = 0;for (const auto ngram : ngrams) {pruneidx_[ngram - nwords_] = j;j++;}idx.insert(idx.end(), ngrams.begin(), ngrams.end());}pruneidx_size_ = pruneidx_.size();int32_t j = 0;// 筛选过的word往前移,为后续的清除做准备// 同时,重新计算hash与id的对应关系for (int32_t i = 0; i < words_.size(); i++) {if (getType(i) == entry_type::label ||(j < words.size() && words[j] == i)) {words_[j] = words_[i];word2int_[find(words_[j].word)] = j;j++;}}nwords_ = words.size();size_ = nwords_ + nlabels_;// 移除多余的wordwords_.erase(words_.begin() + size_, words_.end());// 重新初始化各word的ngrams,重新计算ngram的下标initNgrams();
}

要理解这段代码,必须先理解fasttext的数据存储即Dictionary和DenseMatrix那一部分,否则会非常晕。

二 权重量化

量化一般是将大的数值表示变为小的数值表示,比如从float变为byte。而fasttext采用了另外一种方法,product quantization。简单来说,就是将向量分割为更小的子向量,再使用kmeans算法,将子向量映射到中心点下标。这样, 假设子向量长度为2,则n*8的float类型的矩阵,被映射为n*4的byte矩阵,模型可以减小到原来的1/8.

void ProductQuantizer::train(int32_t n, const real* x) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);auto d = dsub_;auto np = std::min(n, max_points_);auto xslice = std::vector<real>(np * dsub_);// 划分为nsubq_个子向量,子向量长度为d. for (auto m = 0; m < nsubq_; m++) {std::shuffle(perm.begin(), perm.end(), rng);// 随机选取np行数据,每行长度为dfor (auto j = 0; j < np; j++) {memcpy(xslice.data() + j * d, x + perm[j] * dim_ + m * dsub_,			               d*sizeof(real));}// kmeans计算该子向量对应的中心点kmeans(xslice.data(), get_centroids(m, 0), np, d);}
}
void ProductQuantizer::kmeans(const real* x, real* c, int32_t n, int32_t d) {std::vector<int32_t> perm(n, 0);std::iota(perm.begin(), perm.end(), 0);std::shuffle(perm.begin(), perm.end(), rng);// 随机初始化中心点for (auto i = 0; i < ksub_; i++) {memcpy(&c[i * d], x + perm[i] * d, d * sizeof(real));}auto codes = std::vector<uint8_t>(n);// kmeans标准算法,具体可参考另一篇介绍kmeans的文章for (auto i = 0; i < niter_; i++) {Estep(x, c, codes.data(), d, n);MStep(x, c, codes.data(), d, n);}
}
// MStep中有一部分与Kmeans算法中不太一样的部分
// 对于中心点计数为0的部分做了修正,对中心点数值进行了调整
std::uniform_real_distribution<> runiform(0, 1);for (auto k = 0; k < ksub_; k++) {if (nelts[k] == 0) {int32_t m = 0;while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) {m = (m + 1) % ksub_;}memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d);for (auto j = 0; j < d; j++) {int32_t sign = (j % 2) * 2 - 1;centroids[k * d + j] += sign * eps_;centroids[m * d + j] -= sign * eps_;}nelts[k] = nelts[m] / 2;nelts[m] -= nelts[k];}}

源码中默认子矩阵长度是2,kmeans簇大小为256(不超过一个byte),默认会压缩至1/8大小。

总结

原本以为会写比较多内容,因为这部分代码确实花了点时间去看,尤其权重量化中kmeans算法前面那部分(子矩阵划分),不太明白,但是看懂代码逻辑之后,又回头看了文档,恍然大悟,一句话就能把逻辑说的很清楚。但是并没办法绕过看代码,因为正是看文档看的不明白才去看代码的。。。-_-||

所以,最终整篇文章变成了代码注释,以防细节部分忘掉。。。

话说特征选择,个人觉得对于大部分场景,压缩比非常可观。但是也需要看到,DenseMatrix矩阵起步就是200万行,所以对于小数据集,fasttext也会训出比较大的模型,这个是不足的一个方面。

附录

  1. fasttext language identification
  2. fast源码
  3. K-means学习总结
  4. fasttext源码学习(1)–dictionary

这篇关于fasttext源码学习(2)--模型压缩的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1030672

相关文章

uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)

《uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)》在uni-app开发中,文件上传和图片处理是很常见的需求,但也经常会遇到各种问题,下面:本文主要介绍uni-app小程序项目中实... 目录方式一:使用<canvas>实现图片压缩(推荐,兼容性好)示例代码(小程序平台):方式二:使用uni

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

java 恺撒加密/解密实现原理(附带源码)

《java恺撒加密/解密实现原理(附带源码)》本文介绍Java实现恺撒加密与解密,通过固定位移量对字母进行循环替换,保留大小写及非字母字符,由于其实现简单、易于理解,恺撒加密常被用作学习加密算法的入... 目录Java 恺撒加密/解密实现1. 项目背景与介绍2. 相关知识2.1 恺撒加密算法原理2.2 Ja

Nginx屏蔽服务器名称与版本信息方式(源码级修改)

《Nginx屏蔽服务器名称与版本信息方式(源码级修改)》本文详解如何通过源码修改Nginx1.25.4,移除Server响应头中的服务类型和版本信息,以增强安全性,需重新配置、编译、安装,升级时需重复... 目录一、背景与目的二、适用版本三、操作步骤修改源码文件四、后续操作提示五、注意事项六、总结一、背景与

Android实现图片浏览功能的示例详解(附带源码)

《Android实现图片浏览功能的示例详解(附带源码)》在许多应用中,都需要展示图片并支持用户进行浏览,本文主要为大家介绍了如何通过Android实现图片浏览功能,感兴趣的小伙伴可以跟随小编一起学习一... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

使用zip4j实现Java中的ZIP文件加密压缩的操作方法

《使用zip4j实现Java中的ZIP文件加密压缩的操作方法》本文介绍如何通过Maven集成zip4j1.3.2库创建带密码保护的ZIP文件,涵盖依赖配置、代码示例及加密原理,确保数据安全性,感兴趣的... 目录1. zip4j库介绍和版本1.1 zip4j库概述1.2 zip4j的版本演变1.3 zip4

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.