视线估计Gaze-Estimation PFLD实现

2023-10-18 11:20

本文主要是介绍视线估计Gaze-Estimation PFLD实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

视线估计Gaze-Estimation PFLD实现

gaze-estimation问题主要是数据集标注及其困难,针对最近开源的一个数据集,实验回归的方式进行了训练。
整个项目源码:https://github.com/ycdhqzhiai/Gaze-PFLD

1.数据集

使用这个数据集:TEyeD: Over 20 million real-world eye images with Pupil, Eyelid, and Iris 2D and 3D Segmentations, 2D and 3D Landmarks, 3D Eyeball, Gaze Vector, and Eye Movement Types

  • 数据集预处理
    这里将其转换为Json格式,只保留landmarks和gaze-vector,其他标注信息没有用到
import os
import cv2
import glob
import numpy as np
import argparse
import json
##注意该代码只能存9999张图片,如果按帧率为30计算,大概最多只能存55分钟视频def parse_args():parser = argparse.ArgumentParser(description="EyeGaze datasets")parser.add_argument("--video_path", type=str, default='DIKABLISVIDEOS', help='videos path')parser.add_argument("--annotations",type=str,  default='ANNOTATIONS', help='videos label path including gaze_vec iris_lm_2D lid_lm_2D pupil_lm_2D')parser.add_argument("--images",type=str,  default='images', help='save_path')parser.add_argument("--draw_img",type=str,  default='draw_img', help='save_path')parser.add_argument("--blind",type=str,  default='blind', help='save_path')parser.add_argument("--json",type=str,  default='json', help='save_path')args = parser.parse_args()return argsdef mkd(path):if not os.path.exists(path):os.makedirs(path)def judge_exists(path):if os.path.exists(path):return Falsereturn Truedef log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements):b1 = judge_exists(agaze_vec)b2 = judge_exists(airis_lm_2D)b3 = judge_exists(alid_lm_2D)b4 = judge_exists(apupil_lm_2D)b5 = judge_exists(aeye_movements)if b1:print('gaze_vec not found!!! EXIT')if b2:print('iris_lm_2D not found!!! EXIT')if b3:print('lid_lm_2D not found!!! EXIT')if b4:print('pupil_lm_2D not found!!! EXIT')if b5:print('eye_movements not found!!! EXIT')if b1 or b2 or b3 or b4 or b5:return Falsereturn Truedef main():args = parse_args()video_list = glob.glob(os.path.join(args.video_path, '*.mp4'))for video in video_list:name = os.path.split(video)[1]# if not '5_2' in name:#     continueimages_dir = os.path.join(args.images, name)draw_img_dir = os.path.join(args.draw_img, name)blind_dir = os.path.join(args.blind, name)json_dir = os.path.join(args.json, name)mkd(images_dir)mkd(draw_img_dir)mkd(blind_dir)mkd(json_dir)agaze_vec = os.path.join(args.annotations, name+'gaze_vec.txt')airis_lm_2D = os.path.join(args.annotations, name+'iris_lm_2D.txt')alid_lm_2D = os.path.join(args.annotations, name+'lid_lm_2D.txt')apupil_lm_2D = os.path.join(args.annotations, name+'pupil_lm_2D.txt')aeye_movements = os.path.join(args.annotations, name+'eye_movements.txt')flage = log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements)if not flage:exit()with open(agaze_vec, 'r') as fgaze_vec:lgaze_vec = fgaze_vec.readlines()[1:]with open(airis_lm_2D, 'r') as firis_lm_2D:liris_lm_2D = firis_lm_2D.readlines()[1:]with open(alid_lm_2D, 'r') as flid_lm_2D:llid_lm_2D = flid_lm_2D.readlines()[1:]with open(apupil_lm_2D, 'r') as fpupil_lm_2D:lpupil_lm_2D = fpupil_lm_2D.readlines()[1:]with open(aeye_movements, 'r') as feye_movements:leye_movements = feye_movements.readlines()[3:]cap = cv2.VideoCapture(video)num = 0while 1:ret, frame = cap.read()if not ret:breaksrc = frame.copy()save_src    = '{}/{}_{:0>5d}.jpg'.format(images_dir, name[:-4], num)save_draw   = '{}/{}_{:0>5d}.jpg'.format(draw_img_dir, name[:-4], num)save_blind  = '{}/{}_{:0>5d}.jpg'.format(blind_dir, name[:-4], num)save_json   = '{}/{}_{:0>5d}.json'.format(json_dir, name[:-4], num)eye_movements = leye_movements[num].strip()[2:3]gaze_vec    = np.array([float(x) for x in lgaze_vec[num].strip().split(';')[1:3]])iris_lm_2D  = np.array([float(x) for x in liris_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#虹膜,中间那块lid_lm_2D   = np.array([float(x) for x in llid_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#眼睑,最外面那块pupil_lm_2D   = np.array([float(x) for x in lpupil_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#瞳孔,最里面那块num += 1if eye_movements == '1':continueeye_c = np.mean(pupil_lm_2D, axis=0).astype(int)for index in range(iris_lm_2D.shape[0]):x_y = iris_lm_2D[index]cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,255,0),-1) # 绿色for index in range(lid_lm_2D.shape[0]):x_y = lid_lm_2D[index]cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (255,0,0),-1) # 蓝色for index in range(pupil_lm_2D.shape[0]):x_y = pupil_lm_2D[index]cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,0,255),-1) # 红色cv2.circle(frame, tuple(eye_c), 1, (255,255,255),-1)cv2.line(frame, tuple(eye_c), tuple(eye_c+(gaze_vec*100).astype(int)), (0,255,255), 1) # 黄色label_dict = {'gaze_vec':gaze_vec.tolist(), 'iris_lm_2D':iris_lm_2D.tolist(), 'lid_lm_2D':lid_lm_2D.tolist(), 'pupil_lm_2D':pupil_lm_2D.tolist()}if -1 in gaze_vec:cv2.imwrite(save_blind, frame)with open(save_json.replace('json\\', 'blind\\'), 'w') as dump_f:json.dump(label_dict,dump_f)else:if num % 3 == 0:cv2.imwrite(save_src, src)with open(save_json, 'w') as dump_f:json.dump(label_dict,dump_f)cv2.imwrite(save_draw, frame)
if __name__ == '__main__':main()

