23555e0cc9
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹 - 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头 - 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度 - 新增 GraffitiVisualizer 用于涂鸦可视化处理 - 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
import cv2
|
|
import numpy as np
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Set
|
|
from yolo_gs.pipeline.base_processor import BaseProcessor
|
|
from yolo_gs.pipeline.pipeline_data import PipelineData
|
|
|
|
|
|
@dataclass
|
|
class DetectionClass:
|
|
"""检测类别数据类"""
|
|
class_id: int # 类别ID(对应COCO数据集)
|
|
class_name: str # 类别名称
|
|
color: tuple # 绘制颜色 (B, G, R)
|
|
|
|
def __hash__(self):
|
|
"""用于支持集合操作"""
|
|
return hash(self.class_id)
|
|
|
|
def __eq__(self, other):
|
|
"""用于支持集合操作"""
|
|
if not isinstance(other, DetectionClass):
|
|
return False
|
|
return self.class_id == other.class_id
|
|
|
|
|
|
class DrawObjectBoxProcessor(BaseProcessor):
|
|
"""对象框绘制处理器:绘制检测框、ID、标签、置信度"""
|
|
|
|
def __init__(self,
|
|
name: str = "对象框绘制处理器",
|
|
show_id: bool = True,
|
|
show_label: bool = True,
|
|
show_conf: bool = False, # 默认不显示置信度
|
|
box_thickness: int = 2,
|
|
font_scale: float = 0.5,
|
|
vehicle_classes: Optional[Set[DetectionClass]] = None):
|
|
super().__init__(name)
|
|
self.show_id = show_id
|
|
self.show_label = show_label
|
|
self.show_conf = show_conf
|
|
self.box_thickness = box_thickness
|
|
self.font_scale = font_scale
|
|
|
|
# 初始化车辆类别集合(使用预定义或用户自定义)
|
|
self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [
|
|
# coco 数据集 预设类
|
|
DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色
|
|
DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝
|
|
DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄
|
|
DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫
|
|
DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青
|
|
DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙
|
|
DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿
|
|
]
|
|
|
|
# 构建快速查询映射(提高查找效率)
|
|
self._class_id_to_info = {
|
|
cls.class_id: (cls.class_name, cls.color)
|
|
for cls in self.vehicle_classes
|
|
}
|
|
# 车辆类别ID集合(用于快速判断)
|
|
self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes}
|
|
|
|
def process(self, data: PipelineData) -> PipelineData:
|
|
"""绘制检测框和相关信息"""
|
|
if data.frame is None or data.current_result is None:
|
|
return data
|
|
|
|
vis = data.frame.copy()
|
|
boxes = getattr(data.current_result, "boxes", None)
|
|
|
|
if boxes is None:
|
|
data.frame = vis
|
|
return data
|
|
|
|
# 绘制每个检测框
|
|
for box in boxes:
|
|
try:
|
|
self._draw_single_box(vis, box)
|
|
except Exception as e:
|
|
self.logger.warning(f"绘制检测框失败: {e}")
|
|
continue
|
|
|
|
data.frame = vis
|
|
return data
|
|
|
|
def _draw_single_box(self, vis: np.ndarray, box):
|
|
"""绘制单个检测框"""
|
|
# 获取坐标
|
|
xyxy = getattr(box, "xyxy", None)
|
|
if xyxy is None:
|
|
return
|
|
|
|
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
|
|
|
|
# 获取类别和置信度
|
|
cls = int(getattr(box, "cls", -1))
|
|
conf = float(getattr(box, "conf", 0.0))
|
|
|
|
# 获取track_id
|
|
track_id = int(getattr(box, "id", -1))
|
|
|
|
# 检查是否是需要绘制的车辆类别
|
|
if cls not in self._vehicle_class_ids:
|
|
return
|
|
|
|
# 获取类别名称和颜色
|
|
class_name, color = self._class_id_to_info.get(
|
|
cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框
|
|
)
|
|
|
|
# 绘制检测框
|
|
cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness)
|
|
|
|
# 构建显示文本
|
|
text_parts = []
|
|
|
|
if self.show_id and track_id != -1:
|
|
text_parts.append(f"ID:{track_id}")
|
|
|
|
if self.show_label:
|
|
text_parts.append(class_name)
|
|
|
|
if self.show_conf:
|
|
text_parts.append(f"{conf:.2f}")
|
|
|
|
if text_parts:
|
|
text = " ".join(text_parts)
|
|
|
|
# 计算文本大小
|
|
text_size = cv2.getTextSize(
|
|
text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
|
|
)[0]
|
|
|
|
# 确保文本位置在图像内
|
|
text_y = max(y1 - 10, 20) # 至少离顶部20像素
|
|
text_x = max(x1, 10) # 至少离左侧10像素
|
|
|
|
# 绘制文本背景(带内边距)
|
|
bg_top_left = (text_x - 2, text_y - text_size[1] - 4)
|
|
bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2)
|
|
cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1)
|
|
|
|
# 绘制文本(白色)
|
|
cv2.putText(
|
|
vis, text,
|
|
(text_x, text_y - 2),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
self.font_scale,
|
|
(255, 255, 255), # 白色文字
|
|
thickness=1,
|
|
lineType=cv2.LINE_AA # 抗锯齿
|
|
)
|
|
|
|
@classmethod
|
|
def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]:
|
|
"""
|
|
工厂方法:创建自定义的类别集合
|
|
|
|
Args:
|
|
class_configs: 类别配置列表,每个元素包含 class_id, class_name, color
|
|
示例: [
|
|
{"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)},
|
|
{"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)}
|
|
]
|
|
|
|
Returns:
|
|
自定义的DetectionClass集合
|
|
"""
|
|
custom_classes = set()
|
|
for config in class_configs:
|
|
try:
|
|
detection_class = DetectionClass(
|
|
class_id=config["class_id"],
|
|
class_name=config["class_name"],
|
|
color=tuple(config["color"])
|
|
)
|
|
custom_classes.add(detection_class)
|
|
except KeyError as e:
|
|
raise ValueError(f"类别配置缺少必要字段: {e}")
|
|
except Exception as e:
|
|
raise ValueError(f"创建自定义类别失败: {e}")
|
|
return custom_classes
|