1 Star 0 Fork 0

endless/yq_notes_img1

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
西风影像线性拉伸.py 5.25 KB
一键复制 编辑 原始数据 按行查看 历史
endless 提交于 2022-11-12 11:57 . bq
import os
import sys
import glob
from osgeo import gdal
import numpy as np
import cv2
import tqdm
import os
os.environ['CPL_ZIP_ENCODING'] = 'UTF-8'
os.environ['PROJ_LIB'] = r'D:\anaconda3\envs\pytorch_GPU\Lib\site-packages\pyproj\proj_dir\share\proj'
os.environ['GDAL_DATA'] = r'D:\anaconda3\envs\pytorch_GPU\Library\share'
def CalHistogram(img):
img_dtype = img.dtype
img_hist = img.reshape(-1)
img_min, img_max = img_hist.min(), img_hist.max()
n_bins = 2 ** 16
if (img_dtype == np.uint8):
n_bins = 256
if (img_dtype == np.uint16):
n_bins = 2 ** 16
elif (img_dtype == np.uint32):
n_bins = 2 ** 32
if (img_dtype == np.uint8) or (img_dtype == np.uint16) or (img_dtype == np.uint32):
hist = np.bincount(img_hist, minlength=n_bins)
hist[0] = 0
hist[-1] = 0
s_values = np.arange(n_bins)
else:
hist, s_values = np.histogram(img_hist, bins=n_bins, range=(img_min, img_max))
hist[0] = 0
hist[-1] = 0
img_hist = None
return hist, s_values
def GetPercentStretchValue(img, left_clip=0.001, right_clip=0.001):
right_clip = 1.0 - right_clip
hist, s_values = CalHistogram(img)
s_quantiles = np.cumsum(hist).astype(np.float64)
s_quantiles /= (s_quantiles[-1] + 1.0E-5)
left_clip_index = np.argmin(np.abs(s_quantiles - left_clip))
right_clip_index = np.argmin(np.abs(s_quantiles - right_clip))
img_min_clip, img_max_clip = s_values[[left_clip_index, right_clip_index]]
return img_min_clip, img_max_clip
def percent_stretch_image(input_image_data, left_clip=0.001, right_clip=0.001, left_mask=None,
right_mask=None):
if input_image_data is None:
return None
n_dim = input_image_data.ndim
img_bands = 1 if n_dim == 2 else input_image_data.shape[n_dim - 1]
xsize = input_image_data.shape[1]
ysize = input_image_data.shape[0]
indtype = input_image_data.dtype
if indtype == np.uint8:
to_8bit = True
if img_bands > 1:
out_8bit_data = np.zeros((ysize, xsize, img_bands), dtype=np.uint8)
else:
out_8bit_data = np.zeros((ysize, xsize), dtype=np.uint8)
for i_band in range(img_bands):
if img_bands == 1:
input_image_data_raw = input_image_data # [:,:,i_band]
else:
input_image_data_raw = input_image_data[:, :, i_band]
img_clip_min, img_clip_max = GetPercentStretchValue(input_image_data_raw, left_clip=left_clip,
right_clip=right_clip)
input_image_data_raw = np.clip(input_image_data_raw, img_clip_min, img_clip_max)
input_image_data_raw = (input_image_data_raw - img_clip_min) / (img_clip_max - img_clip_min) * 255
input_image_data_raw = input_image_data_raw.astype(np.uint8)
if img_bands > 1:
out_8bit_data[:, :, i_band] = input_image_data_raw
else:
out_8bit_data = input_image_data_raw
return out_8bit_data
def read_tif(file_path):
tif_f = file_path
ds = gdal.Open(tif_f)
if ds == None:
print("Error || Can't open {0} as tif file.".format(tif_f))
return
cols = ds.RasterXSize
rows = ds.RasterYSize
bands = ds.RasterCount
pro = ds.GetProjection()
# 获取仿射矩阵信息
geotrans = ds.GetGeoTransform()
data_set = np.zeros((rows, cols, bands))
for i in range(bands):
band = ds.GetRasterBand(i + 1)
data_type = gdal.GetDataTypeName(band.DataType).lower()
data_set[:, :, i] = band.ReadAsArray()
data_set = np.array(data_set, dtype=data_type)
del ds
return data_set, pro
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), gdal.GDT_Byte)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
print(f'波段总和{im_bands}')
for i in tqdm.tqdm(range(im_bands)):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
if __name__ == '__main__':
in_file = r"G:\02工作空间\beijing_img\data\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536_pansharpen.tif"
img, pro,geotrans = read_tif(in_file)
n_dim = img.ndim
img_bands = 1 if n_dim == 2 else img.shape[n_dim - 1]
print(img.min(), img.mean(), img.max())
img_raw_s = (img - img.min()) / (img.max() - img.min()) * 255
print('img raw s:', img_raw_s.min(), img_raw_s.mean(), img_raw_s.max())
img = percent_stretch_image(img)
path = r"G:\02工作空间\beijing_img\data\GF2_PMS1_E110.7_N21.4_20210223_L1A0005501536_Line.tif"
writeTiff(img, geotrans, pro, path)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/long_chaohuo/yq_notes_img1.git
git@gitee.com:long_chaohuo/yq_notes_img1.git
long_chaohuo
yq_notes_img1
yq_notes_img1
master

搜索帮助