Looking to Listen at the Cocktail Party 代码详解

2023-10-24 22:58

本文主要是介绍Looking to Listen at the Cocktail Party 代码详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个是清华某位大佬对论文《Looking to Listen at the Cocktail Party 》的一个复现。代码链接

网络结构如下图:

在这里插入图片描述
由于AVSpeech这个数据集里是一些视频的片段,而输入网络的是视频中的人脸区域。所以先要做人脸识别,并把人脸截取。
这个代码中使用了Python的一个pretrained的mtcnn的包直接做的。

def face_detect(file,detector,frame_path,cat_train,output_dir):name = file.replace('.jpg', '').split('-')log = cat_train.iloc[int(name[0])]x = log[3]y = log[4]img = cv2.imread('%s%s'%(frame_path,file))x = img.shape[1] * xy = img.shape[0] * yfaces = detector.detect_faces(img)# check if detected facesif(len(faces)==0):print('no face detect: '+file)return #no facebounding_box = bounding_box_check(faces,x,y)if(bounding_box == None):print('face is not related to given coord: '+file)returnprint(file," ",bounding_box)print(file," ",x, y)crop_img = img[bounding_box[1]:bounding_box[1] + bounding_box[3],bounding_box[0]:bounding_box[0]+bounding_box[2]]crop_img = cv2.resize(crop_img,(160,160))cv2.imwrite('%s/frame_'%output_dir + name[0] + '_' + name[1] + '.jpg', crop_img)#crop_img = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)#plt.imshow(crop_img)#plt.show()

以下是AV的model代码:

from keras.models import Sequential
from keras.layers import Input, Dense, Convolution2D,Bidirectional, concatenate
from keras.layers import Flatten, BatchNormalization, ReLU, Reshape, Lambda, TimeDistributed
from keras.models import Model
from keras.layers.recurrent import LSTM
from keras.initializers import he_normal, glorot_uniform
import tensorflow as tfdef AV_model(people_num=2):def UpSampling2DBilinear(size):return Lambda(lambda x: tf.image.resize(x, size, method=tf.image.ResizeMethod.BILINEAR))def sliced(x, index):return x[:, :, :, index]# --------------------------- AS start ---------------------------audio_input = Input(shape=(298, 257, 2))print('as_0:', audio_input.shape)as_conv1 = Convolution2D(96, kernel_size=(1, 7), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv1')(audio_input)as_conv1 = BatchNormalization()(as_conv1)as_conv1 = ReLU()(as_conv1)print('as_1:', as_conv1.shape)as_conv2 = Convolution2D(96, kernel_size=(7, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv2')(as_conv1)as_conv2 = BatchNormalization()(as_conv2)as_conv2 = ReLU()(as_conv2)print('as_2:', as_conv2.shape)as_conv3 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv3')(as_conv2)as_conv3 = BatchNormalization()(as_conv3)as_conv3 = ReLU()(as_conv3)print('as_3:', as_conv3.shape)as_conv4 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(2, 1), name='as_conv4')(as_conv3)as_conv4 = BatchNormalization()(as_conv4)as_conv4 = ReLU()(as_conv4)print('as_4:', as_conv4.shape)as_conv5 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(4, 1), name='as_conv5')(as_conv4)as_conv5 = BatchNormalization()(as_conv5)as_conv5 = ReLU()(as_conv5)print('as_5:', as_conv5.shape)as_conv6 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(8, 1), name='as_conv6')(as_conv5)as_conv6 = BatchNormalization()(as_conv6)as_conv6 = ReLU()(as_conv6)print('as_6:', as_conv6.shape)as_conv7 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(16, 1), name='as_conv7')(as_conv6)as_conv7 = BatchNormalization()(as_conv7)as_conv7 = ReLU()(as_conv7)print('as_7:', as_conv7.shape)as_conv8 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(32, 1), name='as_conv8')(as_conv7)as_conv8 = BatchNormalization()(as_conv8)as_conv8 = ReLU()(as_conv8)print('as_8:', as_conv8.shape)as_conv9 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv9')(as_conv8)as_conv9 = BatchNormalization()(as_conv9)as_conv9 = ReLU()(as_conv9)print('as_9:', as_conv9.shape)as_conv10 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(2, 2), name='as_conv10')(as_conv9)as_conv10 = BatchNormalization()(as_conv10)as_conv10 = ReLU()(as_conv10)print('as_10:', as_conv10.shape)as_conv11 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(4, 4), name='as_conv11')(as_conv10)as_conv11 = BatchNormalization()(as_conv11)as_conv11 = ReLU()(as_conv11)print('as_11:', as_conv11.shape)as_conv12 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(8, 8), name='as_conv12')(as_conv11)as_conv12 = BatchNormalization()(as_conv12)as_conv12 = ReLU()(as_conv12)print('as_12:', as_conv12.shape)as_conv13 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(16, 16), name='as_conv13')(as_conv12)as_conv13 = BatchNormalization()(as_conv13)as_conv13 = ReLU()(as_conv13)print('as_13:', as_conv13.shape)as_conv14 = Convolution2D(96, kernel_size=(5, 5), strides=(1, 1), padding='same', dilation_rate=(32, 32), name='as_conv14')(as_conv13)as_conv14 = BatchNormalization()(as_conv14)as_conv14 = ReLU()(as_conv14)print('as_14:', as_conv14.shape)as_conv15 = Convolution2D(8, kernel_size=(1, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='as_conv15')(as_conv14)as_conv15 = BatchNormalization()(as_conv15)as_conv15 = ReLU()(as_conv15)print('as_15:', as_conv15.shape)AS_out = Reshape((298, 8 * 257))(as_conv15)print('AS_out:', AS_out.shape)# --------------------------- AS end ---------------------------# --------------------------- VS_model start ---------------------------VS_model = Sequential()VS_model.add(Convolution2D(256, kernel_size=(7, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='vs_conv1'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(1, 1), name='vs_conv2'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(2, 1), name='vs_conv3'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(4, 1), name='vs_conv4'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(8, 1), name='vs_conv5'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Convolution2D(256, kernel_size=(5, 1), strides=(1, 1), padding='same', dilation_rate=(16, 1), name='vs_conv6'))VS_model.add(BatchNormalization())VS_model.add(ReLU())VS_model.add(Reshape((75, 256, 1)))VS_model.add(UpSampling2DBilinear((298, 256)))VS_model.add(Reshape((298, 256)))# --------------------------- VS_model end ---------------------------video_input = Input(shape=(75, 1, 1792, people_num))AVfusion_list = [AS_out]for i in range(people_num):single_input = Lambda(sliced, arguments={'index': i})(video_input)VS_out = VS_model(single_input)AVfusion_list.append(VS_out)AVfusion = concatenate(AVfusion_list, axis=2)AVfusion = TimeDistributed(Flatten())(AVfusion)print('AVfusion:', AVfusion.shape)lstm = Bidirectional(LSTM(400, input_shape=(298, 8 * 257), return_sequences=True), merge_mode='sum')(AVfusion)print('lstm:', lstm.shape)fc1 = Dense(600, name="fc1", activation='relu', kernel_initializer=he_normal(seed=27))(lstm)print('fc1:', fc1.shape)fc2 = Dense(600, name="fc2", activation='relu', kernel_initializer=he_normal(seed=42))(fc1)print('fc2:', fc2.shape)fc3 = Dense(600, name="fc3", activation='relu', kernel_initializer=he_normal(seed=65))(fc2)print('fc3:', fc3.shape)complex_mask = Dense(257 * 2 * people_num, name="complex_mask", kernel_initializer=glorot_uniform(seed=87))(fc3)print('complex_mask:', complex_mask.shape)complex_mask_out = Reshape((298, 257, 2, people_num))(complex_mask)print('complex_mask_out:', complex_mask_out.shape)AV_model = Model(inputs=[audio_input, video_input], outputs=complex_mask_out)# # compile AV_model# AV_model.compile(optimizer='adam', loss='mse')return AV_model

