Looking to Listen at the Cocktail Party 代码详解

2023-10-24 22:58

本文主要是介绍Looking to Listen at the Cocktail Party 代码详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个是清华某位大佬对论文《Looking to Listen at the Cocktail Party 》的一个复现。代码链接

网络结构如下图:

在这里插入图片描述
由于AVSpeech这个数据集里是一些视频的片段,而输入网络的是视频中的人脸区域。所以先要做人脸识别,并把人脸截取。
这个代码中使用了Python的一个pretrained的mtcnn的包直接做的。

def face_detect(file,detector,frame_path,cat_train,output_dir):name = file.replace('.jpg', '').split('-')log = cat_train.iloc[int(name[0])]x = log[3]y = log[4]img = cv2.imread('%s%s'%(frame_path,file))x = img.shape[1] * xy = img.shape[0] * yfaces = detector.detect_faces(img)# check if detected facesif(len(faces)==0):print('no face detect: '+file)return #no facebounding_box = bounding_box_check(faces,x,y)if(bounding_box == None):print('face is not related to given coord: '+file)returnprint(file," ",bounding_box)print(file," ",x, y)crop_img = img[bounding_box[1]:bounding_box[1] + bounding_box[3],bounding_box[0]:bounding_box[0]+bounding_box[2]]crop_img = cv2.resize(crop_img,(160,160))cv2.imwrite('%s/frame_'%output_dir + name[0] + '_' + name[1] + '.jpg', crop_img)#crop_img = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)#plt.imshow(crop_img)#plt.show()

以下是AV的model代码:

from keras.models import Sequential
from keras.layers import Input, Dense, Convolution2D,Bidirectional, concatenate
from keras.layers import Flatten, BatchNormalization, ReLU, Reshape, Lambda, TimeDistributed
from keras.models import Model
from keras.layers.recurrent import LSTM
from keras.initializers import he_normal, glorot_uniform
import tensorflow as tfdef AV_model(people_num=2):def UpSampling2DBilinear(size):return Lambda(lambda x: tf.image.resize(x, size, method=tf.image.ResizeMethod.BILINEAR))def sliced(x, index):return x[:, :, :, index]# --------------------------- AS start ---------------------------audio_input = Input(shape=(298, 257, 2))print('as_0:', audio_input.shape)as_conv1 = Convolution2D(96, kernel_size=(1, 7), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv1')(audio_input)as_conv1 = BatchNormalization()(as_conv1)as_conv1 = ReLU()(as_conv1)print('as_1:', as_conv1.shape)as_conv2 = Convolution2D(96, kernel_size=(7, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv2')(as_conv1)as_conv2 = BatchNormalization()(as_conv2)as_conv2 = ReLU()(as_conv2)print('as_2:', as_conv2.shape)as_conv3 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv3')(as_conv2)as_conv3 = BatchNormalization()(as_conv3)as_conv3 = ReLU()(as_conv3)print('as_3:', as_conv3.shape)as_conv4 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(2, 1), name='as_conv4')(as_conv3)as_conv4 = BatchNormalization()(as_conv4)as_conv4 = ReLU()(as_conv4)print('as_4:', as_conv4.shape)as_conv5 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(4, 1), name='as_conv5')(as_conv4)as_conv5 = BatchNormalization()(as_conv5)as_conv5 = ReLU()(as_conv5)print('as_5:', as_conv5.shape)as_conv6 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(8, 1), name='as_conv6')(as_conv5)as_conv6 = BatchNormalization()(as_conv6)as_conv6 = ReLU()(as_conv6)print('as_6:', as_conv6.shape)as_conv7 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(16, 1), name='as_conv7')(as_conv6)as_conv7 = BatchNormalization()(as_conv7)as_conv7 = ReLU()(as_conv7)print('as_7:', as_conv7.shape)as_conv8 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(32, 1), name='as_conv8')(as_conv7)as_conv8 = BatchNormalization()(as_conv8)as_conv8 = ReLU()(as_conv8)print('as_8:', as_conv8.shape)as_conv9 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv9')(as_conv8)as_conv9 = BatchNormalization()(as_conv9)as_conv9 = ReLU()(as_conv9)print('as_9:', as_conv9.shape)as_conv10 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(2, 2), name='as_conv10')(as_conv9)as_conv10 = BatchNormalization()(as_conv10)as_conv10 = ReLU()(as_conv10)print('as_10:', as_conv10.shape)as_conv11 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(4, 4), name='as_conv11')(as_conv10)as_conv11 = BatchNormalization()(as_conv11)as_conv11 = ReLU()(as_conv11)print('as_11:', as_conv11.shape)as_conv12 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(8, 8), name='as_conv12')(as_conv11)as_conv12 = BatchNormalization()(as_conv12)as_conv12 = ReLU()(as_conv12)print('as_12:', as_conv12.shape)as_conv13 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(16, 16), name='as_conv13')(as_conv12)as_conv13 = BatchNormalization()(as_conv13)as_conv13 = ReLU()(as_conv13)print('as_13:', as_conv13.shape)as_conv14 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(32, 32), name='as_conv14')(as_conv13)as_conv14 = BatchNormalization()(as_conv14)as_conv14 = ReLU()(as_conv14)print('as_14:', as_conv14.shape)as_conv15 = Convolution2D(8, kernel_size=(1, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv15')(as_conv14)as_conv15 = BatchNormalization()(as_conv15)as_conv15 = ReLU()(as_conv15)print('as_15:', as_conv15.shape)AS_out = Reshape((298, 8 * 257))(as_conv15)print('AS_out:', AS_out.shape)# --------------------------- AS end ---------------------------# --------------------------- VS_model start ---------------------------VS_model = Sequential()VS_model.add(Convolution2D(256, kernel_size=(7, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='vs_conv1'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='vs_conv2'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(2, 1), name='vs_conv3'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(4, 1), name='vs_conv4'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(8, 1), name='vs_conv5'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(16, 1), name='vs_conv6'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Reshape((75, 256, 1)))VS_model.add(UpSampling2DBilinear((298, 256)))VS_model.add(Reshape((298, 256)))# --------------------------- VS_model end ---------------------------video_input = Input(shape=(75, 1, 1792, people_num))AVfusion_list = [AS_out]for i in range(people_num):single_input = Lambda(sliced, arguments={'index': i})(video_input)VS_out = VS_model(single_input)AVfusion_list.append(VS_out)AVfusion = concatenate(AVfusion_list, axis=2)AVfusion = TimeDistributed(Flatten())(AVfusion)print('AVfusion:', AVfusion.shape)lstm = Bidirectional(LSTM(400, input_shape=(298, 8 * 257), return_sequences=True), merge_mode='sum')(AVfusion)print('lstm:', lstm.shape)fc1 = Dense(600, name="fc1", activation='relu', kernel_initializer=he_normal(seed=27))(lstm)print('fc1:', fc1.shape)fc2 = Dense(600, name="fc2", activation='relu', kernel_initializer=he_normal(seed=42))(fc1)print('fc2:', fc2.shape)fc3 = Dense(600, name="fc3", activation='relu', kernel_initializer=he_normal(seed=65))(fc2)print('fc3:', fc3.shape)complex_mask = Dense(257 * 2 * people_num, name="complex_mask", kernel_initializer=glorot_uniform(seed=87))(fc3)print('complex_mask:', complex_mask.shape)complex_mask_out = Reshape((298, 257, 2, people_num))(complex_mask)print('complex_mask_out:', complex_mask_out.shape)AV_model = Model(inputs=[audio_input, video_input], outputs=complex_mask_out)# # compile AV_model# AV_model.compile(optimizer='adam', loss='mse')return AV_model

