1 Star 0 Fork 0

Admin/julia_dispat

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
core.jl 13.18 KB
一键复制 编辑 原始数据 按行查看 历史
Admin 提交于 2022-04-15 15:34 . first commit
using Images, ImageBinarization;
using ImageView, Revise;
using Distances;
using LsqFit, Plots;
using MultivariateStats, Clustering;
using Statistics;
using Random;
"""
_GetPattern(im::T, position::Vector{Int}=[2, 2], pat::Vector{Int}=[1, 1])
im:需要传入训练图像矩阵
position:则是中心节点的坐标,行是y坐标,列是x坐标
pat:传入模板的半径(不计中心节点的)
"""
function _GetPattern(im::T, position::Vector{Int}=[2, 2], pat::Vector{Int}=[1, 1]) where {T <: Array}
return im[position[1]-pat[1]:position[1]+pat[1],position[2]-pat[2]:position[2]+pat[2]];
end
"""
_GetArray(pattern::T) where {T <: Matrix}
pattern:矩阵形式的模板,待向量坐标化
"""
function _GetArray(pattern::T) where {T <: Matrix}
return reshape(pattern, 1, length(pattern));
end
"""
_GetPattern(im::T, position::Vector{Int}=[2, 2], pat::Vector{Int}=[1, 1])
im:需要传入训练图像矩阵
position:则是中心节点的坐标,行是y坐标,列是x坐标
pat:传入模板的半径(不计中心节点的)
很函数不使用向量化语言实现 _GetArray(_GetPattern()),经过比较,本方法慢三倍
"""
function _GetArray2(im::T, position::Vector{Int}=[2,2], pat::Vector{Int}=[1,1]) where {T<:Array}
row = (position[1] - pat[1]) : (position[1] + pat[1]);
col = (position[2] - pat[2]) : (position[2] + pat[2]);
res = -ones(length(row) * length(col));
index = 0;
for i = row
for j = col
index += 1;
res[index] = im[i, j];
end
end
return res
end
"""
_PatternDataBase(im::T, pat::Vector{Int}=[1, 1], skip::Vector{Int}=[1, 1])
im:需要传入训练图像矩阵
pat:传入模板的半径(不计中心节点的)
skip:是模式跳过的级别
"""
function _PatternDataBase(im::T, pat::Vector{Int}=[1, 1], skip::Vector{Int}=[1, 1]) where {T <: Array}
n, m = size(im);
row, col = (1+pat[1]):skip[1]:(n-pat[1]), (1+pat[2]):skip[2]:(m-pat[2]);
database = zeros(*(length(row), length(col)), prod((pat*2).+1));
# 先左右,后上下,依次提取模式,所以外层循环行,里层循环列
count = 0;
for i = row
for j = col
count += 1
database[count,:] = _GetArray(_GetPattern(im, [i, j], pat));
end
end
return database;
end
"""
_entropy(im::T)
im:待计算信息熵的像素集群,组织形式可以是二、三维矩阵,也可以是一维向量
"""
function _entropy(im::T) where {T <: Array}
l = reshape(im, 1, length(im));
n = length(l);
bin = -ones(n);
count = zeros(n);
point = 0;
for i = l
if i in bin
index = (bin.==i);
count[index].+=1;
else
point+=1;
bin[point] = i;
count[point] = 1;
end
end
p = count[1:point] ./ n;
# println(p)
# println(count)
# println(bin)
return -sum(p .* log2.(p))
end
"""
_entropy_ti(db::T)
db:训练图像在某一个模板下的模式数据库,每一行表示一个模式的高维坐标
"""
function _entropy_ti(db::T) where {T <: Array}
n, m = size(db);
sum = zero(0.0);
for i = 1:n
sum+=_entropy(db[i,:]);
end
return sum/n;
end
"""
_cal_entropy(im::T, rat::Float64=1.0)
im:训练图像
rat:扫描训练图像模板最大尺度与训练图像最小变尺度的比,取值范围 0.0-1.0
"""
function _cal_entropy(im::T, rat::Float64=1.0) where {T <: Array}
n, m = size(im);
radii = ceil(Int,min(n, m)*rat / 2 - 1);
mean_ents = zeros(radii, 1)
for i = 1:radii
db = _PatternDataBase(im, [i, i]);
mean_ents[i] = _entropy_ti(db);
end
return mean_ents
end
@. exp_model(x, p) = p[1] * (1 - exp(-x/p[2])) # 指数模型
@. guass_model(x, p) = p[1] * (1 - exp(-(x^2)/(p[2]^2))) # 高斯模型
@. sphere_model(x, p) = (0<x<p[2]) * p[1] * (3x / (2p[2]) - x^3 / (2 * (p[2]^3))) + (x>p[2]) * p[1] # 球状模型
function exp_fit(xdata, ydata)
p0 = [0.5, 0.5];
return curve_fit(exp_model, xdata, ydata, p0);
end
function guass_fit(xdata, ydata)
p0 = [0.5, 0.5];
return curve_fit(guass_model, xdata, ydata, p0);
end
function sphere_fit(xdata, ydata)
p0 = [0.5, 0.5];
return curve_fit(sphere_model, xdata, ydata, p0);
end
"""
"""
function _show_fit(fit, model, xdata, ydata)
x = 1:0.01:xdata[end];
scatter(ydata, label=false);
plot!(x, model(x, fit.param), label=false);
end
"""
_information_scales(im)
im:训练图像
计算信息临界尺度
"""
function _information_scales(im)
ydata = _cal_entropy(im);
# println(ydata,size(ydata))
ydata = vec(ydata);
xdata = 1:length(ydata);
fitres = exp_fit(xdata, ydata);
c, a = fitres.param
return ceil(Int, 3 * a)
end
"""
_recal_centers(db, id, k, pat)
db:模式数据库
id:模式标签
K:聚类数
pat:模板尺度
计算原本维度的聚类中心
"""
function _recal_centers(db, id, k, pat)
# n = length(Set(id));
m = (2pat[1]+1)*(2pat[2]+1)
centers = zeros(k, m);
for i =1:k
index = (id.==i)
dbi = db[index, :];
centers[i, :] = mean(dbi, dims=1);
end
return centers
end
"""
_sampling(id, i)
id:分类标签
i:寻找的类别
返回该类别随机的一个模式在db里的行数
"""
function _sampling(id, i)
index = 1 : length(id);
ind = id.==i;
index = index[ind];
count = sum(ind);
sample = rand(1:count);
return index[sample]
end
"""
_cluster_a_node(node, sim_grid::Matrix, pattern::Vector{Int}, centers)
node:节点坐标
sim_grid:是否满足条件的示性向量
pattern:模式尺度
centers:聚类中心的坐标库
找到距离节点最近的聚类中心,有多个则返回第一个中心
"""
function _cluster_a_node(node, sim_grid::Matrix, pat::Vector{Int}, centers)
# 首先针对这一节点提取数据事件
node = node + pat
data_event = _GetArray(_GetPattern(sim_grid, node, pat))
# 计算这一数据事件与所有聚类中心的距离含-0.5的节点不参与计算
# index = data_event.!=0.5
# data_event = data_event[index]
# dim = 1:prod(2*pat.+1)
# dims = _find_index(dim, index)
# centers_need = centers[:, dim]
centers_need = centers
# 计算距离
distances = colwise(euclidean, data_event, centers_need')
index = distances.==minimum(distances)
dim = 1:length(distances)
return _find_index(dim, index)[1]
end
"""
_find_index(arr::UnitRange{Int64}, index::BitMatrix)
arr:所有的下标
index:是否满足条件的示性向量
找到满足条件的下标
"""
function _find_index(arr::UnitRange{Int64}, index)
dims = zeros(Int, sum(index))
count=1
for i=arr
if index[i]
dims[count]=i
count+=1
end
end
return dims
end
"""
_sim_a_node(node, sim_grid::Matrix, pattern::Vector{Int}, centers, db)
node:节点坐标
sim_grid:是否满足条件的示性向量
pattern:模式尺度
centers:聚类中心的坐标库
db:模式数据库
找到距离节点最近的聚类中心,有多个则返回第一个中心
"""
function _sim_a_node(node, sim_grid::Matrix, pat::Vector{Int}, centers, db, id, inner=[0, 0])
i, j = node
# println(i, ' ',j)
cluster = _cluster_a_node([i, j] ,sim_grid, pat, centers)
pattern = db[_sampling(id, cluster),:]
p1 = 2*pat[1]+1;p2 = 2*pat[2]+1
pattern = reshape(pattern, p1, p2)
row1 = i+pat[1]-inner[1]:i+pat[1]+inner[1];col1 = j+pat[2]-inner[2]:j+pat[2]+inner[2]
row2 = pat[1]+1-inner[1]:pat[1]+1+inner[1];col2 = pat[2]+1-inner[2]:pat[2]+1+inner[2]
# sim_grid[row1, col1] = pattern[row2 , col2]
for r = 1:length(row1)
for c = 1:length(row2)
if sim_grid[row1[r], col1[c]] == -0.5
sim_grid[row1[r], col1[c]] = pattern[row2[r], col2[c]]
end
end
end
return sim_grid
end
"""
_sim(db, sim_grid, pat, centers, sim_path, id, inner=[0, 0])
db:模式数据库
sim_grid:模拟网格
pat:模板尺度
sim_path:模拟路径
id:聚类标签
inner:补丁尺度
找到距离节点最近的聚类中心,有多个则返回第一个中心
"""
function _sim(db, sim_grid, pat, centers, sim_path, id, inner=[0, 0])
n,m = size(sim_grid)
n = n-2pat[1];m=m-2pat[2]
node_num = n*m
# println(n,' ', m,' ',pat)
for i=1:node_num
node = sim_path[i]
j = floor(Int, node/n)
j == node/n ? j = j : j = j + 1
i = node - (j-1) * n
# println(i, ' ',j)
# if sim_grid[i+pat[1], j+pat[2]] != -0.5
if sim_grid[i+pat[1], j+pat[2]] == 0.5
sim_grid = _sim_a_node([i, j], sim_grid, pat, centers, db, id, inner)
end
end
return sim_grid[pat[1]+1:pat[1]+n , pat[2]+1:pat[2]+m]
end
"""
_simulate(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0])
im:训练图像
pat:模板尺度
skip:模式跳过级别
mds:降维数
cluster_num:聚类数
inner:补丁尺度
普通dispat
"""
function _simulate(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0])
n, m = size(im);
sim_grid = fill(0.5, n+2pat[1], m+2pat[2])
return _simulate_singleresolution(im, sim_grid, pat, skip, mds, cluster_num, inner)
end
"""
_simulate_multi(im, sim_grid, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0])
im:训练图像
sim_grid:初始模拟网格
pat:模板尺度
skip:模式跳过级别
mds:降维数
cluster_num:聚类数
inner:补丁尺度
需要传入模拟网格的dispat
"""
function _simulate_singleresolution(im, sim_grid, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0])
n, m = size(im);
# 生成模式数据库
db = _PatternDataBase(im, pat, skip)
# 降维模式数据库并分类
M=fit(MDS,transpose(db);maxoutdim=mds, distances=false);
Y = predict(M)
kmres = kmeans(Y, cluster_num);
id = kmres.assignments;
# centers = kmres.centers;
# 计算完整的聚类中心
centers = _recal_centers(db, id, cluster_num, pat)
# # 生成模拟网格
# n, m = size(im);
# sim_grid = fill(-0.5, n+2pat[1], m+2pat[2])
# println(size(sim_grid))
# 生成模拟路径
node_num = n*m
sim_path = randperm(node_num)
# 进行模拟
realization = _sim(db, sim_grid, pat, centers, sim_path, id, inner)
return float.(realization)
end
"""
_simulate_multiresolution(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0], resolution=1)
im:训练图像
pat:模板尺度
skip:模式跳过级别
mds:降维数
cluster_num:聚类数
inner:补丁尺度
resolution:分辨率级别
传统的多分辨率dispat
"""
function _simulate_multiresolution(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0], resolution=1)
if resolution == 1
return _simulate(im, pat, skip, mds, cluster_num, inner)
end
out = im
n, m = size(im);
realization = 0.5
for i=resolution:-1:1
imre = imresize(out, (ceil(Int, n/i), ceil(Int, m/i)))
imre = binarize(imre, HistogramThresholding.Otsu())
nn, mm = size(imre);
sim_grid = fill(0.5, nn+2pat[1], mm+2pat[2])
if i < resolution
sim_grid[1+pat[1]:nn+pat[1],1+pat[2]:mm+pat[2]] = realization
end
realization = _simulate_singleresolution(imre, sim_grid, pat, skip, mds, cluster_num, inner)
imshow(realization)
if i > 1
realization = imresize(realization, (ceil(Int, n/(i-1)), ceil(Int, m/(i-1))))
realization = binarize(realization, HistogramThresholding.Otsu())
imshow(realization)
end
end
return realization
end
"""
_simulate_multiresolution(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0], resolution=1)
im:训练图像
pat:模板尺度
skip:模式跳过级别
mds:降维数
cluster_num:聚类数
inner:补丁尺度
resolution:分辨率级别
尺度效应改进的多分辨率dispat
"""
function _simulate_multiresolution_scaling(im, pat=[1,1], skip=[1,1], mds=2, cluster_num=30, inner=[0, 0], resolution=1)
if resolution == 1
return _simulate(im, pat, skip, mds, cluster_num, inner)
end
out = im
n, m = size(im);
realization = -0.5
for i=resolution:-1:1 # 从粗层次开始模拟
imre = imresize(out, (ceil(Int, n/i), ceil(Int, m/i)))
imre = binarize(imre, HistogramThresholding.Otsu())
nn, mm = size(imre);
pat1 = ceil.(Int, pat./i);inner1 = ceil.(Int, inner./i)
sim_grid = fill(0.5, nn+2pat1[1], mm+2pat1[2])
if i < resolution # 不是最粗层次都需要更新模拟结果
sim_grid[1+pat1[1]:nn+pat1[1],1+pat1[2]:mm+pat1[2]] = realization
end
realization = _simulate_singleresolution(imre, sim_grid, pat1, skip, mds, cluster_num, inner1)
imshow(realization)
if i > 1 # 不是第一个层次都需要重塑模拟结果
realization = imresize(realization, (ceil(Int, n/(i-1)), ceil(Int, m/(i-1))))
realization = binarize(realization, HistogramThresholding.Otsu())
imshow(realization)
end
end
return realization
end
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Julia
1
https://gitee.com/marx_1_1307066363/julia_dispat.git
git@gitee.com:marx_1_1307066363/julia_dispat.git
marx_1_1307066363
julia_dispat
julia_dispat
master

搜索帮助