1 Star 4 Fork 2

Gitee 极速下载/anytext

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/tyxsspa/AnyText
克隆/下载
t3_dataset.py 17.27 KB
一键复制 编辑 原始数据 按行查看 历史
import os
import numpy as np
import cv2
import random
import math
import time
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset, DataLoader
from dataset_util import load, show_bbox_on_image
phrase_list = [
', content and position of the texts are ',
', textual material depicted in the image are ',
', texts that says ',
', captions shown in the snapshot are ',
', with the words of ',
', that reads ',
', the written materials on the picture: ',
', these texts are written on it: ',
', captions are ',
', content of the text in the graphic is '
]
def insert_spaces(string, nSpace):
if nSpace == 0:
return string
new_string = ""
for char in string:
new_string += char + " " * nSpace
return new_string[:-nSpace]
def draw_glyph(font, text):
g_size = 50
W, H = (512, 80)
new_font = font.font_variant(size=g_size)
img = Image.new(mode='1', size=(W, H), color=0)
draw = ImageDraw.Draw(img)
left, top, right, bottom = new_font.getbbox(text)
text_width = max(right-left, 5)
text_height = max(bottom - top, 5)
ratio = min(W*0.9/text_width, H*0.9/text_height)
new_font = font.font_variant(size=int(g_size*ratio))
text_width, text_height = new_font.getsize(text)
offset_x, offset_y = new_font.getoffset(text)
x = (img.width - text_width) // 2
y = (img.height - text_height) // 2 - offset_y//2
draw.text((x, y), text, font=new_font, fill='white')
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
return img
def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True):
enlarge_polygon = polygon*scale
rect = cv2.minAreaRect(enlarge_polygon)
box = cv2.boxPoints(rect)
box = np.int0(box)
w, h = rect[1]
angle = rect[2]
if angle < -45:
angle += 90
angle = -angle
if w < h:
angle += 90
vert = False
if (abs(angle) % 90 < vertAng or abs(90-abs(angle) % 90) % 90 < vertAng):
_w = max(box[:, 0]) - min(box[:, 0])
_h = max(box[:, 1]) - min(box[:, 1])
if _h >= _w:
vert = True
angle = 0
img = np.zeros((height*scale, width*scale, 3), np.uint8)
img = Image.fromarray(img)
# infer font size
image4ratio = Image.new("RGB", img.size, "white")
draw = ImageDraw.Draw(image4ratio)
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
text_w = min(w, h) * (_tw / _th)
if text_w <= max(w, h):
# add space
if len(text) > 1 and not vert and add_space:
for i in range(1, 100):
text_space = insert_spaces(text, i)
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
if min(w, h) * (_tw2 / _th2) > max(w, h):
break
text = insert_spaces(text, i-1)
font_size = min(w, h)*0.80
else:
shrink = 0.75 if vert else 0.85
font_size = min(w, h) / (text_w/max(w, h)) * shrink
new_font = font.font_variant(size=int(font_size))
left, top, right, bottom = new_font.getbbox(text)
text_width = right-left
text_height = bottom - top
layer = Image.new('RGBA', img.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
if not vert:
draw.text((rect[0][0]-text_width//2, rect[0][1]-text_height//2-top), text, font=new_font, fill=(255, 255, 255, 255))
else:
x_s = min(box[:, 0]) + _w//2 - text_height//2
y_s = min(box[:, 1])
for c in text:
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
_, _t, _, _b = new_font.getbbox(c)
y_s += _b
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
x_offset = int((img.width - rotated_layer.width) / 2)
y_offset = int((img.height - rotated_layer.height) / 2)
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
img = np.expand_dims(np.array(img.convert('1')), axis=2).astype(np.float64)
return img
def get_caption_pos(ori_caption, pos_idxs, prob=1.0, place_holder='*'):
idx2pos = {
0: " top left",
1: " top",
2: " top right",
3: " left",
4: random.choice([" middle", " center"]),
5: " right",
6: " bottom left",
7: " bottom",
8: " bottom right"
}
new_caption = ori_caption + random.choice(phrase_list)
pos = ''
for i in range(len(pos_idxs)):
if random.random() < prob and pos_idxs[i] > 0:
pos += place_holder + random.choice([' located', ' placed', ' positioned', '']) + random.choice([' at', ' in', ' on']) + idx2pos[pos_idxs[i]] + ', '
else:
pos += place_holder + ' , '
pos = pos[:-2] + '.'
new_caption += pos
return new_caption
def generate_random_rectangles(w, h, box_num):
rectangles = []
for i in range(box_num):
x = random.randint(0, w)
y = random.randint(0, h)
w = random.randint(16, 256)
h = random.randint(16, 96)
angle = random.randint(-45, 45)
p1 = (x, y)
p2 = (x + w, y)
p3 = (x + w, y + h)
p4 = (x, y + h)
center = ((x + x + w) / 2, (y + y + h) / 2)
p1 = rotate_point(p1, center, angle)
p2 = rotate_point(p2, center, angle)
p3 = rotate_point(p3, center, angle)
p4 = rotate_point(p4, center, angle)
rectangles.append((p1, p2, p3, p4))
return rectangles
def rotate_point(point, center, angle):
# rotation
angle = math.radians(angle)
x = point[0] - center[0]
y = point[1] - center[1]
x1 = x * math.cos(angle) - y * math.sin(angle)
y1 = x * math.sin(angle) + y * math.cos(angle)
x1 += center[0]
y1 += center[1]
return int(x1), int(y1)
class T3DataSet(Dataset):
def __init__(
self,
json_path,
max_lines=5,
max_chars=20,
place_holder='*',
font_path='./font/Arial_Unicode.ttf',
caption_pos_prob=1.0,
mask_pos_prob=1.0,
mask_img_prob=0.5,
for_show=False,
using_dlc=False,
glyph_scale=1,
percent=1.0,
debug=False,
wm_thresh=1.0,
):
assert isinstance(json_path, (str, list))
if isinstance(json_path, str):
json_path = [json_path]
data_list = []
self.using_dlc = using_dlc
self.max_lines = max_lines
self.max_chars = max_chars
self.place_holder = place_holder
self.font = ImageFont.truetype(font_path, size=60)
self.caption_pos_porb = caption_pos_prob
self.mask_pos_prob = mask_pos_prob
self.mask_img_prob = mask_img_prob
self.for_show = for_show
self.glyph_scale = glyph_scale
self.wm_thresh = wm_thresh
for jp in json_path:
data_list += self.load_data(jp, percent)
self.data_list = data_list
print(f'All dataset loaded, imgs={len(self.data_list)}')
self.debug = debug
if self.debug:
self.tmp_items = [i for i in range(100)]
def load_data(self, json_path, percent):
tic = time.time()
content = load(json_path)
d = []
count = 0
wm_skip = 0
max_img = len(content['data_list']) * percent
for gt in content['data_list']:
if len(d) > max_img:
break
if 'wm_score' in gt and gt['wm_score'] > self.wm_thresh: # wm_score > thresh will be skiped as an img with watermark
wm_skip += 1
continue
data_root = content['data_root']
if self.using_dlc:
data_root = data_root.replace('/data/vdb', '/mnt/data', 1)
img_path = os.path.join(data_root, gt['img_name'])
info = {}
info['img_path'] = img_path
info['caption'] = gt['caption'] if 'caption' in gt else ''
if self.place_holder in info['caption']:
count += 1
info['caption'] = info['caption'].replace(self.place_holder, " ")
if 'annotations' in gt:
polygons = []
invalid_polygons = []
texts = []
languages = []
pos = []
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0:
continue
if 'valid' in annotation and annotation['valid'] is False:
invalid_polygons.append(annotation['polygon'])
continue
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
languages.append(annotation['language'])
if 'pos' in annotation:
pos.append(annotation['pos'])
info['polygons'] = [np.array(i) for i in polygons]
info['invalid_polygons'] = [np.array(i) for i in invalid_polygons]
info['texts'] = texts
info['language'] = languages
info['pos'] = pos
d.append(info)
print(f'{json_path} loaded, imgs={len(d)}, wm_skip={wm_skip}, time={(time.time()-tic):.2f}s')
if count > 0:
print(f"Found {count} image's caption contain placeholder: {self.place_holder}, change to ' '...")
return d
def __getitem__(self, item):
item_dict = {}
if self.debug: # sample fixed items
item = self.tmp_items.pop()
print(f'item = {item}')
cur_item = self.data_list[item]
# img
target = np.array(Image.open(cur_item['img_path']).convert('RGB'))
if target.shape[0] != 512 or target.shape[1] != 512:
target = cv2.resize(target, (512, 512))
target = (target.astype(np.float32) / 127.5) - 1.0
item_dict['img'] = target
# caption
item_dict['caption'] = cur_item['caption']
item_dict['glyphs'] = []
item_dict['gly_line'] = []
item_dict['positions'] = []
item_dict['texts'] = []
item_dict['language'] = []
item_dict['inv_mask'] = []
texts = cur_item.get('texts', [])
if len(texts) > 0:
idxs = [i for i in range(len(texts))]
if len(texts) > self.max_lines:
sel_idxs = random.sample(idxs, self.max_lines)
unsel_idxs = [i for i in idxs if i not in sel_idxs]
else:
sel_idxs = idxs
unsel_idxs = []
if len(cur_item['pos']) > 0:
pos_idxs = [cur_item['pos'][i] for i in sel_idxs]
else:
pos_idxs = [-1 for i in sel_idxs]
item_dict['caption'] = get_caption_pos(item_dict['caption'], pos_idxs, self.caption_pos_porb, self.place_holder)
item_dict['polygons'] = [cur_item['polygons'][i] for i in sel_idxs]
item_dict['texts'] = [cur_item['texts'][i][:self.max_chars] for i in sel_idxs]
item_dict['language'] = [cur_item['language'][i] for i in sel_idxs]
# glyphs
for idx, text in enumerate(item_dict['texts']):
gly_line = draw_glyph(self.font, text)
glyphs = draw_glyph2(self.font, text, item_dict['polygons'][idx], scale=self.glyph_scale)
item_dict['glyphs'] += [glyphs]
item_dict['gly_line'] += [gly_line]
# mask_pos
for polygon in item_dict['polygons']:
item_dict['positions'] += [self.draw_pos(polygon, self.mask_pos_prob)]
# inv_mask
invalid_polygons = cur_item['invalid_polygons'] if 'invalid_polygons' in cur_item else []
if len(texts) > 0:
invalid_polygons += [cur_item['polygons'][i] for i in unsel_idxs]
item_dict['inv_mask'] = self.draw_inv_mask(invalid_polygons)
item_dict['hint'] = self.get_hint(item_dict['positions'])
if random.random() < self.mask_img_prob:
# randomly generate 0~3 masks
box_num = random.randint(0, 3)
boxes = generate_random_rectangles(512, 512, box_num)
boxes = np.array(boxes)
pos_list = item_dict['positions'].copy()
for i in range(box_num):
pos_list += [self.draw_pos(boxes[i], self.mask_pos_prob)]
mask = self.get_hint(pos_list)
masked_img = target*(1-mask)
else:
masked_img = np.zeros_like(target)
item_dict['masked_img'] = masked_img
if self.for_show:
item_dict['img_name'] = os.path.split(cur_item['img_path'])[-1]
return item_dict
if len(texts) > 0:
del item_dict['polygons']
# padding
n_lines = min(len(texts), self.max_lines)
item_dict['n_lines'] = n_lines
n_pad = self.max_lines - n_lines
if n_pad > 0:
item_dict['glyphs'] += [np.zeros((512*self.glyph_scale, 512*self.glyph_scale, 1))] * n_pad
item_dict['gly_line'] += [np.zeros((80, 512, 1))] * n_pad
item_dict['positions'] += [np.zeros((512, 512, 1))] * n_pad
item_dict['texts'] += [' '] * n_pad
item_dict['language'] += [' '] * n_pad
return item_dict
def __len__(self):
return len(self.data_list)
def draw_inv_mask(self, polygons):
img = np.zeros((512, 512))
for p in polygons:
pts = p.reshape((-1, 1, 2))
cv2.fillPoly(img, [pts], color=255)
img = img[..., None]
return img/255.
def draw_pos(self, ploygon, prob=1.0):
img = np.zeros((512, 512))
rect = cv2.minAreaRect(ploygon)
w, h = rect[1]
small = False
if w < 20 or h < 20:
small = True
if random.random() < prob:
pts = ploygon.reshape((-1, 1, 2))
cv2.fillPoly(img, [pts], color=255)
# 10% dilate / 10% erode / 5% dilatex2 5% erodex2
random_value = random.random()
kernel = np.ones((3, 3), dtype=np.uint8)
if random_value < 0.7:
pass
elif random_value < 0.8:
img = cv2.dilate(img.astype(np.uint8), kernel, iterations=1)
elif random_value < 0.9 and not small:
img = cv2.erode(img.astype(np.uint8), kernel, iterations=1)
elif random_value < 0.95:
img = cv2.dilate(img.astype(np.uint8), kernel, iterations=2)
elif random_value < 1.0 and not small:
img = cv2.erode(img.astype(np.uint8), kernel, iterations=2)
img = img[..., None]
return img/255.
def get_hint(self, positions):
if len(positions) == 0:
return np.zeros((512, 512, 1))
return np.sum(positions, axis=0).clip(0, 1)
if __name__ == '__main__':
'''
Run this script to show details of your dataset, such as ocr annotations, glyphs, prompts, etc.
'''
from tqdm import tqdm
from matplotlib import pyplot as plt
import shutil
show_imgs_dir = 'show_results'
show_count = 50
if os.path.exists(show_imgs_dir):
shutil.rmtree(show_imgs_dir)
os.makedirs(show_imgs_dir)
plt.rcParams['axes.unicode_minus'] = False
json_paths = [
'/path/of/your/dataset/data1.json',
'/path/of/your/dataset/data2.json',
# ...
]
dataset = T3DataSet(json_paths, for_show=True, max_lines=20, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
pbar = tqdm(total=show_count)
for i, data in enumerate(train_loader):
if i == show_count:
break
img = ((data['img'][0].numpy() + 1.0) / 2.0 * 255).astype(np.uint8)
masked_img = ((data['masked_img'][0].numpy() + 1.0) / 2.0 * 255)[..., ::-1].astype(np.uint8)
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_masked.jpg'), masked_img)
if 'texts' in data and len(data['texts']) > 0:
texts = [x[0] for x in data['texts']]
img = show_bbox_on_image(Image.fromarray(img), data['polygons'], texts)
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}.jpg'), np.array(img)[..., ::-1])
with open(os.path.join(show_imgs_dir, f'plots_{i}.txt'), 'w') as fin:
fin.writelines([data['caption'][0]])
all_glyphs = []
for k, glyphs in enumerate(data['glyphs']):
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_glyph_{k}.jpg'), glyphs[0].numpy().astype(np.int32)*255)
all_glyphs += [glyphs[0].numpy().astype(np.int32)*255]
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_allglyphs.jpg'), np.sum(all_glyphs, axis=0))
for k, gly_line in enumerate(data['gly_line']):
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_gly_line_{k}.jpg'), gly_line[0].numpy().astype(np.int32)*255)
for k, position in enumerate(data['positions']):
if position is not None:
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_pos_{k}.jpg'), position[0].numpy().astype(np.int32)*255)
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_hint.jpg'), data['hint'][0].numpy().astype(np.int32)*255)
cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_inv_mask.jpg'), np.array(img)[..., ::-1]*(1-data['inv_mask'][0].numpy().astype(np.int32)))
pbar.update(1)
pbar.close()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/anytext.git
git@gitee.com:mirrors/anytext.git
mirrors
anytext
anytext
main

搜索帮助