import cv2 import numpy as np import math from collections import deque from ..base_processor import BaseProcessor from ..pipeline_data import PipelineData class DrawDirectionProcessor(BaseProcessor): """行驶方向绘制处理器:基于YOLO跟踪结果绘制车辆行驶方向""" def __init__(self, name: str = "行驶方向绘制处理器", arrow_length: float = 20.0, arrow_thickness: int = 2, arrow_color: tuple = (0, 0, 255), # 红色箭头 show_trajectory: bool = True, trajectory_length: int = 10, # 轨迹点数量 trajectory_color: tuple = (255, 255, 255), trajectory_thickness: int = 1, min_points_for_direction: int = 2, min_move_distance: float = 3.0): # 最小移动距离,避免抖动 super().__init__(name) self.arrow_length = arrow_length self.arrow_thickness = arrow_thickness self.arrow_color = arrow_color self.show_trajectory = show_trajectory self.trajectory_length = trajectory_length self.trajectory_color = trajectory_color self.trajectory_thickness = trajectory_thickness self.min_points_for_direction = min_points_for_direction self.min_move_distance = min_move_distance # 存储每个track_id的轨迹历史 {track_id: deque([(x1, y1), (x2, y2), ...])} self.track_histories = {} # 只存储车辆中心点,不存储完整轨迹 self.center_histories = {} # 清理超时的track(超过多少帧没出现就清理) self.max_track_lifetime = 30 # 30帧没出现就清理 self.track_last_seen = {} # {track_id: last_frame_idx} def process(self, data: PipelineData) -> PipelineData: """基于YOLO跟踪结果绘制车辆行驶方向""" if data.frame is None or data.current_result is None: return data vis = data.frame.copy() frame_idx = data.frame_idx # 获取YOLO检测结果 boxes = getattr(data.current_result, "boxes", None) if boxes is None: # 没有检测框,清理过期track self._cleanup_old_tracks(frame_idx) data.frame = vis return data # 当前帧存在的track_id current_track_ids = set() # 处理每个检测框 for box in boxes: try: # 解析检测框 xyxy = getattr(box, "xyxy", None) if xyxy is None: continue # 获取坐标 x1, y1, x2, y2 = map(float, xyxy[0].tolist()) # 获取track_id track_id = int(getattr(box, "id", -1)) if track_id == -1: continue # 没有跟踪ID,跳过 current_track_ids.add(track_id) # 计算中心点 center_x = (x1 + x2) / 2.0 center_y = (y1 + y2) / 2.0 # 更新轨迹历史 if track_id not in self.center_histories: self.center_histories[track_id] = deque(maxlen=self.trajectory_length) # 添加当前中心点到历史 self.center_histories[track_id].append((center_x, center_y)) # 更新最后出现帧 self.track_last_seen[track_id] = frame_idx # 绘制该车辆的轨迹和方向 self._draw_track_info(vis, track_id, (center_x, center_y)) except Exception as e: self.logger.error(f"Error processing box: {e}") continue # 清理长时间未出现的track self._cleanup_old_tracks(frame_idx) data.frame = vis return data def _draw_track_info(self, vis: np.ndarray, track_id: int, current_center: tuple): """绘制单个车辆的轨迹和方向""" if track_id not in self.center_histories: return history = self.center_histories[track_id] # 需要至少2个点才能绘制方向 if len(history) < self.min_points_for_direction: return # 绘制轨迹 if self.show_trajectory: self._draw_trajectory(vis, history) # 绘制方向箭头 self._draw_direction_arrow(vis, history) # 绘制当前中心点 cx, cy = current_center cv2.circle(vis, (int(cx), int(cy)), 3, (0, 255, 255), -1) # 黄色点 # 绘制track_id文本 text = f"ID:{track_id}" cv2.putText(vis, text, (int(cx) + 10, int(cy) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) def _draw_trajectory(self, vis: np.ndarray, history: deque): """绘制轨迹线""" if len(history) < 2: return # 将deque转换为列表以便索引 points = list(history) # 绘制轨迹线 for i in range(1, len(points)): pt1 = points[i - 1] pt2 = points[i] # 确保点是有效的 if (isinstance(pt1, (tuple, list)) and len(pt1) >= 2 and isinstance(pt2, (tuple, list)) and len(pt2) >= 2): x1, y1 = int(pt1[0]), int(pt1[1]) x2, y2 = int(pt2[0]), int(pt2[1]) # 绘制轨迹线 cv2.line(vis, (x1, y1), (x2, y2), self.trajectory_color, self.trajectory_thickness) def _draw_direction_arrow(self, vis: np.ndarray, history: deque): """绘制方向箭头""" if len(history) < 2: return # 获取最近的两个点计算方向 points = list(history) current_point = points[-1] # 使用最近的两个点计算方向(避免抖动) start_idx = max(0, len(points) - 3) # 使用最近3个点 direction_points = points[start_idx:] if len(direction_points) < 2: return # 计算平均方向(首尾点之间的向量) start_point = direction_points[0] end_point = direction_points[-1] dx = end_point[0] - start_point[0] dy = end_point[1] - start_point[1] # 计算移动距离 move_distance = math.hypot(dx, dy) # 如果移动距离太小,不绘制箭头(避免抖动) if move_distance < self.min_move_distance: return # 归一化方向向量 if move_distance > 0: dx_norm = dx / move_distance dy_norm = dy / move_distance else: return # 计算箭头起点(当前点) sx, sy = current_point # 计算箭头终点 ex = int(sx + dx_norm * self.arrow_length) ey = int(sy + dy_norm * self.arrow_length) # 绘制方向箭头 cv2.arrowedLine(vis, (int(sx), int(sy)), (ex, ey), self.arrow_color, self.arrow_thickness, tipLength=0.3) def _cleanup_old_tracks(self, current_frame_idx: int): """清理长时间未出现的track""" tracks_to_remove = [] for track_id, last_seen in list(self.track_last_seen.items()): # 如果超过最大生命周期没出现,则清理 if current_frame_idx - last_seen > self.max_track_lifetime: tracks_to_remove.append(track_id) for track_id in tracks_to_remove: if track_id in self.center_histories: del self.center_histories[track_id] if track_id in self.track_last_seen: del self.track_last_seen[track_id]