代码拉取完成,页面将自动刷新
import gradio as gr
import numpy as np
import joblib
from PIL import Image
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.models import Model
def img_to_array(image):
return np.array(image)
def preprocess_input(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.resize((224, 224)) # 调整图像大小到224x224
image_array = np.array(image, dtype='float32') / 255.0 # 归一化到0-1之间
return np.expand_dims(image_array, axis=0) # 扩展维度以匹配Keras的输入要求
def extract_features(image):
image_array = preprocess_input(image)
# 加载VGG16模型,不包括顶部的全连接层
base_model = VGG16(weights='imagenet', include_top=False)
# 添加全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)
model = Model(inputs=base_model.input, outputs=x)
# 提取特征
features = model.predict(image_array)
return features
class CatDogClassifier:
def __init__(self, model_path, feature_method='vgg'):
self.model = joblib.load(model_path)
print(f"加载模型: {model_path},加载成功")
self.feature_method = feature_method
def preprocess_image(self, image):
if self.feature_method == 'flat':
img_resized = image.resize((64, 64))
img_array = np.array(img_resized)
features = img_array.flatten().reshape(1, -1)
else: # 'vgg' 方法
features = extract_features(image)
return features
def predict(self, image):
features = self.preprocess_image(image)
print(f"预测输入特征形状: {features.shape}")
prediction = self.model.predict(features)
if len(prediction.shape) == 1:
probability = prediction
else:
probability = prediction[0]
# 调整概率判断逻辑,确保模型输出的概率值被正确处理
if probability > 0.5: # 假设概率大于0.5为猫
return {"猫": float(probability)}
else:
return {"狗": float(1 - probability)} # 确保输出概率值正确
# Create classifier instance
model_file = "C:/Users/刘亚翔/Desktop/faiss_dog_cat_question-main/best_model_SVC.pkl"
classifier = CatDogClassifier(model_path=model_file, feature_method="vgg")
# Define prediction function
def predict_image(image):
if image is None:
return None
try:
return classifier.predict(image)
except Exception as e:
print(f"预测出错: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="上传图片"),
outputs=gr.Label(label="预测结果"),
title="🐱猫狗图片分类器🐶",
description="""\
## 使用说明
1. 点击上传或拖拽一张包含猫或狗的图片
2. 等待AI预测结果
3. 查看预测结果和置信度
*支持的图片格式:JPG、PNG、JPEG*
""",
examples=None,
cache_examples=True
)
# Start the application
if __name__ == "__main__":
print(f"使用模型: {model_file}")
iface.launch(
server_name="127.0.0.1", # 使用本地地址
server_port=7860, # 指定端口号
debug=True, # 开启调试模式
share=True # 创建公共链接
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。