代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
# @Author : LG
import torch
from PIL import Image
from data.transform import Transforms
import numpy as np
class Transfer:
def __init__(self, checkpoint:str, device='cuda:0'):
self.device = device
self.model = torch.load(checkpoint)
self.model.to(self.device)
self.model.eval()
self.model.requires_grad_(False)
def transfer(self, image:str):
image = Image.open(image).convert('RGB')
image = Transforms((256, 256), (256, 256), is_train=False)(image)
image = image.unsqueeze(0)
image = image.to(self.device)
fake: torch.Tensor = self.model(image)
fake = fake.data
fake_numpy = fake[0].cpu().float().numpy()
fake_numpy = (np.transpose(fake_numpy, (1, 2, 0)) + 1) / 2.0 * 255
return fake_numpy.astype(np.uint8)
if __name__ == '__main__':
# 导入模型,进行风格迁移.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
t = Transfer('checkpoints/pretrained/latest_netG_B.pth', device)
fake = t.transfer('datasets/horse2zebra/testA/n02381460_140.jpg')
fake_image = Image.fromarray(fake)
fake_image.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。