From 23555e0cc933cd62a0d16c99c80d6241b2cb788d Mon Sep 17 00:00:00 2001 From: zimoyin <2556608754@qq.com> Date: Sat, 10 Jan 2026 09:41:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(pipeline):=20=E6=B7=BB=E5=8A=A0=E9=80=86?= =?UTF-8?q?=E8=A1=8C=E5=A4=84=E7=90=86=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹 - 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头 - 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度 - 新增 GraffitiVisualizer 用于涂鸦可视化处理 - 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦 --- pipeline/handler/DrawDirectionProcessor.py | 225 ++++++++++++++++++ pipeline/handler/DrawGraffitiProcessor.py | 198 ++++++++++++++++ pipeline/handler/DrawObjectBoxProcessor.py | 184 ++++++++++++++ pipeline/handler/GraffitiProcessor.py | 263 +++++++++++++++++++++ pipeline/handler/ResultLogger.py | 10 + pipeline/handler/RetrogradeProcessor.py | 244 +++++++++++++++++++ pipeline/handler/__init__.py | 10 + pipeline/handler/graffiti_visualizer.py | 198 ++++++++++++++++ test_advanced_processors.py | 23 ++ test_compatibility.py | 31 +++ test_modifications.py | 29 +++ test_updated_system.py | 39 +++ 12 files changed, 1454 insertions(+) create mode 100644 pipeline/handler/DrawDirectionProcessor.py create mode 100644 pipeline/handler/DrawGraffitiProcessor.py create mode 100644 pipeline/handler/DrawObjectBoxProcessor.py create mode 100644 pipeline/handler/GraffitiProcessor.py create mode 100644 pipeline/handler/ResultLogger.py create mode 100644 pipeline/handler/RetrogradeProcessor.py create mode 100644 pipeline/handler/__init__.py create mode 100644 pipeline/handler/graffiti_visualizer.py create mode 100644 test_advanced_processors.py create mode 100644 test_compatibility.py create mode 100644 test_modifications.py create mode 100644 test_updated_system.py diff --git a/pipeline/handler/DrawDirectionProcessor.py b/pipeline/handler/DrawDirectionProcessor.py new file mode 100644 index 0000000..898067f --- /dev/null +++ b/pipeline/handler/DrawDirectionProcessor.py @@ -0,0 +1,225 @@ +import cv2 +import numpy as np +import math +from collections import deque +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData + + +class DrawDirectionProcessor(BaseProcessor): + """行驶方向绘制处理器:基于YOLO跟踪结果绘制车辆行驶方向""" + + def __init__(self, + name: str = "行驶方向绘制处理器", + arrow_length: float = 20.0, + arrow_thickness: int = 2, + arrow_color: tuple = (0, 0, 255), # 红色箭头 + show_trajectory: bool = True, + trajectory_length: int = 10, # 轨迹点数量 + trajectory_color: tuple = (255, 255, 255), + trajectory_thickness: int = 1, + min_points_for_direction: int = 2, + min_move_distance: float = 3.0): # 最小移动距离,避免抖动 + super().__init__(name) + self.arrow_length = arrow_length + self.arrow_thickness = arrow_thickness + self.arrow_color = arrow_color + self.show_trajectory = show_trajectory + self.trajectory_length = trajectory_length + self.trajectory_color = trajectory_color + self.trajectory_thickness = trajectory_thickness + self.min_points_for_direction = min_points_for_direction + self.min_move_distance = min_move_distance + + # 存储每个track_id的轨迹历史 {track_id: deque([(x1, y1), (x2, y2), ...])} + self.track_histories = {} + + # 只存储车辆中心点,不存储完整轨迹 + self.center_histories = {} + + # 清理超时的track(超过多少帧没出现就清理) + self.max_track_lifetime = 30 # 30帧没出现就清理 + self.track_last_seen = {} # {track_id: last_frame_idx} + + def process(self, data: PipelineData) -> PipelineData: + """基于YOLO跟踪结果绘制车辆行驶方向""" + if data.frame is None or data.current_result is None: + return data + + vis = data.frame.copy() + frame_idx = data.frame_idx + + # 获取YOLO检测结果 + boxes = getattr(data.current_result, "boxes", None) + if boxes is None: + # 没有检测框,清理过期track + self._cleanup_old_tracks(frame_idx) + data.frame = vis + return data + + # 当前帧存在的track_id + current_track_ids = set() + + # 处理每个检测框 + for box in boxes: + try: + # 解析检测框 + xyxy = getattr(box, "xyxy", None) + if xyxy is None: + continue + + # 获取坐标 + x1, y1, x2, y2 = map(float, xyxy[0].tolist()) + + # 获取track_id + track_id = int(getattr(box, "id", -1)) + if track_id == -1: + continue # 没有跟踪ID,跳过 + + current_track_ids.add(track_id) + + # 计算中心点 + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + + # 更新轨迹历史 + if track_id not in self.center_histories: + self.center_histories[track_id] = deque(maxlen=self.trajectory_length) + + # 添加当前中心点到历史 + self.center_histories[track_id].append((center_x, center_y)) + + # 更新最后出现帧 + self.track_last_seen[track_id] = frame_idx + + # 绘制该车辆的轨迹和方向 + self._draw_track_info(vis, track_id, (center_x, center_y)) + + except Exception as e: + print(f"Error processing box: {e}") + continue + + # 清理长时间未出现的track + self._cleanup_old_tracks(frame_idx) + + data.frame = vis + return data + + def _draw_track_info(self, vis: np.ndarray, track_id: int, current_center: tuple): + """绘制单个车辆的轨迹和方向""" + if track_id not in self.center_histories: + return + + history = self.center_histories[track_id] + + # 需要至少2个点才能绘制方向 + if len(history) < self.min_points_for_direction: + return + + # 绘制轨迹 + if self.show_trajectory: + self._draw_trajectory(vis, history) + + # 绘制方向箭头 + self._draw_direction_arrow(vis, history) + + # 绘制当前中心点 + cx, cy = current_center + cv2.circle(vis, (int(cx), int(cy)), 3, (0, 255, 255), -1) # 黄色点 + + # 绘制track_id文本 + text = f"ID:{track_id}" + cv2.putText(vis, text, + (int(cx) + 10, int(cy) - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (255, 255, 255), 1) + + def _draw_trajectory(self, vis: np.ndarray, history: deque): + """绘制轨迹线""" + if len(history) < 2: + return + + # 将deque转换为列表以便索引 + points = list(history) + + # 绘制轨迹线 + for i in range(1, len(points)): + pt1 = points[i - 1] + pt2 = points[i] + + # 确保点是有效的 + if (isinstance(pt1, (tuple, list)) and len(pt1) >= 2 and + isinstance(pt2, (tuple, list)) and len(pt2) >= 2): + x1, y1 = int(pt1[0]), int(pt1[1]) + x2, y2 = int(pt2[0]), int(pt2[1]) + + # 绘制轨迹线 + cv2.line(vis, (x1, y1), (x2, y2), + self.trajectory_color, self.trajectory_thickness) + + def _draw_direction_arrow(self, vis: np.ndarray, history: deque): + """绘制方向箭头""" + if len(history) < 2: + return + + # 获取最近的两个点计算方向 + points = list(history) + current_point = points[-1] + + # 使用最近的两个点计算方向(避免抖动) + start_idx = max(0, len(points) - 3) # 使用最近3个点 + direction_points = points[start_idx:] + + if len(direction_points) < 2: + return + + # 计算平均方向(首尾点之间的向量) + start_point = direction_points[0] + end_point = direction_points[-1] + + dx = end_point[0] - start_point[0] + dy = end_point[1] - start_point[1] + + # 计算移动距离 + move_distance = math.hypot(dx, dy) + + # 如果移动距离太小,不绘制箭头(避免抖动) + if move_distance < self.min_move_distance: + return + + # 归一化方向向量 + if move_distance > 0: + dx_norm = dx / move_distance + dy_norm = dy / move_distance + else: + return + + # 计算箭头起点(当前点) + sx, sy = current_point + + # 计算箭头终点 + ex = int(sx + dx_norm * self.arrow_length) + ey = int(sy + dy_norm * self.arrow_length) + + # 绘制方向箭头 + cv2.arrowedLine(vis, + (int(sx), int(sy)), + (ex, ey), + self.arrow_color, + self.arrow_thickness, + tipLength=0.3) + + def _cleanup_old_tracks(self, current_frame_idx: int): + """清理长时间未出现的track""" + tracks_to_remove = [] + + for track_id, last_seen in list(self.track_last_seen.items()): + # 如果超过最大生命周期没出现,则清理 + if current_frame_idx - last_seen > self.max_track_lifetime: + tracks_to_remove.append(track_id) + + for track_id in tracks_to_remove: + if track_id in self.center_histories: + del self.center_histories[track_id] + if track_id in self.track_last_seen: + del self.track_last_seen[track_id] \ No newline at end of file diff --git a/pipeline/handler/DrawGraffitiProcessor.py b/pipeline/handler/DrawGraffitiProcessor.py new file mode 100644 index 0000000..c8845b0 --- /dev/null +++ b/pipeline/handler/DrawGraffitiProcessor.py @@ -0,0 +1,198 @@ +import cv2 +import numpy as np +import math +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData + + +class DrawGraffitiProcessor(BaseProcessor): + """涂鸦绘制处理器:绘制热力图和网格方向箭头""" + + def __init__(self, + name: str = "涂鸦绘制处理器", + heatmap_opacity: float = 0.45, + grid_step: int = 40, + grid_heat_threshold: float = 0.6, + grid_draw_every: int = 2, + max_heat: float = 10.0): + super().__init__(name) + self.heatmap_opacity = heatmap_opacity + self.grid_step = grid_step + self.grid_heat_threshold = grid_heat_threshold + self.grid_draw_every = grid_draw_every + self.max_heat = max_heat + + # 箭头参数 + self.arrow_length_min = 10.0 + self.grid_arrow_color = (0, 128, 255) # 橙色 - 网格箭头 + + # 帧计数器 + self.frame_idx = 0 + + def process(self, data: PipelineData) -> PipelineData: + """绘制热力图和网格方向箭头""" + if data.frame is None: + return data + + self.frame_idx = data.frame_idx + vis = data.frame.copy() + + # 获取涂鸦数据 + heat_map = data.get_data("graffiti", "heat_map") + dir_x_map = data.get_data("graffiti", "dir_x_map") + dir_y_map = data.get_data("graffiti", "dir_y_map") + + # 绘制热力图 + if heat_map is not None: + vis = self._draw_heatmap(vis, heat_map) + + # 绘制网格方向箭头 + if (dir_x_map is not None and dir_y_map is not None and + self.frame_idx % self.grid_draw_every == 0): + vis = self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map) + + # TODO 绘制调试信息,这个应该是单独的处理器 + vis = self._draw_debug_info(vis) + + data.frame = vis + return data + + def _draw_heatmap(self, vis, heat_map): + """绘制热力图叠加层""" + # 归一化热力图 + heat_norm = np.clip(heat_map / self.max_heat, 0.0, 1.0) + heat_u8 = (heat_norm * 255).astype(np.uint8) + + # 创建红色通道叠加层 + red_overlay = np.zeros_like(vis) + red_overlay[:, :, 2] = heat_u8 # 红色通道 + + # 叠加热力图 + vis = cv2.addWeighted(vis, 1.0, red_overlay, self.heatmap_opacity, 0) + + return vis + + def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map): + """在网格上绘制持久方向箭头""" + h, w = heat_map.shape + + for gy in range(0, h, self.grid_step): + for gx in range(0, w, self.grid_step): + x0, y0 = gx, gy + x1 = min(gx + self.grid_step, w) + y1 = min(gy + self.grid_step, h) + + # 检查网格区域热力 + heat_window = heat_map[y0:y1, x0:x1] + if heat_window.size == 0: + continue + + avg_heat = float(heat_window.mean()) + + # 如果平均热力达到阈值,绘制方向箭头 + if avg_heat / self.max_heat >= self.grid_heat_threshold: + avg_dx = float(dir_x_map[y0:y1, x0:x1].sum()) + avg_dy = float(dir_y_map[y0:y1, x0:x1].sum()) + + if avg_dx == 0 and avg_dy == 0: + continue + + # 计算网格中心点 + cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2) + + # 绘制网格边界(可选,用于调试) + # cv2.rectangle(vis, (x0, y0), (x1, y1), (100, 100, 100), 1) + + # 绘制方向箭头 + self._draw_arrow(vis, cell_center, + (avg_dx * 5.0, avg_dy * 5.0)) + + return vis + + def _draw_arrow(self, img, start, vec, color=None, thickness=2): + """绘制箭头""" + if color is None: + color = self.grid_arrow_color + + sx, sy = int(round(start[0])), int(round(start[1])) + vx, vy = float(vec[0]), float(vec[1]) + + norm = math.hypot(vx, vy) + if norm < 1e-3: + return + + # 确保箭头有最小长度 + scale = 1.0 + if norm < self.arrow_length_min: + scale = self.arrow_length_min / (norm + 1e-6) + + ex = int(round(sx + vx * scale)) + ey = int(round(sy + vy * scale)) + + cv2.arrowedLine(img, + (sx, sy), + (ex, ey), + color, thickness, + tipLength=0.3) + + def _draw_debug_info(self, vis): + """绘制调试信息""" + # 绘制帧号 + cv2.putText(vis, + f"Frame:{self.frame_idx}", + (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, + (200, 200, 200), 2) + + # 绘制热力图参数 + cv2.putText(vis, + f"HeatMax:{self.max_heat} TrustThr:3.0", + (10, 45), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (200, 200, 200), 1) + + # 绘制网格参数 + cv2.putText(vis, + f"Grid:{self.grid_step}px Thresh:{self.grid_heat_threshold}", + (10, 65), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (200, 200, 200), 1) + + # 绘制热力图图例 + legend_y = vis.shape[0] - 30 + legend_x = 10 + + # 图例标题 + cv2.putText(vis, "Heat Legend:", + (legend_x, legend_y), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (200, 200, 200), 1) + + # 绘制颜色条 + colorbar_width = 100 + colorbar_height = 10 + colorbar_x = legend_x + 80 + colorbar_y = legend_y - 5 + + # 创建渐变颜色条 + for i in range(colorbar_width): + intensity = i / colorbar_width + color = (0, 0, int(intensity * 255)) + cv2.line(vis, + (colorbar_x + i, colorbar_y), + (colorbar_x + i, colorbar_y + colorbar_height), + color, 1) + + # 标注最小值 + cv2.putText(vis, "0", + (colorbar_x - 10, colorbar_y + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, + (200, 200, 200), 1) + + # 标注最大值 + cv2.putText(vis, f"{self.max_heat}", + (colorbar_x + colorbar_width - 10, colorbar_y + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, + (200, 200, 200), 1) + + return vis \ No newline at end of file diff --git a/pipeline/handler/DrawObjectBoxProcessor.py b/pipeline/handler/DrawObjectBoxProcessor.py new file mode 100644 index 0000000..220e3e2 --- /dev/null +++ b/pipeline/handler/DrawObjectBoxProcessor.py @@ -0,0 +1,184 @@ +import cv2 +import numpy as np +from dataclasses import dataclass, field +from typing import List, Optional, Set +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData + + +@dataclass +class DetectionClass: + """检测类别数据类""" + class_id: int # 类别ID(对应COCO数据集) + class_name: str # 类别名称 + color: tuple # 绘制颜色 (B, G, R) + + def __hash__(self): + """用于支持集合操作""" + return hash(self.class_id) + + def __eq__(self, other): + """用于支持集合操作""" + if not isinstance(other, DetectionClass): + return False + return self.class_id == other.class_id + + +class DrawObjectBoxProcessor(BaseProcessor): + """对象框绘制处理器:绘制检测框、ID、标签、置信度""" + + def __init__(self, + name: str = "对象框绘制处理器", + show_id: bool = True, + show_label: bool = True, + show_conf: bool = False, # 默认不显示置信度 + box_thickness: int = 2, + font_scale: float = 0.5, + vehicle_classes: Optional[Set[DetectionClass]] = None): + super().__init__(name) + self.show_id = show_id + self.show_label = show_label + self.show_conf = show_conf + self.box_thickness = box_thickness + self.font_scale = font_scale + + # 初始化车辆类别集合(使用预定义或用户自定义) + self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [ + # coco 数据集 预设类 + DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色 + DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝 + DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄 + DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫 + DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青 + DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙 + DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿 + ] + + # 构建快速查询映射(提高查找效率) + self._class_id_to_info = { + cls.class_id: (cls.class_name, cls.color) + for cls in self.vehicle_classes + } + # 车辆类别ID集合(用于快速判断) + self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes} + + def process(self, data: PipelineData) -> PipelineData: + """绘制检测框和相关信息""" + if data.frame is None or data.current_result is None: + return data + + vis = data.frame.copy() + boxes = getattr(data.current_result, "boxes", None) + + if boxes is None: + data.frame = vis + return data + + # 绘制每个检测框 + for box in boxes: + try: + self._draw_single_box(vis, box) + except Exception as e: + self.logger.warning(f"绘制检测框失败: {e}") + continue + + data.frame = vis + return data + + def _draw_single_box(self, vis: np.ndarray, box): + """绘制单个检测框""" + # 获取坐标 + xyxy = getattr(box, "xyxy", None) + if xyxy is None: + return + + x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist())) + + # 获取类别和置信度 + cls = int(getattr(box, "cls", -1)) + conf = float(getattr(box, "conf", 0.0)) + + # 获取track_id + track_id = int(getattr(box, "id", -1)) + + # 检查是否是需要绘制的车辆类别 + if cls not in self._vehicle_class_ids: + return + + # 获取类别名称和颜色 + class_name, color = self._class_id_to_info.get( + cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框 + ) + + # 绘制检测框 + cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness) + + # 构建显示文本 + text_parts = [] + + if self.show_id and track_id != -1: + text_parts.append(f"ID:{track_id}") + + if self.show_label: + text_parts.append(class_name) + + if self.show_conf: + text_parts.append(f"{conf:.2f}") + + if text_parts: + text = " ".join(text_parts) + + # 计算文本大小 + text_size = cv2.getTextSize( + text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1 + )[0] + + # 确保文本位置在图像内 + text_y = max(y1 - 10, 20) # 至少离顶部20像素 + text_x = max(x1, 10) # 至少离左侧10像素 + + # 绘制文本背景(带内边距) + bg_top_left = (text_x - 2, text_y - text_size[1] - 4) + bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2) + cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1) + + # 绘制文本(白色) + cv2.putText( + vis, text, + (text_x, text_y - 2), + cv2.FONT_HERSHEY_SIMPLEX, + self.font_scale, + (255, 255, 255), # 白色文字 + thickness=1, + lineType=cv2.LINE_AA # 抗锯齿 + ) + + @classmethod + def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]: + """ + 工厂方法:创建自定义的类别集合 + + Args: + class_configs: 类别配置列表,每个元素包含 class_id, class_name, color + 示例: [ + {"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)}, + {"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)} + ] + + Returns: + 自定义的DetectionClass集合 + """ + custom_classes = set() + for config in class_configs: + try: + detection_class = DetectionClass( + class_id=config["class_id"], + class_name=config["class_name"], + color=tuple(config["color"]) + ) + custom_classes.add(detection_class) + except KeyError as e: + raise ValueError(f"类别配置缺少必要字段: {e}") + except Exception as e: + raise ValueError(f"创建自定义类别失败: {e}") + return custom_classes diff --git a/pipeline/handler/GraffitiProcessor.py b/pipeline/handler/GraffitiProcessor.py new file mode 100644 index 0000000..1226e6d --- /dev/null +++ b/pipeline/handler/GraffitiProcessor.py @@ -0,0 +1,263 @@ +import numpy as np +import cv2 +import math +from collections import deque, defaultdict +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData + + +class GraffitiProcessor(BaseProcessor): + """涂鸦计算处理器:计算车辆轨迹并更新车道涂鸦""" + + def __init__(self, name: str = "涂鸦计算处理器"): + super().__init__(name) + # 可调参数 + self.MAX_HEAT = 10.0 + self.INC_PER_PASS = 1.0 + self.DEC_PER_OPPOSITE = 1.0 + self.MIN_MOVE_DIST = 6.0 + self.MIN_CUM_DIST = 20.0 + self.HISTORY_LEN = 16 + self.HEAT_DECAY_PER_FRAME = 0.002 + self.VEHICLE_CLASSES = {2, 3, 5, 7} + self.FULL_DECAY_EVERY = 30 # 每30帧进行一次全图衰减 + + # 初始化数据结构(在process中根据帧尺寸初始化) + self.graffiti = None + self.frame_shape = None + self.track_hist = defaultdict(lambda: deque(maxlen=self.HISTORY_LEN)) + self.track_cum_dist = defaultdict(float) + self.frame_idx = 0 + + def init_graffiti(self, h, w): + """初始化涂鸦数据结构""" + if self.graffiti is None or self.frame_shape != (h, w): + self.graffiti = GraffitiLane(w, h) + self.frame_shape = (h, w) + + def process(self, data: PipelineData) -> PipelineData: + """处理每一帧,更新涂鸦数据""" + if data.current_result is None or data.frame is None: + return data + + frame = data.frame + h, w = frame.shape[:2] + self.init_graffiti(h, w) + self.frame_idx = data.frame_idx + + # 周期性对整图做完整衰减 + if self.frame_idx % self.FULL_DECAY_EVERY == 0: + self.graffiti.full_decay(self.frame_idx) + + # 获取检测框 + boxes = getattr(data.current_result, "boxes", None) + if boxes is None: + return data + + # 处理每个检测框 + for box in boxes: + try: + # 解析检测框 + xyxy = getattr(box, "xyxy", None) + if xyxy is None: + continue + + x1, y1, x2, y2 = map(float, xyxy[0].tolist()) + cls = int(getattr(box, "cls", -1)) + track_id = int(getattr(box, "id", -1)) + + # 仅处理车辆类 + if cls not in self.VEHICLE_CLASSES: + continue + + # 更新轨迹和历史 + self._update_track(track_id, x1, y1, x2, y2) + + # 如果累计距离达到阈值,更新涂鸦 + if (self.track_cum_dist[track_id] >= self.MIN_CUM_DIST and + len(self.track_hist[track_id]) >= 2): + self._update_graffiti(track_id, x2 - x1) + self.track_cum_dist[track_id] = 0.0 + + except Exception as e: + continue + + # 清理长时间不活跃的track + self._cleanup_old_tracks() + + # 存储涂鸦数据到PipelineData(deque转换为list以便序列化) + data.put_data("graffiti", "heat_map", self.graffiti.heat) + data.put_data("graffiti", "dir_x_map", self.graffiti.dir_x) + data.put_data("graffiti", "dir_y_map", self.graffiti.dir_y) + data.put_data("graffiti", "track_hist", { + track_id: list(hist) for track_id, hist in self.track_hist.items() + }) + data.put_data("graffiti", "track_cum_dist", dict(self.track_cum_dist)) + + return data + + def _update_track(self, track_id, x1, y1, x2, y2): + """更新车辆轨迹信息""" + cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0 + prev_centers = self.track_hist[track_id] + + if len(prev_centers) > 0: + last_cx, last_cy = prev_centers[-1] + move_dist = math.hypot(cx - last_cx, cy - last_cy) + else: + move_dist = 0.0 + + if len(prev_centers) > 0: + self.track_cum_dist[track_id] += move_dist + else: + self.track_cum_dist[track_id] = 0.0 + + if move_dist >= self.MIN_MOVE_DIST: + prev_centers.append((cx, cy)) + elif len(prev_centers) == 0: + prev_centers.append((cx, cy)) + + def _update_graffiti(self, track_id, vehicle_width): + """根据车辆轨迹更新涂鸦""" + hist = self.track_hist[track_id] + if len(hist) < 2: + return + + cp_prev = hist[-2] + cp_cur = hist[-1] + + # 创建多边形掩码 + poly = self._polygon_from_segment(cp_prev, cp_cur, max(8.0, vehicle_width)) + if poly is None: + return + + # 计算ROI掩码 + bbox_roi, mask_roi = self._polygon_mask_roi_from_pts(poly, self.frame_shape) + if mask_roi.size == 0: + return + + # 计算方向向量 + dir_vec = self._normalize_vec((cp_cur[0] - cp_prev[0], cp_cur[1] - cp_prev[1])) + dir_vec_scaled = dir_vec.astype(np.float32) * 1.0 + + # 更新涂鸦 + self.graffiti.apply_paint_roi( + bbox_roi, mask_roi, dir_vec_scaled, + inc=self.INC_PER_PASS, + current_frame=self.frame_idx + ) + + def _cleanup_old_tracks(self): + """清理长时间不活跃的track,防止内存泄漏""" + # 可以根据需要实现,比如超过一定帧数没有更新就清理 + # 这里先不实现,根据实际情况调整 + pass + + def _polygon_from_segment(self, center_prev, center_cur, veh_width_px): + """从两点构造矩形带""" + x0, y0 = center_prev + x1, y1 = center_cur + vx = x1 - x0 + vy = y1 - y0 + dist = math.hypot(vx, vy) + if dist == 0: + return None + + ux, uy = vx / dist, vy / dist + px, py = -uy, ux + half_w = veh_width_px / 2.0 + + p0 = (int(round(x0 + px * half_w)), int(round(y0 + py * half_w))) + p1 = (int(round(x0 - px * half_w)), int(round(y0 - py * half_w))) + p2 = (int(round(x1 - px * half_w)), int(round(y1 - py * half_w))) + p3 = (int(round(x1 + px * half_w)), int(round(y1 + py * half_w))) + + return np.array([p0, p1, p2, p3], dtype=np.int32) + + def _polygon_mask_roi_from_pts(self, pts, img_shape): + """在多边形最小包围矩形内生成局部掩码""" + if pts is None or len(pts) == 0: + return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8) + + x, y, w, h = cv2.boundingRect(pts) + x2 = min(x + w, img_shape[1]) + y2 = min(y + h, img_shape[0]) + w = x2 - x + h = y2 - y + + if w <= 0 or h <= 0: + return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8) + + pts_roi = pts.copy() + pts_roi[:, 0] -= x + pts_roi[:, 1] -= y + mask = np.zeros((h, w), dtype=np.uint8) + cv2.fillPoly(mask, [pts_roi], 1) + + return (x, y, w, h), mask + + def _normalize_vec(self, v): + """归一化向量""" + v = np.array(v, dtype=np.float32) + norm = np.linalg.norm(v) + if norm == 0: + return np.array([0.0, 0.0], dtype=np.float32) + return v / norm + + +class GraffitiLane: + """涂鸦车道数据结构""" + + def __init__(self, frame_w, frame_h): + self.w = frame_w + self.h = frame_h + self.heat = np.zeros((self.h, self.w), dtype=np.float32) + self.dir_x = np.zeros((self.h, self.w), dtype=np.float32) + self.dir_y = np.zeros((self.h, self.w), dtype=np.float32) + self.last_decay_frame = 0 + + def _decay_factor(self, frames): + if frames <= 0: + return 1.0 + return (1.0 - 0.002) ** frames # HEAT_DECAY_PER_FRAME = 0.002 + + def lazy_decay_roi(self, bbox, current_frame): + x, y, w, h = bbox + if w <= 0 or h <= 0: + return + + frames = current_frame - self.last_decay_frame + if frames <= 0: + return + + factor = self._decay_factor(frames) + self.heat[y:y + h, x:x + w] *= factor + self.dir_x[y:y + h, x:x + w] *= factor + self.dir_y[y:y + h, x:x + w] *= factor + + def full_decay(self, current_frame): + """在整图上强制应用pending衰减(用于周期性清理)""" + frames = current_frame - self.last_decay_frame + if frames <= 0: + return + + factor = self._decay_factor(frames) + if factor == 1.0: + self.last_decay_frame = current_frame + return + + # 对整图应用因子 + self.heat *= factor + self.dir_x *= factor + self.dir_y *= factor + self.last_decay_frame = current_frame + + def apply_paint_roi(self, bbox, mask_roi, direction_vector, inc=1.0, current_frame=0): + self.lazy_decay_roi(bbox, current_frame) + x, y, w, h = bbox + add = inc * mask_roi.astype(np.float32) + + self.heat[y:y + h, x:x + w] = np.minimum(10.0, self.heat[y:y + h, x:x + w] + add) # MAX_HEAT = 10.0 + dx, dy = direction_vector + self.dir_x[y:y + h, x:x + w] += dx * add + self.dir_y[y:y + h, x:x + w] += dy * add \ No newline at end of file diff --git a/pipeline/handler/ResultLogger.py b/pipeline/handler/ResultLogger.py new file mode 100644 index 0000000..cbda1e3 --- /dev/null +++ b/pipeline/handler/ResultLogger.py @@ -0,0 +1,10 @@ +from ..base_processor import BaseProcessor +from ..pipeline_data import PipelineData +import numpy as np + +class ResultLogger(BaseProcessor): + """示例处理器:打印检测结果日志""" + def process(self, data: PipelineData) -> PipelineData: + print(f"\n【{self.name}】帧{data.frame_idx} - 检测到目标数: {len(data.current_result.boxes)}") + print(f"缓存帧数: {len(data.result_cache)}") + return data \ No newline at end of file diff --git a/pipeline/handler/RetrogradeProcessor.py b/pipeline/handler/RetrogradeProcessor.py new file mode 100644 index 0000000..d2a8dd9 --- /dev/null +++ b/pipeline/handler/RetrogradeProcessor.py @@ -0,0 +1,244 @@ +import cv2 +import numpy as np +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData +import math + + +class RetrogradeProcessor(BaseProcessor): + """逆行检测处理器:检测逆向行驶的车辆""" + + def __init__(self, name: str = "逆行检测处理器"): + super().__init__(name) + # 可调参数 + self.TRUST_THRESHOLD = 3.0 + self.PARTIAL_OVERLAP_THRESHOLD = 0.1 + self.MAJORITY_OVERLAP_THRESHOLD = 0.5 + self.DEC_PER_OPPOSITE = 1.0 + # self.OPPOSITE_ANGLE_COS_THRESH = -0.9397 + self.OPPOSITE_ANGLE_COS_THRESH = math.cos(math.radians(1)) + self.VEHICLE_CLASSES = {2, 3, 5, 7} + + # 存储上一帧的事件信息 + self.track_last_event = {} + self.frame_idx = 0 + + def process(self, data: PipelineData) -> PipelineData: + """检测逆行行为""" + if data.current_result is None: + return data + + self.frame_idx = data.frame_idx + + # 获取涂鸦数据 + heat_map = data.get_data("graffiti", "heat_map") + dir_x_map = data.get_data("graffiti", "dir_x_map") + dir_y_map = data.get_data("graffiti", "dir_y_map") + track_hist = data.get_data("graffiti", "track_hist") + + if heat_map is None or track_hist is None: + return data + + boxes = getattr(data.current_result, "boxes", None) + if boxes is None: + return data + + frame_events = {} + events_all0 = {} + # 处理每个检测框 + for box in boxes: + try: + # 解析检测框 + xyxy = getattr(box, "xyxy", None) + if xyxy is None: + continue + + x1, y1, x2, y2 = map(float, xyxy[0].tolist()) + cls = int(getattr(box, "cls", -1)) + track_id = int(getattr(box, "id", -1)) + + # 仅处理车辆类 + if cls not in self.VEHICLE_CLASSES: + continue + + # 获取车辆历史轨迹 + hist = track_hist.get(track_id, []) + if len(hist) < 2: + continue + + # 计算当前方向 + cp_prev = hist[-2] + cp_cur = hist[-1] + dir_vec = self._normalize_vec((cp_cur[0] - cp_prev[0], cp_cur[1] - cp_prev[1])) + + # 计算车辆经过的区域 + vehicle_width = max(8.0, x2 - x1) + poly = self._polygon_from_segment(cp_prev, cp_cur, vehicle_width) + if poly is None: + continue + + # 创建ROI掩码 + frame_shape = heat_map.shape[:2] if len(heat_map.shape) == 2 else heat_map.shape + bbox_roi, mask_roi = self._polygon_mask_roi_from_pts(poly, frame_shape) + if mask_roi.size == 0: + continue + + # 计算与现有涂鸦的重叠 + x, y, w_roi, h_roi = bbox_roi + heat_roi = heat_map[y:y + h_roi, x:x + w_roi] + overlap_pixels = np.logical_and(mask_roi > 0, heat_roi > 0) + overlap_ratio = float(overlap_pixels.sum()) / max(1.0, mask_roi.sum()) + + # 计算区域平均热力和方向 + region_heat = self._region_avg_heat(heat_roi, mask_roi) + region_dir = self._region_avg_direction(dir_x_map[y:y + h_roi, x:x + w_roi], + dir_y_map[y:y + h_roi, x:x + w_roi], + mask_roi) + + # 判断事件类型 + event = self._detect_retrograde_event( + region_heat, region_dir, dir_vec, overlap_ratio + ) + events_all0[track_id] = event + + if event: + frame_events[track_id] = event + # 减少涂鸦区域热力(逆向行驶会减少涂鸦可信度) + self._reduce_paint(heat_map, dir_x_map, dir_y_map, bbox_roi, mask_roi, overlap_ratio) + + except Exception as e: + continue + + # 存储事件信息 + data.put_data("retrograde", "current_events", frame_events) + data.put_data("retrograde", "all_events", self.track_last_event) + + # 输出逆行信息 + for track_id, event in events_all0.items(): + print(f"【逆行检测器】帧{self.frame_idx} - 检测到事件:{event} - 轨迹ID: {track_id}") + + # 更新事件历史 + for track_id, event in frame_events.items(): + self.track_last_event[track_id] = (self.frame_idx, event) + + return data + + def _detect_retrograde_event(self, region_heat, region_dir, current_dir, overlap_ratio): + """检测逆行事件类型""" + if region_heat >= 1.0 and region_dir is not None: + cos_sim = self._cos_between(region_dir, current_dir) + + if cos_sim < self.OPPOSITE_ANGLE_COS_THRESH: + if region_heat >= self.TRUST_THRESHOLD: + if overlap_ratio >= self.MAJORITY_OVERLAP_THRESHOLD: + return "REVERSE" + elif overlap_ratio >= self.PARTIAL_OVERLAP_THRESHOLD: + return "CROSS_LINE" + else: + return "SUSPECT" + else: + return "SUSPECT" + else: + return "NORMAL" + + return None + + def _reduce_paint(self, heat_map, dir_x_map, dir_y_map, bbox_roi, mask_roi, overlap_ratio): + """减少涂鸦区域热力(逆向行驶惩罚)""" + x, y, w, h = bbox_roi + amount = self.DEC_PER_OPPOSITE * overlap_ratio + dec = amount * mask_roi.astype(np.float32) + + # 减少热力 + heat_roi = heat_map[y:y + h, x:x + w] + heat_roi_before = heat_roi.copy() + heat_roi_after = np.maximum(0.0, heat_roi_before - dec) + + # 按比例减少方向强度 + eps = 1e-6 + ratio = np.ones_like(heat_roi_before) + nonzero_mask = heat_roi_before > eps + ratio[nonzero_mask] = heat_roi_after[nonzero_mask] / (heat_roi_before[nonzero_mask] + eps) + + dir_x_map[y:y + h, x:x + w] *= ratio + dir_y_map[y:y + h, x:x + w] *= ratio + heat_map[y:y + h, x:x + w] = heat_roi_after + + def _polygon_from_segment(self, center_prev, center_cur, veh_width_px): + """从两点构造矩形带""" + x0, y0 = center_prev + x1, y1 = center_cur + vx = x1 - x0 + vy = y1 - y0 + dist = math.hypot(vx, vy) + if dist == 0: + return None + + ux, uy = vx / dist, vy / dist + px, py = -uy, ux + half_w = veh_width_px / 2.0 + + p0 = (int(round(x0 + px * half_w)), int(round(y0 + py * half_w))) + p1 = (int(round(x0 - px * half_w)), int(round(y0 - py * half_w))) + p2 = (int(round(x1 - px * half_w)), int(round(y1 - py * half_w))) + p3 = (int(round(x1 + px * half_w)), int(round(y1 + py * half_w))) + + return np.array([p0, p1, p2, p3], dtype=np.int32) + + def _polygon_mask_roi_from_pts(self, pts, img_shape): + """在多边形最小包围矩形内生成局部掩码""" + if pts is None or len(pts) == 0: + return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8) + + x, y, w, h = cv2.boundingRect(pts) + x2 = min(x + w, img_shape[1]) + y2 = min(y + h, img_shape[0]) + w = x2 - x + h = y2 - y + + if w <= 0 or h <= 0: + return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8) + + pts_roi = pts.copy() + pts_roi[:, 0] -= x + pts_roi[:, 1] -= y + mask = np.zeros((h, w), dtype=np.uint8) + cv2.fillPoly(mask, [pts_roi], 1) + + return (x, y, w, h), mask + + def _region_avg_heat(self, heat_roi, mask_roi): + """计算区域平均热力""" + mask = mask_roi.astype(bool) + if mask.sum() == 0: + return 0.0 + return float(heat_roi[mask].mean()) + + def _region_avg_direction(self, dir_x_roi, dir_y_roi, mask_roi): + """计算区域平均方向""" + mask = mask_roi.astype(bool) + if mask.sum() == 0: + return None + + sx = dir_x_roi[mask].sum() + sy = dir_y_roi[mask].sum() + vec = np.array([sx, sy], dtype=np.float32) + norm = np.linalg.norm(vec) + + if norm == 0: + return None + return vec / norm + + def _normalize_vec(self, v): + """归一化向量""" + v = np.array(v, dtype=np.float32) + norm = np.linalg.norm(v) + if norm == 0: + return np.array([0.0, 0.0], dtype=np.float32) + return v / norm + + def _cos_between(self, a, b): + """计算两个向量的余弦相似度""" + an = self._normalize_vec(a) + bn = self._normalize_vec(b) + return float(np.dot(an, bn)) diff --git a/pipeline/handler/__init__.py b/pipeline/handler/__init__.py new file mode 100644 index 0000000..f37e600 --- /dev/null +++ b/pipeline/handler/__init__.py @@ -0,0 +1,10 @@ +from ..base_processor import BaseProcessor +from ..pipeline import Pipeline +from ..pipeline_data import PipelineData + +__all__ = [ + 'BaseProcessor', + 'Pipeline', + 'PipelineData', + 'ResultLogger', +] \ No newline at end of file diff --git a/pipeline/handler/graffiti_visualizer.py b/pipeline/handler/graffiti_visualizer.py new file mode 100644 index 0000000..24f9ea9 --- /dev/null +++ b/pipeline/handler/graffiti_visualizer.py @@ -0,0 +1,198 @@ +import numpy as np +import cv2 +import math +from yolo_gs.pipeline.base_processor import BaseProcessor +from yolo_gs.pipeline.pipeline_data import PipelineData + + +class GraffitiVisualizer(BaseProcessor): + """涂鸦可视化处理器:绘制热力图和方向箭头""" + + def __init__(self, name: str = "涂鸦可视化处理器"): + super().__init__(name) + # 可调参数 + self.MAX_HEAT = 10.0 + self.ARROW_LEN_MIN = 10.0 + self.GRID_STEP = 40 + self.GRID_HEAT_THRESHOLD = 0.6 + self.GRID_DRAW_EVERY = 2 + self.FULL_DECAY_EVERY = 30 + + self.frame_idx = 0 + + def process(self, data: PipelineData) -> PipelineData: + """绘制涂鸦可视化""" + if data.frame is None: + return data + + self.frame_idx = data.frame_idx + vis = data.frame.copy() + + # 获取涂鸦数据 + heat_map = data.get_data("graffiti", "heat_map") + dir_x_map = data.get_data("graffiti", "dir_x_map") + dir_y_map = data.get_data("graffiti", "dir_y_map") + track_hist = data.get_data("graffiti", "track_hist") + current_events = data.get_data("retrograde", "current_events", {}) + all_events = data.get_data("retrograde", "all_events", {}) + + # 绘制热力图 + if heat_map is not None: + heat_norm = np.clip(heat_map / self.MAX_HEAT, 0.0, 1.0) + heat_u8 = (heat_norm * 255).astype(np.uint8) + red_overlay = np.zeros_like(vis) + red_overlay[:, :, 2] = heat_u8 + alpha = 0.45 + vis = cv2.addWeighted(vis, 1.0, red_overlay, alpha, 0) + + # 绘制网格方向箭头 + if (dir_x_map is not None and dir_y_map is not None and + self.frame_idx % self.GRID_DRAW_EVERY == 0): + self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map) + + # 绘制车辆轨迹和事件 + if track_hist is not None: + self._draw_tracks_and_events(vis, track_hist, current_events, all_events) + + # 绘制检测框和方向 + if data.current_result is not None: + self._draw_detections(vis, data.current_result, track_hist) + + # 添加调试信息 + self._draw_debug_info(vis) + + data.frame = vis + return data + + def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map): + """在网格上绘制方向箭头""" + h, w = heat_map.shape + for gy in range(0, h, self.GRID_STEP): + for gx in range(0, w, self.GRID_STEP): + x0, y0 = gx, gy + x1 = min(gx + self.GRID_STEP, w) + y1 = min(gy + self.GRID_STEP, h) + + win = heat_map[y0:y1, x0:x1] + if win.size == 0: + continue + + avg_heat = float(win.mean()) + if avg_heat / self.MAX_HEAT >= self.GRID_HEAT_THRESHOLD: + avg_dx = float(dir_x_map[y0:y1, x0:x1].sum()) + avg_dy = float(dir_y_map[y0:y1, x0:x1].sum()) + + if avg_dx == 0 and avg_dy == 0: + continue + + cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2) + self._draw_arrow( + vis, cell_center, + (avg_dx * 5.0, avg_dy * 5.0), + color=(0, 128, 255), thickness=2 + ) + + def _draw_tracks_and_events(self, vis, track_hist, current_events, all_events): + """绘制车辆轨迹和事件标签""" + for tid, hist in track_hist.items(): + if len(hist) == 0: + continue + + # 绘制轨迹线 + pts = np.array(hist, dtype=np.int32) + for i in range(1, len(pts)): + cv2.line(vis, tuple(pts[i - 1]), tuple(pts[i]), (255, 255, 255), 1) + + # 绘制当前位置 + cx, cy = hist[-1] + cv2.circle(vis, (int(cx), int(cy)), 3, (255, 255, 255), -1) + + # 绘制当前事件 + if tid in current_events: + event = current_events[tid] + txt = f"id:{tid} {event}" + cv2.putText( + vis, txt, (int(cx) + 6, int(cy) - 6), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2 + ) + # 绘制近期事件(过去5帧内) + elif tid in all_events: + last_frame_idx, event = all_events[tid] + if event is not None and self.frame_idx - last_frame_idx <= 5: + txt = f"id:{tid} {event}" + cv2.putText( + vis, txt, (int(cx) + 6, int(cy) - 6), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2 + ) + + def _draw_detections(self, vis, result, track_hist): + """绘制检测框和车辆方向""" + boxes = getattr(result, "boxes", None) + if boxes is None: + return + + for box in boxes: + try: + xyxy = getattr(box, "xyxy", None) + if xyxy is None: + continue + + x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist())) + track_id = int(getattr(box, "id", -1)) + cls = int(getattr(box, "cls", -1)) + + # 绘制检测框 + cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 255), 2) + + # 绘制车辆ID + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + cv2.putText( + vis, f"id:{track_id}", (cx + 6, cy - 6), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 + ) + + # 绘制车辆方向箭头 + if track_id in track_hist and len(track_hist[track_id]) >= 2: + hist = track_hist[track_id] + p0 = hist[-2] + p1 = hist[-1] + v = (p1[0] - p0[0], p1[1] - p0[1]) + self._draw_arrow( + vis, (cx, cy), (v[0], v[1]), + color=(0, 0, 255), thickness=2 + ) + + except Exception: + continue + + def _draw_arrow(self, img, start, vec, color=(0, 255, 0), thickness=2, tip_length=0.3): + """绘制箭头""" + sx, sy = int(round(start[0])), int(round(start[1])) + vx, vy = float(vec[0]), float(vec[1]) + norm = math.hypot(vx, vy) + + if norm < 1e-3: + return + + scale = 1.0 + if norm < self.ARROW_LEN_MIN: + scale = self.ARROW_LEN_MIN / (norm + 1e-6) + + ex = int(round(sx + vx * scale)) + ey = int(round(sy + vy * scale)) + + cv2.arrowedLine( + img, (sx, sy), (ex, ey), + color, thickness, tipLength=tip_length + ) + + def _draw_debug_info(self, vis): + """绘制调试信息""" + cv2.putText( + vis, f"Frame:{self.frame_idx}", (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2 + ) + cv2.putText( + vis, f"HeatMax:{self.MAX_HEAT} TrustThr:3.0", (10, 45), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1 + ) \ No newline at end of file diff --git a/test_advanced_processors.py b/test_advanced_processors.py new file mode 100644 index 0000000..71a74d6 --- /dev/null +++ b/test_advanced_processors.py @@ -0,0 +1,23 @@ +""" +测试新创建的高级处理器 +""" +from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor +from pipeline.handler.enhanced_reverse_direction_processor import EnhancedRetrogradeProcessor + + +def test_advanced_processors(): + print("测试高级处理器...") + + # 测试AdvancedHeatmapProcessor + advanced_heatmap_proc = AdvancedHeatmapProcessor() + print(f"AdvancedHeatmapProcessor创建成功: {advanced_heatmap_proc.name}") + + # 测试EnhancedRetrogradeProcessor + enhanced_retrograde_proc = EnhancedRetrogradeProcessor() + print(f"EnhancedRetrogradeProcessor创建成功: {enhanced_retrograde_proc.name}") + + print("所有高级处理器测试完成!") + + +if __name__ == "__main__": + test_advanced_processors() \ No newline at end of file diff --git a/test_compatibility.py b/test_compatibility.py new file mode 100644 index 0000000..25d6a18 --- /dev/null +++ b/test_compatibility.py @@ -0,0 +1,31 @@ +""" +测试处理器兼容性 +""" +from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor +from pipeline.handler.draw_heatmap_box_processor import DrawHeatMapBoxProcessor +from pipeline.pipeline_data import PipelineData +import numpy as np +import cv2 + + +def test_compatibility(): + print("测试处理器兼容性...") + + # 创建处理器 + heatmap_proc = AdvancedHeatmapProcessor() + draw_proc = DrawHeatMapBoxProcessor() + + print(f"处理器创建成功: {heatmap_proc.name}, {draw_proc.name}") + + # 创建模拟数据 + data = PipelineData() + # 创建一个模拟帧 + data.frame = np.zeros((480, 640, 3), dtype=np.uint8) + data.frame_idx = 1 + + # 由于我们无法创建真实的YOLO结果,我们跳过处理步骤,主要验证类定义和依赖 + print("处理器兼容性测试完成!") + + +if __name__ == "__main__": + test_compatibility() \ No newline at end of file diff --git a/test_modifications.py b/test_modifications.py new file mode 100644 index 0000000..043092f --- /dev/null +++ b/test_modifications.py @@ -0,0 +1,29 @@ +""" +测试修改后的处理器 +""" +from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor +from pipeline.handler.draw_heatmap_box_processor import DrawHeatMapBoxProcessor +from pipeline.handler.draw_id_box_processor import DrawIdBoxProcessor +from pipeline.handler.draw_motion_vector_processor import DrawMotionVectorProcessor + + +def test_modifications(): + print("测试修改后的处理器...") + + # 测试所有处理器是否都能正确创建 + proc1 = AdvancedHeatmapProcessor() + proc2 = DrawHeatMapBoxProcessor() + proc3 = DrawIdBoxProcessor() + proc4 = DrawMotionVectorProcessor() + + print('所有处理器创建成功') + print(f'高级热力图处理器: {proc1.name}') + print(f'热力图框绘制处理器: {proc2.name}') + print(f'ID框绘制处理器: {proc3.name}') + print(f'运动向量绘制处理器: {proc4.name}') + + print("修改验证完成!") + + +if __name__ == "__main__": + test_modifications() \ No newline at end of file diff --git a/test_updated_system.py b/test_updated_system.py new file mode 100644 index 0000000..a76f66b --- /dev/null +++ b/test_updated_system.py @@ -0,0 +1,39 @@ +""" +测试更新后的系统,验证移除旧处理器后的功能 +""" +from pipeline.handler import ( + AdvancedHeatmapProcessor, + EnhancedRetrogradeProcessor, + DrawIdBoxProcessor, + DrawHeatMapBoxProcessor, + DrawMotionVectorProcessor, + ResultLogger, + BoxFilter +) + +def test_updated_system(): + print("测试更新后的系统...") + + # 测试所有处理器是否都能正确创建 + proc1 = AdvancedHeatmapProcessor() + proc2 = EnhancedRetrogradeProcessor() + proc3 = DrawIdBoxProcessor() + proc4 = DrawHeatMapBoxProcessor() + proc5 = DrawMotionVectorProcessor() + proc6 = ResultLogger("TestLogger") + proc7 = BoxFilter("TestFilter") + + print('所有处理器创建成功') + print(f'高级热力图处理器: {proc1.name}') + print(f'增强型逆行检测处理器: {proc2.name}') + print(f'ID框绘制处理器: {proc3.name}') + print(f'热力图框绘制处理器: {proc4.name}') + print(f'运动向量绘制处理器: {proc5.name}') + print(f'日志处理器: {proc6.name}') + print(f'过滤器: {proc7.name}') + + print("系统更新测试完成!") + + +if __name__ == "__main__": + test_updated_system() \ No newline at end of file