这个大佬太强了,自愧不如。

这篇关于Looking to Listen at the Cocktail Party 代码详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/278322

相关文章

Redis 的 SUBSCRIBE命令详解

《Redis的SUBSCRIBE命令详解》Redis的SUBSCRIBE命令用于订阅一个或多个频道,以便接收发送到这些频道的消息,本文给大家介绍Redis的SUBSCRIBE命令,感兴趣的朋友跟随... 目录基本语法工作原理示例消息格式相关命令python 示例Redis 的 SUBSCRIBE 命令用于订

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

SpringBoot日志级别与日志分组详解

《SpringBoot日志级别与日志分组详解》文章介绍了日志级别(ALL至OFF)及其作用,说明SpringBoot默认日志级别为INFO,可通过application.properties调整全局或... 目录日志级别1、级别内容2、调整日志级别调整默认日志级别调整指定类的日志级别项目开发过程中,利用日志

Java中的抽象类与abstract 关键字使用详解

《Java中的抽象类与abstract关键字使用详解》:本文主要介绍Java中的抽象类与abstract关键字使用详解,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、抽象类的概念二、使用 abstract2.1 修饰类 => 抽象类2.2 修饰方法 => 抽象方法,没有

MySQL8 密码强度评估与配置详解

《MySQL8密码强度评估与配置详解》MySQL8默认启用密码强度插件,实施MEDIUM策略(长度8、含数字/字母/特殊字符),支持动态调整与配置文件设置,推荐使用STRONG策略并定期更新密码以提... 目录一、mysql 8 密码强度评估机制1.核心插件:validate_password2.密码策略级

从入门到精通详解Python虚拟环境完全指南

《从入门到精通详解Python虚拟环境完全指南》Python虚拟环境是一个独立的Python运行环境,它允许你为不同的项目创建隔离的Python环境,下面小编就来和大家详细介绍一下吧... 目录什么是python虚拟环境一、使用venv创建和管理虚拟环境1.1 创建虚拟环境1.2 激活虚拟环境1.3 验证虚

详解python pycharm与cmd中制表符不一样

《详解pythonpycharm与cmd中制表符不一样》本文主要介绍了pythonpycharm与cmd中制表符不一样,这个问题通常是因为PyCharm和命令行(CMD)使用的制表符(tab)的宽... 这个问题通常是因为PyCharm和命令行(CMD)使用的制表符(tab)的宽度不同导致的。在PyChar

sky-take-out项目中Redis的使用示例详解

《sky-take-out项目中Redis的使用示例详解》SpringCache是Spring的缓存抽象层,通过注解简化缓存管理,支持Redis等提供者,适用于方法结果缓存、更新和删除操作,但无法实现... 目录Spring Cache主要特性核心注解1.@Cacheable2.@CachePut3.@Ca

SpringBoot请求参数传递与接收示例详解

《SpringBoot请求参数传递与接收示例详解》本文给大家介绍SpringBoot请求参数传递与接收示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋... 目录I. 基础参数传递i.查询参数(Query Parameters)ii.路径参数(Path Va