1 Star 0 Fork 53

花大虫张/Book5_Essentials-of-Probability-and-Statistics

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Bk5_Ch05_04.py 3.22 KB
一键复制 编辑 原始数据 按行查看 历史
Visualize-ML 提交于 2022-11-28 23:58 +08:00 . Add files via upload
###############
# Authored by Weisheng Jiang
# Book 5 | From Basic Arithmetic to Machine Learning
# Published and copyrighted by Tsinghua University Press
# Beijing, China, 2022
###############
from scipy.stats import multinomial
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
num = 8
x1_array = np.arange(num + 1)
x2_array = np.arange(num + 1)
xx1, xx2 = np.meshgrid(x1_array, x2_array)
xx3 = num - xx1 - xx2
xx3 = np.where(xx3 >= 0.0, xx3, np.nan)
def heatmap_sum(data,i_array,j_array,title,vmin,vmax,cmap,annot = False):
fig, ax = plt.subplots(figsize=(10, 10))
ax = sns.heatmap(data,cmap= cmap, #'YlGnBu', # YlGnBu
cbar_kws={"orientation": "horizontal"},
yticklabels=i_array, xticklabels=j_array,
ax = ax, annot = annot,
linewidths=0.25, linecolor='grey',
vmin = vmin, vmax = vmax,
fmt = '.3f')
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.invert_yaxis()
ax.set_aspect("equal")
plt.title(title)
plt.yticks(rotation=0)
#%% calculate multinomial probability
p_array = [0.6, 0.1, 0.3]
p_array = [0.3, 0.4, 0.3]
p_array = [0.1, 0.6, 0.3]
PMF_ff = multinomial.pmf(x=np.array(([xx1.ravel(), xx2.ravel(), xx3.ravel()])).T,
n=num, p=p_array)
PMF_ff = np.where(PMF_ff > 0.0, PMF_ff, np.nan)
PMF_ff = np.reshape(PMF_ff, xx1.shape)
#%% save to excel file
df = pd.DataFrame(np.flipud(PMF_ff))
filepath = 'PMF_ff.xlsx'
df.to_excel(filepath, index=False)
#%% 3D/2D scatter plot
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection="3d")
ax.scatter3D(xx1.ravel(), xx2.ravel(), xx3.ravel(),
s = 400,
marker='.',
c=PMF_ff.ravel(),
cmap = 'RdYlBu_r')
# ax.contour(xx1, xx2, PMF_ff, 15, zdir='z', offset=0, cmap='RdYlBu_r')
ax.set_proj_type('ortho')
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_zlabel('$x_3$')
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.set_xticks([0,num])
ax.set_yticks([0,num])
ax.set_zticks([0,num])
ax.set_xlim(0, num)
ax.set_ylim(0, num)
ax.set_zlim3d(0, num)
# ax.view_init(azim=20, elev=20)
ax.view_init(azim=-30, elev=20)
ax.view_init(azim=-90, elev=90)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_zlabel('$x_3$')
# ax.set_aspect('equal')
ax.set_box_aspect(aspect = (1,1,1))
ax.grid()
plt.show()
#%% heatmap
title = 'PMF of binomial distribution'
heatmap_sum(PMF_ff,x1_array,x2_array,title,0,0.12,'plasma_r',True)
#%% 3D stem chart
fig = plt.figure(figsize = (10,10))
ax = fig.add_subplot(111, projection='3d')
ax.stem(xx1.ravel(), xx2.ravel(), PMF_ff.ravel(), basefmt=" ")
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_zlabel('PMF')
ax.set_xlim((0,8))
ax.set_ylim((0,8))
ax.set_zlim((0,0.12))
# ax.set_zticks([])
# ax.grid(False)
ax.view_init(azim=-100, elev=20)
ax.set_proj_type('ortho')
plt.show()
# test only
# print(multinomial.pmf(x=(5,2,1), n=num, p=p_array))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/hua-da-chong-zhang/Book5_Essentials-of-Probability-and-Statistics.git
git@gitee.com:hua-da-chong-zhang/Book5_Essentials-of-Probability-and-Statistics.git
hua-da-chong-zhang
Book5_Essentials-of-Probability-and-Statistics
Book5_Essentials-of-Probability-and-Statistics
main

搜索帮助