代码拉取完成,页面将自动刷新
import asyncio
import concurrent
import os
import subprocess
import cv2
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from starlette.responses import FileResponse
from video_inspection.SnowflakeGenerator import SnowflakeGenerator
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords
from utils.plots import plot_one_box
app = FastAPI()
# 设置本地 HLS 输出路径
hls_output_dir = r"/video_inspection/uploads"
if not os.path.exists(hls_output_dir):
os.makedirs(hls_output_dir)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 选择设备
# 加载多个 YOLOv7 模型
weights_paths = ['F:/RGZN/speed/weights/yolov7.pt'] # 不同模型权重文件路径
models = []
for weights in weights_paths:
model = attempt_load(weights, map_location=device) # 加载模型
model.eval() # 设置模型为评估模式
models.append(model)
names_list = [model.names if hasattr(model, 'names') else ['class_{}'.format(i) for i in range(1000)] for model in
models]
# 初始化雪花算法生成器
generator = SnowflakeGenerator(datacenter_id=1, worker_id=1)
exit_flags = {} # 存储各线程的退出标志
# 初始化线程池
executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
class FileUrlInput(BaseModel):
file_url: str
names_dict: str
# 定义请求数据结构
class M3U8Input(BaseModel):
m3u8_url: str # 输入视频流的URL
names_dict: str #
region_points: list
class Exlfag(BaseModel):
thread_id: str
# '-c:v', 'h264_nvenc', # 使用 NVIDIA NVENC 编码器
async def start_streaming(input_stream_url: str, unique_id: str, names_dict: str,region_points: list):
# 设置 OpenCV 的 FFMPEG 读取尝试次数
os.environ['OPENCV_FFMPEG_READ_ATTEMPTS'] = '50000'
m3u8_output_path = os.path.join(hls_output_dir, f"{unique_id}.m3u8")
# 打开摄像头
cap = cv2.VideoCapture(input_stream_url)
if not cap.isOpened():
print("无法打开摄像头")
return
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
# 判断是否为 MP4 类型的视频(假定 MP4 视频帧数较多)
is_mp4_like = total_frames > fps # 如果帧数大于帧率,认为是长视频
frame_count = 0
ffmpeg_command = [
r'D:\linux软件\ffmpeg\bin\ffmpeg.exe',
'-y',
'-f', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', f"{frame_width}x{frame_height}",
'-r', str(15),
'-i', '-',
'-c:v', 'h264_qsv',
'-bufsize', '2000k',
'-preset', 'veryfast', # 优化编码速度
'-g', '30', # GOP 大小
'-sc_threshold', '0', # 场景切换阈值
'-f', 'hls',
'-hls_time', '2',
'-hls_list_size', '5',
'-hls_flags', 'delete_segments+append_list',
'-hls_base_url', f"http://127.0.0.1:8005/stream/{unique_id}/",
m3u8_output_path
]
# 启动 FFmpeg 推流进程
process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE)
exit_flags[unique_id] = process
try:
# 读取摄像头视频流并推送到 HLS
while True:
ret, frame = cap.read()
if not ret:
print("无法读取视频帧")
break
# 如果是长视频(MP4),每一帧都存入队列
if is_mp4_like:
# 处理每一帧
frame = await process_frame(unique_id, frame, names_dict)
else:
# 如果不是长视频,每 0.5 秒采样一帧
if frame_count % int(fps * 0.5) == 0:
# 处理每一帧
frame = await process_frame(unique_id, frame, names_dict)
# 写入帧数据到 FFmpeg 输入流
process.stdin.write(frame.tobytes())
process.stdin.flush()
frame_count += 1
# 计算实际处理时间并调整等待时间
asyncio.sleep(0.01) # 稍微增加等待时间
except Exception as e:
print(f"推流中出现错误: {e}")
finally:
# 关闭摄像头和 FFmpeg
if cap is not None:
cap.release()
if process is not None:
process.stdin.close()
process.wait()
def push_frame_to_ffmpeg(frame, process):
# 写入帧数据到 FFmpeg 输入流
process.stdin.write(frame.tobytes())
process.stdin.flush()
"""
模型对帧进行检测
"""
async def process_frame(unique_id, frame, names_dict, region_points=None):
"""
处理帧并在指定区域内进行目标检测
:param unique_id: 唯一标识符
:param frame: 输入图像帧
:param names_dict: 目标名称字典
:param region_points: 检测区域的点坐标列表 [(x1,y1), (x2,y2), ...]
详细解释这个区域坐标参数的含义:
"region_points": [
[100, 100], // 左上角点 (x1, y1)
[500, 100], // 右上角点 (x2, y1)
[600, 400], // 右下角点 (x3, y2)
[50, 400] // 左下角点 (x4, y2)
]
图形示意:
(100,100) ●——————————————● (500,100)
| \
| \
| \
| \
| \
| \
(50,400) ●——————————————--● (600,400)
参数说明:
1. 每个点都是 `[x, y]` 格式:
- x: 表示距离图像左边界的像素距离
- y: 表示距离图像上边界的像素距离
2. 坐标系统:
- 原点 (0,0) 在图像的左上角
- x 轴向右增加
- y 轴向下增加
- 单位是像素
3. 点的顺序:
- 建议按顺时针或逆时针顺序排列点
- 点的连接将形成一个封闭的多边形区域
- 这个区域内的目标才会被检测
4. 示例中的区域:
- 形成了一个不规则四边形
- 左边比右边低,形成了一个倾斜的形状
- 适合监控斜坡或特定角度的场景
注意事项:
- 坐标不能超出图像边界
- 确保坐标形成一个有效的封闭区域
- 可以根据实际场景需要调整点的位置
- 点的数量可以更多,形成更复杂的多边形
使用建议:
1. 根据实际监控区域设置合适的点
2. 可以通过可视化工具帮助确定准确坐标
3. 避免设置过于复杂的形状
4. 确保区域覆盖所有需要监控的范围
"""
# 创建掩码图层和叠加层
mask = np.zeros(frame.shape[:2], dtype=np.uint8)
overlay = frame.copy()
# 如果提供了区域点,则创建检测区域
if region_points:
points = np.array(region_points, np.int32)
points = points.reshape((-1, 1, 2))
# 绘制区域边界
cv2.polylines(overlay, [points], True, (0, 255, 0), 2)
# 填充掩码
cv2.fillPoly(mask, [points], 255)
else:
# 如果没有提供区域点,则整个画面都是检测区域
mask.fill(255)
# 叠加透明边框
alpha = 0.3
frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
# 目标检测处理
img = letterbox(frame, 640, stride=32)[0]
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.float() / 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# 对每个模型进行推理
for model, names in zip(models, names_list):
with torch.no_grad():
pred = model(img)[0]
pred = non_max_suppression(pred, 0.25, 0.40, agnostic=False)
for det in pred:
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], frame.shape).round()
for *xyxy, conf, cls in det:
# 检查检测框中心点是否在掩码区域内
center_x = int((xyxy[0] + xyxy[2]) / 2)
center_y = int((xyxy[1] + xyxy[3]) / 2)
if mask[center_y, center_x] == 255: # 在检测区域内
class_id = int(cls)
label_name = names[class_id]
chinese_label_name = Label_Names.get(label_name, label_name)
color = Label_Color.get(label_name, (0, 255, 0))
label = f'{chinese_label_name} {conf:.2f}'
match_found = False
for target in names_dict.split(','):
if label_name == target:
match_found = True
break
if match_found:
frame = put_chinese_text(color, frame, label, (int(xyxy[0]), int(xyxy[1] - 30)), 21)
plot_one_box(xyxy, frame, label="", color=color, line_thickness=2)
return frame
"""
进行PIL图片文字写入
"""
def put_chinese_text(color, image, text, position, font_size):
# 创建 PIL 图像
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
# 选择一个支持中文的字体
font = ImageFont.truetype("/Alimama_ShuHeiTi_Bold.ttf", font_size)
draw.text(position, text, font=font, fill=color)
# 转换回 OpenCV 格式
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
@app.post("/generate-link")
async def generate_link(m3u8_input: M3U8Input):
if not m3u8_input:
raise HTTPException(status_code=400, detail="Missing input_stream_url parameter")
input_stream_url = m3u8_input.m3u8_url
names_dict = m3u8_input.names_dict
region_points = m3u8_input.region_points
unique_id = str(generator.next_id())
# 在主线程中使用事件循环来运行协程
loop = asyncio.get_event_loop()
loop.run_in_executor(executor, lambda: asyncio.run(start_streaming(input_stream_url, unique_id, names_dict,region_points)))
# 立即返回生成的 m3u8 URL
m3u8_url = f"http://127.0.0.1:8005/stream/{unique_id}.m3u8"
return {"m3u8_url": m3u8_url}
@app.get("/stream/{unique_id}.m3u8")
async def get_m3u8(unique_id: str):
m3u8_file_path = os.path.join(hls_output_dir, f"{unique_id}.m3u8")
if os.path.exists(m3u8_file_path):
return FileResponse(m3u8_file_path)
raise HTTPException(status_code=404, detail="m3u8 file not found")
@app.get("/stream/{unique_id}/{ts_filename}")
async def get_ts(ts_filename: str):
ts_file_path = os.path.join(hls_output_dir, f"{ts_filename}") # 假设 TS 文件直接存放在 hls_output_dir 中
if os.path.exists(ts_file_path):
return FileResponse(ts_file_path)
raise HTTPException(status_code=404, detail="TS file not found")
Label_Names = {
'person': '人',
'smoke': '吸烟',
'D00': '纵向裂缝',
'D10': '横向裂缝',
'D20': '鳄鱼裂缝',
'D40': '坑洼',
'bicycle': '自行车',
'car': '汽车',
'motorcycle': '摩托车',
'airplane': '飞机',
'bus': '公交车',
'train': '火车',
'truck': '卡车',
'boat': '船',
'traffic light': '红绿灯',
'fire hydrant': '消防栓',
'stop sign': '停车标志',
'parking meter': '停车计时器',
'bench': '长椅',
'bird': '鸟',
'cat': '猫',
'dog': '狗',
'horse': '马',
'sheep': '羊',
'cow': '奶牛',
'elephant': '大象',
'bear': '熊',
'zebra': '斑马',
'giraffe': '长颈鹿',
'backpack': '背包',
'umbrella': '雨伞',
'handbag': '手提包',
'tie': '领带',
'suitcase': '行李箱',
'frisbee': '飞盘',
'skis': '滑雪板',
'snowboard': '滑雪板',
'sports ball': '运动球',
'kite': '风筝',
'baseball bat': '棒球棒',
'baseball glove': '棒球手套',
'skateboard': '滑板',
'surfboard': '冲浪板',
'tennis racket': '网球拍',
'bottle': '瓶子',
'wine glass': '酒杯',
'cup': '杯子',
'fork': '叉子',
'knife': '刀',
'spoon': '勺子',
'bowl': '碗',
'banana': '香蕉',
'apple': '苹果',
'sandwich': '三明治',
'orange': '橙子',
'broccoli': '西兰花',
'carrot': '胡萝卜',
'hot dog': '热狗',
'pizza': '披萨',
'donut': '甜甜圈',
'cake': '蛋糕',
'chair': '椅子',
'couch': '沙发',
'potted plant': '盆栽植物',
'bed': '床',
'dining table': '餐桌',
'toilet': '厕所',
'tv': '电视',
'laptop': '笔记本电脑',
'mouse': '鼠标',
'remote': '遥控器',
'keyboard': '键盘',
'cell phone': '手机',
'microwave': '微波炉',
'oven': '烤箱',
'toaster': '烤面包机',
'sink': '水槽',
'refrigerator': '冰箱',
'book': '书',
'clock': '时钟',
'vase': '花瓶',
'scissors': '剪刀',
'teddy bear': '泰迪熊',
'hair drier': '吹风机',
'toothbrush': '牙刷'
}
Label_Color = {
'car': (30, 144, 255), # 红色
'person': (0, 255, 0), # 绿色
'smoke': (220, 20, 60),
'D00': (123, 104, 238),
'D10': (65, 105, 225),
'D20': (0, 191, 255),
'D40': (176, 224, 230),
'bicycle': (0, 206, 209),
'motorcycle': (255, 192, 203),
'airplane': (0, 128, 128),
'bus': (255, 165, 0),
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8005)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。