2.训练

使用PFLD来训练gaze-estimation,PFLDInference骨干网络用来预测landmarks,AuxiliaryNet网络用来预测gaze-vector。

  • dataloder
def preprocess_unityeyes_image(img, json_data, datasets, input_width, input_height):ow = 160oh = 96# Prepare to segment eye imageih, iw = img.shape[:2]ih_2, iw_2 = ih/2.0, iw/2.0heatmap_w = int(ow/2)heatmap_h = int(oh/2)#img = cv2.resize(im, (im.shape[1]*3, im.shape[0]*3))#img = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)if datasets == 'B':gaze    = np.array(json_data['gaze'])landmarks  = np.array(json_data['landmarks'])left_corner = landmarks[0]right_corner = landmarks[4]eye_width = 1.5 * abs(left_corner[0] - right_corner[0])eye_middle =  landmarks[24].astype(int)elif datasets == 'E':gaze    = np.array(json_data['gaze_vec'])left_corner = np.array(json_data['lid_lm_2D'])[0]right_corner = np.array(json_data['lid_lm_2D'])[33]eye_width = 1.5 * abs(left_corner[0] - right_corner[0])eye_middle =  np.mean([np.amin(np.array(json_data['iris_lm_2D']), axis=0), np.amax(np.array(json_data['iris_lm_2D']), axis=0)], axis=0)landmarks  = np.concatenate((np.array(json_data['lid_lm_2D']), np.array(json_data['iris_lm_2D']), np.array(json_data['pupil_lm_2D']), eye_middle.reshape(1,2)))else:print('UnityEyes do not write!!!')exit()crop_img, lad = get_img(img, landmarks)crop_img = cv2.resize(crop_img, (input_width,input_height))# if 1:#     print(crop_img.shape)#     for (x, y) in lad:#         color = (0, 255, 0)#         cv2.circle(crop_img, (int(round(x*crop_img.shape[1])), int(round(y*crop_img.shape[0]))), 1, color, -1, lineType=cv2.LINE_AA)#     #crop_img = cv2.resize(crop_img, (160,96))#     cv2.imshow('c', crop_img)#     cv2.waitKey(0)#     exit()return crop_img, lad, gazeclass EyesDataset(data.Dataset):def __init__(self, datasets, dataroot, transforms=None, input_width=160, input_height=112):self.dataroot = datarootself.datasets = datasetsself.input_width = input_widthself.input_height = input_heightself.transforms = transformsif datasets == 'U':self.img_paths = glob.glob(os.path.join(dataroot, 'UnityEyes/images', '/*.jpg'))elif datasets == 'E':self.img_paths = glob.glob(os.path.join(dataroot, 'Eye200W/images', '/*.jpg'))elif datasets == 'B':self.img_paths = glob.glob(os.path.join(dataroot, 'BL_Eye/images', '/*.jpg'))self.img_paths = sorted(self.img_paths)self.json_paths = []for img_path in self.img_paths:json_files = img_path.replace('images', 'json').replace('.jpg', '.json')self.json_paths.append(json_files)def __getitem__(self, index):if torch.is_tensor(index):index = index.tolist()full_img = cv2.imread(self.img_paths[index])with open(self.json_paths[index]) as f:json_data = json.load(f)eye, landmarks, gaze = preprocess_unityeyes_image(full_img, json_data, self.datasets, self.input_width, self.input_height)if self.transforms:eye = self.transforms(eye)return eye, landmarks, gazedef __len__(self):return len(self.img_paths)
  • model
