Files
HighwayEventDet/pipeline/handler/DrawObjectBoxProcessor.py
T
zimoyin 23555e0cc9 feat(pipeline): 添加逆行处理器
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹
- 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头
- 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度
- 新增 GraffitiVisualizer 用于涂鸦可视化处理
- 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
2026-01-10 09:41:18 +08:00

185 lines
6.5 KiB
Python

import cv2
import numpy as np
from dataclasses import dataclass, field
from typing import List, Optional, Set
from yolo_gs.pipeline.base_processor import BaseProcessor
from yolo_gs.pipeline.pipeline_data import PipelineData
@dataclass
class DetectionClass:
"""检测类别数据类"""
class_id: int # 类别ID(对应COCO数据集)
class_name: str # 类别名称
color: tuple # 绘制颜色 (B, G, R)
def __hash__(self):
"""用于支持集合操作"""
return hash(self.class_id)
def __eq__(self, other):
"""用于支持集合操作"""
if not isinstance(other, DetectionClass):
return False
return self.class_id == other.class_id
class DrawObjectBoxProcessor(BaseProcessor):
"""对象框绘制处理器:绘制检测框、ID、标签、置信度"""
def __init__(self,
name: str = "对象框绘制处理器",
show_id: bool = True,
show_label: bool = True,
show_conf: bool = False, # 默认不显示置信度
box_thickness: int = 2,
font_scale: float = 0.5,
vehicle_classes: Optional[Set[DetectionClass]] = None):
super().__init__(name)
self.show_id = show_id
self.show_label = show_label
self.show_conf = show_conf
self.box_thickness = box_thickness
self.font_scale = font_scale
# 初始化车辆类别集合(使用预定义或用户自定义)
self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [
# coco 数据集 预设类
DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色
DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝
DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄
DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫
DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青
DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙
DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿
]
# 构建快速查询映射(提高查找效率)
self._class_id_to_info = {
cls.class_id: (cls.class_name, cls.color)
for cls in self.vehicle_classes
}
# 车辆类别ID集合(用于快速判断)
self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes}
def process(self, data: PipelineData) -> PipelineData:
"""绘制检测框和相关信息"""
if data.frame is None or data.current_result is None:
return data
vis = data.frame.copy()
boxes = getattr(data.current_result, "boxes", None)
if boxes is None:
data.frame = vis
return data
# 绘制每个检测框
for box in boxes:
try:
self._draw_single_box(vis, box)
except Exception as e:
self.logger.warning(f"绘制检测框失败: {e}")
continue
data.frame = vis
return data
def _draw_single_box(self, vis: np.ndarray, box):
"""绘制单个检测框"""
# 获取坐标
xyxy = getattr(box, "xyxy", None)
if xyxy is None:
return
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
# 获取类别和置信度
cls = int(getattr(box, "cls", -1))
conf = float(getattr(box, "conf", 0.0))
# 获取track_id
track_id = int(getattr(box, "id", -1))
# 检查是否是需要绘制的车辆类别
if cls not in self._vehicle_class_ids:
return
# 获取类别名称和颜色
class_name, color = self._class_id_to_info.get(
cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框
)
# 绘制检测框
cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness)
# 构建显示文本
text_parts = []
if self.show_id and track_id != -1:
text_parts.append(f"ID:{track_id}")
if self.show_label:
text_parts.append(class_name)
if self.show_conf:
text_parts.append(f"{conf:.2f}")
if text_parts:
text = " ".join(text_parts)
# 计算文本大小
text_size = cv2.getTextSize(
text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
)[0]
# 确保文本位置在图像内
text_y = max(y1 - 10, 20) # 至少离顶部20像素
text_x = max(x1, 10) # 至少离左侧10像素
# 绘制文本背景(带内边距)
bg_top_left = (text_x - 2, text_y - text_size[1] - 4)
bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2)
cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1)
# 绘制文本(白色)
cv2.putText(
vis, text,
(text_x, text_y - 2),
cv2.FONT_HERSHEY_SIMPLEX,
self.font_scale,
(255, 255, 255), # 白色文字
thickness=1,
lineType=cv2.LINE_AA # 抗锯齿
)
@classmethod
def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]:
"""
工厂方法:创建自定义的类别集合
Args:
class_configs: 类别配置列表,每个元素包含 class_id, class_name, color
示例: [
{"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)},
{"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)}
]
Returns:
自定义的DetectionClass集合
"""
custom_classes = set()
for config in class_configs:
try:
detection_class = DetectionClass(
class_id=config["class_id"],
class_name=config["class_name"],
color=tuple(config["color"])
)
custom_classes.add(detection_class)
except KeyError as e:
raise ValueError(f"类别配置缺少必要字段: {e}")
except Exception as e:
raise ValueError(f"创建自定义类别失败: {e}")
return custom_classes