这个大佬太强了,自愧不如。

这篇关于Looking to Listen at the Cocktail Party 代码详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/278322

相关文章

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串

Spring Boot spring-boot-maven-plugin 参数配置详解(最新推荐)

《SpringBootspring-boot-maven-plugin参数配置详解(最新推荐)》文章介绍了SpringBootMaven插件的5个核心目标(repackage、run、start... 目录一 spring-boot-maven-plugin 插件的5个Goals二 应用场景1 重新打包应用

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Linux系统性能检测命令详解

《Linux系统性能检测命令详解》本文介绍了Linux系统常用的监控命令(如top、vmstat、iostat、htop等)及其参数功能,涵盖进程状态、内存使用、磁盘I/O、系统负载等多维度资源监控,... 目录toppsuptimevmstatIOStatiotopslabtophtopdstatnmon

java使用protobuf-maven-plugin的插件编译proto文件详解

《java使用protobuf-maven-plugin的插件编译proto文件详解》:本文主要介绍java使用protobuf-maven-plugin的插件编译proto文件,具有很好的参考价... 目录protobuf文件作为数据传输和存储的协议主要介绍在Java使用maven编译proto文件的插件

Android ClassLoader加载机制详解

《AndroidClassLoader加载机制详解》Android的ClassLoader负责加载.dex文件,基于双亲委派模型,支持热修复和插件化,需注意类冲突、内存泄漏和兼容性问题,本文给大家介... 目录一、ClassLoader概述1.1 类加载的基本概念1.2 android与Java Class

Java中的数组与集合基本用法详解

《Java中的数组与集合基本用法详解》本文介绍了Java数组和集合框架的基础知识,数组部分涵盖了一维、二维及多维数组的声明、初始化、访问与遍历方法,以及Arrays类的常用操作,对Java数组与集合相... 目录一、Java数组基础1.1 数组结构概述1.2 一维数组1.2.1 声明与初始化1.2.2 访问

SpringBoot线程池配置使用示例详解

《SpringBoot线程池配置使用示例详解》SpringBoot集成@Async注解,支持线程池参数配置(核心数、队列容量、拒绝策略等)及生命周期管理,结合监控与任务装饰器,提升异步处理效率与系统... 目录一、核心特性二、添加依赖三、参数详解四、配置线程池五、应用实践代码说明拒绝策略(Rejected