class Gaze_PFLD(nn.Module):def __init__(self):super(Gaze_PFLD, self).__init__()self.lad = PFLDInference()self.gaze = AuxiliaryNet()def forward(self, x):features, landmark = self.lad(x)gaze = self.gaze(features)return landmark, gaze
  • loss
class PFLDLoss(nn.Module):def __init__(self):super(PFLDLoss, self).__init__()self.gaze_loss = nn.MSELoss()def forward(self, landmark_gt, landmarks, gaze_pred, gaze):lad_loss = wing_loss(landmark_gt, landmarks)gaze_loss = self.gaze_loss(gaze_pred, gaze)return gaze_loss*1000, lad_loss
def wing_loss(y_true, y_pred, w=10.0, epsilon=2.0, N_LANDMARK=51):y_pred = y_pred.reshape(-1, N_LANDMARK, 2)y_true = y_true.reshape(-1, N_LANDMARK, 2)x = y_true - y_predc = w * (1.0 - math.log(1.0 + w / epsilon))absolute_x = torch.abs(x)losses = torch.where(w > absolute_x,w * torch.log(1.0 + absolute_x / epsilon),absolute_x - c)loss = torch.mean(torch.sum(losses, axis=[1, 2]), axis=0)return loss

3.demo

import argparse
import numpy as np
import cv2
import torch
import torchvision
from models.pfld import PFLDInference, AuxiliaryNetdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):checkpoint = torch.load(args.model_path, map_location=device)print(checkpoint.keys())pfld_backbone = PFLDInference().to(device)auxiliarynet = AuxiliaryNet().to(device)pfld_backbone.load_state_dict(checkpoint['pfld_backbone'])auxiliarynet.load_state_dict(checkpoint["auxiliarynet"])pfld_backbone.eval()auxiliarynet.eval()pfld_backbone = pfld_backbone.to(device)auxiliarynet = auxiliarynet.to(device)transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])img = cv2.imread('5.png')img = cv2.resize(img, (img.shape[1]*1, img.shape[0]*1))height, width = img.shape[:2]input = cv2.resize(img, (160,112))input = transform(input).unsqueeze(0).to(device)features, landmarks = pfld_backbone(input)gaze = auxiliarynet(features) pre_landmark = landmarks[0]#print(pre_landmark.shape)pre_landmark = pre_landmark.cpu().detach().numpy().reshape(-1, 2) * [width, height]gaze = gaze.cpu().detach().numpy()[0]c_pos = pre_landmark[-1,:]cv2.line(img, tuple(c_pos.astype(int)), tuple(c_pos.astype(int)+(gaze*400).astype(int)), (0,255,0), 1)for (x, y) in pre_landmark.astype(np.int32):cv2.circle(img, (x, y), 1, (0, 0, 255))cv2.imshow('gaze estimation', img)cv2.imwrite('gaze.jpg', img)cv2.waitKey(0)def parse_args():parser = argparse.ArgumentParser(description='Testing')parser.add_argument('--model_path',default="./checkpoint/snapshot/checkpoint_epoch_13.pth.tar",type=str)args = parser.parse_args()return argsif __name__ == "__main__":args = parse_args()main(args)

