1 Star 0 Fork 0

刘亚翔/Lazypredictor

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
weapp.py 3.48 KB
一键复制 编辑 原始数据 按行查看 历史
刘亚翔 提交于 2024-11-11 14:49 . 1
import gradio as gr
import numpy as np
import joblib
from PIL import Image
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.models import Model
def img_to_array(image):
return np.array(image)
def preprocess_input(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.resize((224, 224)) # 调整图像大小到224x224
image_array = np.array(image, dtype='float32') / 255.0 # 归一化到0-1之间
return np.expand_dims(image_array, axis=0) # 扩展维度以匹配Keras的输入要求
def extract_features(image):
image_array = preprocess_input(image)
# 加载VGG16模型,不包括顶部的全连接层
base_model = VGG16(weights='imagenet', include_top=False)
# 添加全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)
model = Model(inputs=base_model.input, outputs=x)
# 提取特征
features = model.predict(image_array)
return features
class CatDogClassifier:
def __init__(self, model_path, feature_method='vgg'):
self.model = joblib.load(model_path)
print(f"加载模型: {model_path},加载成功")
self.feature_method = feature_method
def preprocess_image(self, image):
if self.feature_method == 'flat':
img_resized = image.resize((64, 64))
img_array = np.array(img_resized)
features = img_array.flatten().reshape(1, -1)
else: # 'vgg' 方法
features = extract_features(image)
return features
def predict(self, image):
features = self.preprocess_image(image)
print(f"预测输入特征形状: {features.shape}")
prediction = self.model.predict(features)
if len(prediction.shape) == 1:
probability = prediction
else:
probability = prediction[0]
# 调整概率判断逻辑,确保模型输出的概率值被正确处理
if probability > 0.5: # 假设概率大于0.5为猫
return {"猫": float(probability)}
else:
return {"狗": float(1 - probability)} # 确保输出概率值正确
# Create classifier instance
model_file = "C:/Users/刘亚翔/Desktop/faiss_dog_cat_question-main/best_model_SVC.pkl"
classifier = CatDogClassifier(model_path=model_file, feature_method="vgg")
# Define prediction function
def predict_image(image):
if image is None:
return None
try:
return classifier.predict(image)
except Exception as e:
print(f"预测出错: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="上传图片"),
outputs=gr.Label(label="预测结果"),
title="🐱猫狗图片分类器🐶",
description="""\
## 使用说明
1. 点击上传或拖拽一张包含猫或狗的图片
2. 等待AI预测结果
3. 查看预测结果和置信度
*支持的图片格式:JPG、PNG、JPEG*
""",
examples=None,
cache_examples=True
)
# Start the application
if __name__ == "__main__":
print(f"使用模型: {model_file}")
iface.launch(
server_name="127.0.0.1", # 使用本地地址
server_port=7860, # 指定端口号
debug=True, # 开启调试模式
share=True # 创建公共链接
)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/liu-yaxiang/lazypredictor.git
git@gitee.com:liu-yaxiang/lazypredictor.git
liu-yaxiang
lazypredictor
Lazypredictor
master

搜索帮助