2 Star 1 Fork 0

MM-NUDT/IRRS

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 2.50 KB
一键复制 编辑 原始数据 按行查看 历史
hejy47 提交于 2021-04-05 00:56 . 代码修改
import os
import shutil
import random
import numpy as np
import pandas as pd
from PIL import Image
def split():
rock_label = pd.read_csv('./data/rock_label_1.csv')
rock_label = rock_label.iloc[:,:].values
data_dist = {}
for rock in rock_label:
data_dist[rock[0]] = rock[1]
origin_path = './data/rock/'
train_path = './data/TrainFolder/'
test_path = './data/TestFolder/'
rocks = os.listdir(origin_path)
rocks = list(filter(lambda r : r[-5] == '1', rocks))
for rock in rocks:
class_name = data_dist[int(rock[:-6])]
class_path = os.path.join(train_path, class_name)
if not os.path.exists(class_path):
os.makedirs(class_path)
rock_path = os.path.join(origin_path, rock)
img = Image.open(rock_path)
w, h = img.size
if rock_path[-3:] == "jpg":
img = img.crop((w // 5, h // 5, w * 4 // 5, h * 4 // 5))
w, h = img.size
id = 0
for x in range(0, w, 112):
for y in range(0, h, 112):
x_t = x + 224
y_t = y + 224
if x_t > w: x_t = w
if y_t > h: y_t = h
cropped = img.crop((x_t - 224, y_t - 224, x_t, y_t))
cropped.save(os.path.join(class_path, rock[:-6]+"_"+str(id)+".jpg"))
id += 1
class_list = os.listdir(train_path)
for class_name in class_list:
class_train_path = os.path.join(train_path, class_name)
class_test_path = os.path.join(test_path, class_name)
if not os.path.exists(class_test_path):
os.makedirs(class_test_path)
file_list = os.listdir(class_train_path)
random.shuffle(file_list)
test_list = file_list[:len(file_list) // 10]
for test_file in test_list:
file_path = os.path.join(class_train_path, test_file)
shutil.move(file_path, class_test_path)
def get_mean_std():
origin_path = './data/rock/'
rocks = os.listdir(origin_path)
rocks = list(filter(lambda r : r[-5] == '1', rocks))
print('Calculate the mean and variance of the data')
mean, std = np.zeros(3), np.zeros(3)
for rock in rocks:
rock_path = os.path.join(origin_path, rock)
img = Image.open(rock_path)
data = np.array(img)
for d in range(3):
mean[d] += np.mean(data[..., d])
std[d] += np.std(data[..., d])
mean = mean / len(rocks)
std = std / len(rocks)
print(mean)
print(std)
if __name__ == '__main__':
split()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mm-nudt/IRRS.git
git@gitee.com:mm-nudt/IRRS.git
mm-nudt
IRRS
IRRS
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385