import cv2 import numpy as np from dataclasses import dataclass, field from typing import List, Optional, Set from ..base_processor import BaseProcessor from ..pipeline_data import PipelineData @dataclass class DetectionClass: """检测类别数据类""" class_id: int # 类别ID(对应COCO数据集) class_name: str # 类别名称 color: tuple # 绘制颜色 (B, G, R) def __hash__(self): """用于支持集合操作""" return hash(self.class_id) def __eq__(self, other): """用于支持集合操作""" if not isinstance(other, DetectionClass): return False return self.class_id == other.class_id class DrawObjectBoxProcessor(BaseProcessor): """对象框绘制处理器:绘制检测框、ID、标签、置信度""" def __init__(self, name: str = "对象框绘制处理器", show_id: bool = True, show_label: bool = True, show_conf: bool = False, # 默认不显示置信度 box_thickness: int = 2, font_scale: float = 0.5, vehicle_classes: Optional[Set[DetectionClass]] = None): super().__init__(name) self.show_id = show_id self.show_label = show_label self.show_conf = show_conf self.box_thickness = box_thickness self.font_scale = font_scale # 初始化车辆类别集合(使用预定义或用户自定义) self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [ # coco 数据集 预设类 DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色 DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝 DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄 DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫 DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青 DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙 DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿 ] # 构建快速查询映射(提高查找效率) self._class_id_to_info = { cls.class_id: (cls.class_name, cls.color) for cls in self.vehicle_classes } # 车辆类别ID集合(用于快速判断) self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes} def process(self, data: PipelineData) -> PipelineData: """绘制检测框和相关信息""" if data.frame is None or data.current_result is None: return data vis = data.frame.copy() boxes = getattr(data.current_result, "boxes", None) if boxes is None: data.frame = vis return data # 绘制每个检测框 for box in boxes: try: self._draw_single_box(vis, box) except Exception as e: self.logger.warning(f"绘制检测框失败: {e}") continue data.frame = vis return data def _draw_single_box(self, vis: np.ndarray, box): """绘制单个检测框""" # 获取坐标 xyxy = getattr(box, "xyxy", None) if xyxy is None: return x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist())) # 获取类别和置信度 cls = int(getattr(box, "cls", -1)) conf = float(getattr(box, "conf", 0.0)) # 获取track_id track_id = int(getattr(box, "id", -1)) # 检查是否是需要绘制的车辆类别 if cls not in self._vehicle_class_ids: return # 获取类别名称和颜色 class_name, color = self._class_id_to_info.get( cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框 ) # 绘制检测框 cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness) # 构建显示文本 text_parts = [] if self.show_id and track_id != -1: text_parts.append(f"ID:{track_id}") if self.show_label: text_parts.append(class_name) if self.show_conf: text_parts.append(f"{conf:.2f}") if text_parts: text = " ".join(text_parts) # 计算文本大小 text_size = cv2.getTextSize( text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1 )[0] # 确保文本位置在图像内 text_y = max(y1 - 10, 20) # 至少离顶部20像素 text_x = max(x1, 10) # 至少离左侧10像素 # 绘制文本背景(带内边距) bg_top_left = (text_x - 2, text_y - text_size[1] - 4) bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2) cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1) # 绘制文本(白色) cv2.putText( vis, text, (text_x, text_y - 2), cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, (255, 255, 255), # 白色文字 thickness=1, lineType=cv2.LINE_AA # 抗锯齿 ) @classmethod def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]: """ 工厂方法:创建自定义的类别集合 Args: class_configs: 类别配置列表,每个元素包含 class_id, class_name, color 示例: [ {"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)}, {"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)} ] Returns: 自定义的DetectionClass集合 """ custom_classes = set() for config in class_configs: try: detection_class = DetectionClass( class_id=config["class_id"], class_name=config["class_name"], color=tuple(config["color"]) ) custom_classes.add(detection_class) except KeyError as e: raise ValueError(f"类别配置缺少必要字段: {e}") except Exception as e: raise ValueError(f"创建自定义类别失败: {e}") return custom_classes