1 Star 0 Fork 0

snow-tyan/classifier

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
temp.py 3.68 KB
一键复制 编辑 原始数据 按行查看 历史
snow-tyan 提交于 2021-11-20 14:53 . add all
import cv2
import os
import albumentations as A
from albumentations import pytorch as AT
import pandas as pd
from PIL import Image
import numpy as np
from utils.dataset import TrainDataset
from tqdm import tqdm
import warnings
warnings.filterwarnings('error')
input_size = 224
albu_transform = {
'train': A.Compose([
# A.LongestMaxSize((int(input_size * (256 / 224)))),
# A.PadIfNeeded((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
# change scale
A.Resize((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
A.RandomCrop(input_size, input_size),
A.SomeOf([
A.RandomRotate90(),
A.HorizontalFlip(),
A.VerticalFlip(),
A.Flip(),
], 2),
A.ShiftScaleRotate(),
A.OneOf([
A.GaussianBlur(blur_limit=(3, 5)),
A.MedianBlur(blur_limit=3),
# A.MotionBlur(), # 运动模糊
], p=0.3),
A.SomeOf([
A.RandomBrightnessContrast(),
A.HueSaturationValue(),
A.RGBShift(),
A.ChannelShuffle(),
], 2),
A.OneOf([
A.CoarseDropout(),
A.GridDropout(),
]),
# A.Normalize(), # default imagenet std and mean
A.Normalize(mean=(0.638, 0.568, 0.570),
std=(0.245, 0.255, 0.255)),
AT.ToTensorV2(p=1.0) # include HWC -> CHW
]),
'val': A.Compose([
# A.LongestMaxSize((int(input_size * (256 / 224)))),
# # 默认反射填充 零填充 border_mode=cv2.BORDER_CONSTANT
# A.PadIfNeeded((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
A.Resize((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
A.CenterCrop(input_size, input_size),
# A.Normalize(), # default imagenet std and mean
A.Normalize(mean=(0.638, 0.568, 0.570),
std=(0.245, 0.255, 0.255)),
AT.ToTensorV2(p=1.0) # include HWC -> CHW
])
}
dataset_path = '../Dataset-fu'
train_path = '../Dataset-fu/test'
df = pd.read_csv(os.path.join(dataset_path, 'test.csv'))
train_error_list = df['image'].tolist()
# train_error_list = ['a1627.jpg']
# deled_list = ['1/14907.jpg', '1/42990.jpg', '1/53097.jpg',
# '1/54576.jpg', '1/8094.jpg', '1/82662.jpg',
# '1/87982.jpg', '108/48843.jpg', '116/34316.jpg',
# '2/18300.jpg', '2/28808.jpg', '2/47862.jpg',
# '2/52915.jpg', '2/61739.jpg', '2/72924.jpg']
del_list = []
for img_name in tqdm(train_error_list):
image_error = Image.open(os.path.join(train_path, img_name))
# print(image_error.mode, image_error.format)
img = np.array(image_error)
# print(img.shape)
try:
image = cv2.imread(os.path.join(train_path, img_name))
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image[:, :, ::-1]
image = albu_transform['train'](image=image)['image']
# print(image.shape)
except Warning as e:
print(e)
del_list.append(img_name)
# print(img_name, 'CV2 ERROR READ')
print(del_list)
# image = cv2.imread(os.path.join(train_path, '1/14907.jpg'))
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# img = Image.open(os.path.join(train_path, '1/14907.jpg')).convert('RGB')
# img = np.array(img)
# print(img == image)
# import pandas as pd
# l = ['1/10774.jpg', '103/64158.jpg', '103/6697.jpg', '128/28305.jpg',
# '128/30478.jpg', '27/15745.jpg', '27/50899.jpg', '27/89127.jpg']
# df = pd.Series(l)
# import timm
# print(timm.list_models(pretrained=True))
if 0.5:
print('hi')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/snow-tyan/classifier.git
git@gitee.com:snow-tyan/classifier.git
snow-tyan
classifier
classifier
master

搜索帮助