效果图
在这里插入图片描述

3.export onnx

# from __future__ import absolute_import
# from __future__ import division
# from __future__ import print_function
import argparse
import sys
import time
from models.pfld import Gaze_PFLDimport torch
import torch.nn as nn
import models# def load_model_weight(model, checkpoint):
#     state_dict = checkpoint['model_state_dict']
#     # strip prefix of state_dict
#     if list(state_dict.keys())[0].startswith('module.'):
#         state_dict = {k[7:]: v for k, v in checkpoint['model_state_dict'].items()}#     model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()#     # check loaded parameters and created model parameters
#     for k in state_dict:
#         if k in model_state_dict:
#             if state_dict[k].shape != model_state_dict[k].shape:
#                 print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
#                     k, model_state_dict[k].shape, state_dict[k].shape))
#                 state_dict[k] = model_state_dict[k]
#         else:
#             print('Drop parameter {}.'.format(k))
#     for k in model_state_dict:
#         if not (k in state_dict):
#             print('No param {}.'.format(k))
#             state_dict[k] = model_state_dict[k]
#     model.load_state_dict(state_dict, strict=False)if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights', type=str, default="./checkpoint/snapshot/checkpoint.pth.tar", help='weights path')  # from yolov5/models/parser.add_argument('--img-size', nargs='+', type=int, default=[112, 160], help='image size')  # height, widthparser.add_argument('--batch-size', type=int, default=1, help='batch size')opt = parser.parse_args()opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expanddevice = "cpu"print("=====> load pytorch checkpoint...")checkpoint = torch.load(opt.weights, map_location=torch.device('cpu')) nstack = checkpoint['nstack']nfeatures = checkpoint['nfeatures']nlandmarks = checkpoint['nlandmarks']net = Gaze_PFLD().to(device)net.load_state_dict(checkpoint['gaze_pfld'])img = torch.zeros(1, 1, *opt.img_size).to(device)print(img.shape)landmarks, gaze = net.forward(img)f = opt.weights.replace('.pth.tar', '.onnx')  # filenametorch.onnx.export(net, img, f,export_params=True, verbose=False, opset_version=12, input_names=['inputs'])# # ONNX exporttry:import onnxfrom onnxsim import simplifyprint('\nStarting ONNX export with onnx %s...' % onnx.__version__)f = opt.weights.replace('.pth.tar', '.onnx')  # filenametorch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],output_names=['output'])# Checksonnx_model = onnx.load(f)  # load onnx modelmodel_simp, check = simplify(onnx_model)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, f)print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable modelprint('ONNX export success, saved as %s' % f)except Exception as e:print('ONNX export failure: %s' % e)

这篇关于视线估计Gaze-Estimation PFLD实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/232219

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

Python实现微信自动锁定工具

《Python实现微信自动锁定工具》在数字化办公时代,微信已成为职场沟通的重要工具,但临时离开时忘记锁屏可能导致敏感信息泄露,下面我们就来看看如何使用Python打造一个微信自动锁定工具吧... 目录引言:当微信隐私遇到自动化守护效果展示核心功能全景图技术亮点深度解析1. 无操作检测引擎2. 微信路径智能获

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

如何在 Spring Boot 中实现 FreeMarker 模板

《如何在SpringBoot中实现FreeMarker模板》FreeMarker是一种功能强大、轻量级的模板引擎,用于在Java应用中生成动态文本输出(如HTML、XML、邮件内容等),本文... 目录什么是 FreeMarker 模板?在 Spring Boot 中实现 FreeMarker 模板1. 环

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

利用python实现对excel文件进行加密

《利用python实现对excel文件进行加密》由于文件内容的私密性,需要对Excel文件进行加密,保护文件以免给第三方看到,本文将以Python语言为例,和大家讲讲如何对Excel文件进行加密,感兴... 目录前言方法一:使用pywin32库(仅限Windows)方法二:使用msoffcrypto-too