利用clip模型实现text2draw

2024-08-31 16:36
文章标签 实现 模型 clip text2draw

本文主要是介绍利用clip模型实现text2draw,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

这篇关于利用clip模型实现text2draw的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1124419

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

Python实现微信自动锁定工具

《Python实现微信自动锁定工具》在数字化办公时代,微信已成为职场沟通的重要工具,但临时离开时忘记锁屏可能导致敏感信息泄露,下面我们就来看看如何使用Python打造一个微信自动锁定工具吧... 目录引言:当微信隐私遇到自动化守护效果展示核心功能全景图技术亮点深度解析1. 无操作检测引擎2. 微信路径智能获

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

如何在 Spring Boot 中实现 FreeMarker 模板

《如何在SpringBoot中实现FreeMarker模板》FreeMarker是一种功能强大、轻量级的模板引擎,用于在Java应用中生成动态文本输出(如HTML、XML、邮件内容等),本文... 目录什么是 FreeMarker 模板?在 Spring Boot 中实现 FreeMarker 模板1. 环

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

利用python实现对excel文件进行加密

《利用python实现对excel文件进行加密》由于文件内容的私密性,需要对Excel文件进行加密,保护文件以免给第三方看到,本文将以Python语言为例,和大家讲讲如何对Excel文件进行加密,感兴... 目录前言方法一:使用pywin32库(仅限Windows)方法二:使用msoffcrypto-too