Files
HighwayEventDet/pipeline/handler/DrawDirectionProcessor.py
zimoyin cdf228fe56 feat(app): 集成loguru日志系统并优化错误处理
- 在app.py中引入loguru并配置日志轮转、异步输出等功能
- 添加全局日志初始化函数和程序启动/退出日志记录
- 将所有print语句替换为logger.info/error/debug/warning方法
- 在data_source.py中添加模型加载和视频打开的日志记录
- 在各个处理器中集成日志记录器实例并记录处理状态
- 修改处理器模块导入路径以符合相对导入规范
- 在requirements.txt中添加loguru依赖包
- 统一异常处理的日志记录方式,便于调试和监控
2026-01-10 18:03:18 +08:00

225 lines
7.7 KiB
Python

import cv2
import numpy as np
import math
from collections import deque
from ..base_processor import BaseProcessor
from ..pipeline_data import PipelineData
class DrawDirectionProcessor(BaseProcessor):
"""行驶方向绘制处理器:基于YOLO跟踪结果绘制车辆行驶方向"""
def __init__(self,
name: str = "行驶方向绘制处理器",
arrow_length: float = 20.0,
arrow_thickness: int = 2,
arrow_color: tuple = (0, 0, 255), # 红色箭头
show_trajectory: bool = True,
trajectory_length: int = 10, # 轨迹点数量
trajectory_color: tuple = (255, 255, 255),
trajectory_thickness: int = 1,
min_points_for_direction: int = 2,
min_move_distance: float = 3.0): # 最小移动距离,避免抖动
super().__init__(name)
self.arrow_length = arrow_length
self.arrow_thickness = arrow_thickness
self.arrow_color = arrow_color
self.show_trajectory = show_trajectory
self.trajectory_length = trajectory_length
self.trajectory_color = trajectory_color
self.trajectory_thickness = trajectory_thickness
self.min_points_for_direction = min_points_for_direction
self.min_move_distance = min_move_distance
# 存储每个track_id的轨迹历史 {track_id: deque([(x1, y1), (x2, y2), ...])}
self.track_histories = {}
# 只存储车辆中心点,不存储完整轨迹
self.center_histories = {}
# 清理超时的track(超过多少帧没出现就清理)
self.max_track_lifetime = 30 # 30帧没出现就清理
self.track_last_seen = {} # {track_id: last_frame_idx}
def process(self, data: PipelineData) -> PipelineData:
"""基于YOLO跟踪结果绘制车辆行驶方向"""
if data.frame is None or data.current_result is None:
return data
vis = data.frame.copy()
frame_idx = data.frame_idx
# 获取YOLO检测结果
boxes = getattr(data.current_result, "boxes", None)
if boxes is None:
# 没有检测框,清理过期track
self._cleanup_old_tracks(frame_idx)
data.frame = vis
return data
# 当前帧存在的track_id
current_track_ids = set()
# 处理每个检测框
for box in boxes:
try:
# 解析检测框
xyxy = getattr(box, "xyxy", None)
if xyxy is None:
continue
# 获取坐标
x1, y1, x2, y2 = map(float, xyxy[0].tolist())
# 获取track_id
track_id = int(getattr(box, "id", -1))
if track_id == -1:
continue # 没有跟踪ID,跳过
current_track_ids.add(track_id)
# 计算中心点
center_x = (x1 + x2) / 2.0
center_y = (y1 + y2) / 2.0
# 更新轨迹历史
if track_id not in self.center_histories:
self.center_histories[track_id] = deque(maxlen=self.trajectory_length)
# 添加当前中心点到历史
self.center_histories[track_id].append((center_x, center_y))
# 更新最后出现帧
self.track_last_seen[track_id] = frame_idx
# 绘制该车辆的轨迹和方向
self._draw_track_info(vis, track_id, (center_x, center_y))
except Exception as e:
self.logger.error(f"Error processing box: {e}")
continue
# 清理长时间未出现的track
self._cleanup_old_tracks(frame_idx)
data.frame = vis
return data
def _draw_track_info(self, vis: np.ndarray, track_id: int, current_center: tuple):
"""绘制单个车辆的轨迹和方向"""
if track_id not in self.center_histories:
return
history = self.center_histories[track_id]
# 需要至少2个点才能绘制方向
if len(history) < self.min_points_for_direction:
return
# 绘制轨迹
if self.show_trajectory:
self._draw_trajectory(vis, history)
# 绘制方向箭头
self._draw_direction_arrow(vis, history)
# 绘制当前中心点
cx, cy = current_center
cv2.circle(vis, (int(cx), int(cy)), 3, (0, 255, 255), -1) # 黄色点
# 绘制track_id文本
text = f"ID:{track_id}"
cv2.putText(vis, text,
(int(cx) + 10, int(cy) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255), 1)
def _draw_trajectory(self, vis: np.ndarray, history: deque):
"""绘制轨迹线"""
if len(history) < 2:
return
# 将deque转换为列表以便索引
points = list(history)
# 绘制轨迹线
for i in range(1, len(points)):
pt1 = points[i - 1]
pt2 = points[i]
# 确保点是有效的
if (isinstance(pt1, (tuple, list)) and len(pt1) >= 2 and
isinstance(pt2, (tuple, list)) and len(pt2) >= 2):
x1, y1 = int(pt1[0]), int(pt1[1])
x2, y2 = int(pt2[0]), int(pt2[1])
# 绘制轨迹线
cv2.line(vis, (x1, y1), (x2, y2),
self.trajectory_color, self.trajectory_thickness)
def _draw_direction_arrow(self, vis: np.ndarray, history: deque):
"""绘制方向箭头"""
if len(history) < 2:
return
# 获取最近的两个点计算方向
points = list(history)
current_point = points[-1]
# 使用最近的两个点计算方向(避免抖动)
start_idx = max(0, len(points) - 3) # 使用最近3个点
direction_points = points[start_idx:]
if len(direction_points) < 2:
return
# 计算平均方向(首尾点之间的向量)
start_point = direction_points[0]
end_point = direction_points[-1]
dx = end_point[0] - start_point[0]
dy = end_point[1] - start_point[1]
# 计算移动距离
move_distance = math.hypot(dx, dy)
# 如果移动距离太小,不绘制箭头(避免抖动)
if move_distance < self.min_move_distance:
return
# 归一化方向向量
if move_distance > 0:
dx_norm = dx / move_distance
dy_norm = dy / move_distance
else:
return
# 计算箭头起点(当前点)
sx, sy = current_point
# 计算箭头终点
ex = int(sx + dx_norm * self.arrow_length)
ey = int(sy + dy_norm * self.arrow_length)
# 绘制方向箭头
cv2.arrowedLine(vis,
(int(sx), int(sy)),
(ex, ey),
self.arrow_color,
self.arrow_thickness,
tipLength=0.3)
def _cleanup_old_tracks(self, current_frame_idx: int):
"""清理长时间未出现的track"""
tracks_to_remove = []
for track_id, last_seen in list(self.track_last_seen.items()):
# 如果超过最大生命周期没出现,则清理
if current_frame_idx - last_seen > self.max_track_lifetime:
tracks_to_remove.append(track_id)
for track_id in tracks_to_remove:
if track_id in self.center_histories:
del self.center_histories[track_id]
if track_id in self.track_last_seen:
del self.track_last_seen[track_id]