feat(pipeline): 添加逆行处理器
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹 - 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头 - 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度 - 新增 GraffitiVisualizer 用于涂鸦可视化处理 - 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
This commit is contained in:
@@ -0,0 +1,184 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionClass:
|
||||
"""检测类别数据类"""
|
||||
class_id: int # 类别ID(对应COCO数据集)
|
||||
class_name: str # 类别名称
|
||||
color: tuple # 绘制颜色 (B, G, R)
|
||||
|
||||
def __hash__(self):
|
||||
"""用于支持集合操作"""
|
||||
return hash(self.class_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""用于支持集合操作"""
|
||||
if not isinstance(other, DetectionClass):
|
||||
return False
|
||||
return self.class_id == other.class_id
|
||||
|
||||
|
||||
class DrawObjectBoxProcessor(BaseProcessor):
|
||||
"""对象框绘制处理器:绘制检测框、ID、标签、置信度"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = "对象框绘制处理器",
|
||||
show_id: bool = True,
|
||||
show_label: bool = True,
|
||||
show_conf: bool = False, # 默认不显示置信度
|
||||
box_thickness: int = 2,
|
||||
font_scale: float = 0.5,
|
||||
vehicle_classes: Optional[Set[DetectionClass]] = None):
|
||||
super().__init__(name)
|
||||
self.show_id = show_id
|
||||
self.show_label = show_label
|
||||
self.show_conf = show_conf
|
||||
self.box_thickness = box_thickness
|
||||
self.font_scale = font_scale
|
||||
|
||||
# 初始化车辆类别集合(使用预定义或用户自定义)
|
||||
self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [
|
||||
# coco 数据集 预设类
|
||||
DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色
|
||||
DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝
|
||||
DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄
|
||||
DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫
|
||||
DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青
|
||||
DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙
|
||||
DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿
|
||||
]
|
||||
|
||||
# 构建快速查询映射(提高查找效率)
|
||||
self._class_id_to_info = {
|
||||
cls.class_id: (cls.class_name, cls.color)
|
||||
for cls in self.vehicle_classes
|
||||
}
|
||||
# 车辆类别ID集合(用于快速判断)
|
||||
self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes}
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""绘制检测框和相关信息"""
|
||||
if data.frame is None or data.current_result is None:
|
||||
return data
|
||||
|
||||
vis = data.frame.copy()
|
||||
boxes = getattr(data.current_result, "boxes", None)
|
||||
|
||||
if boxes is None:
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
# 绘制每个检测框
|
||||
for box in boxes:
|
||||
try:
|
||||
self._draw_single_box(vis, box)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"绘制检测框失败: {e}")
|
||||
continue
|
||||
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
def _draw_single_box(self, vis: np.ndarray, box):
|
||||
"""绘制单个检测框"""
|
||||
# 获取坐标
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
return
|
||||
|
||||
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
|
||||
|
||||
# 获取类别和置信度
|
||||
cls = int(getattr(box, "cls", -1))
|
||||
conf = float(getattr(box, "conf", 0.0))
|
||||
|
||||
# 获取track_id
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
|
||||
# 检查是否是需要绘制的车辆类别
|
||||
if cls not in self._vehicle_class_ids:
|
||||
return
|
||||
|
||||
# 获取类别名称和颜色
|
||||
class_name, color = self._class_id_to_info.get(
|
||||
cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框
|
||||
)
|
||||
|
||||
# 绘制检测框
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness)
|
||||
|
||||
# 构建显示文本
|
||||
text_parts = []
|
||||
|
||||
if self.show_id and track_id != -1:
|
||||
text_parts.append(f"ID:{track_id}")
|
||||
|
||||
if self.show_label:
|
||||
text_parts.append(class_name)
|
||||
|
||||
if self.show_conf:
|
||||
text_parts.append(f"{conf:.2f}")
|
||||
|
||||
if text_parts:
|
||||
text = " ".join(text_parts)
|
||||
|
||||
# 计算文本大小
|
||||
text_size = cv2.getTextSize(
|
||||
text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
|
||||
)[0]
|
||||
|
||||
# 确保文本位置在图像内
|
||||
text_y = max(y1 - 10, 20) # 至少离顶部20像素
|
||||
text_x = max(x1, 10) # 至少离左侧10像素
|
||||
|
||||
# 绘制文本背景(带内边距)
|
||||
bg_top_left = (text_x - 2, text_y - text_size[1] - 4)
|
||||
bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2)
|
||||
cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1)
|
||||
|
||||
# 绘制文本(白色)
|
||||
cv2.putText(
|
||||
vis, text,
|
||||
(text_x, text_y - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
self.font_scale,
|
||||
(255, 255, 255), # 白色文字
|
||||
thickness=1,
|
||||
lineType=cv2.LINE_AA # 抗锯齿
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]:
|
||||
"""
|
||||
工厂方法:创建自定义的类别集合
|
||||
|
||||
Args:
|
||||
class_configs: 类别配置列表,每个元素包含 class_id, class_name, color
|
||||
示例: [
|
||||
{"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)},
|
||||
{"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)}
|
||||
]
|
||||
|
||||
Returns:
|
||||
自定义的DetectionClass集合
|
||||
"""
|
||||
custom_classes = set()
|
||||
for config in class_configs:
|
||||
try:
|
||||
detection_class = DetectionClass(
|
||||
class_id=config["class_id"],
|
||||
class_name=config["class_name"],
|
||||
color=tuple(config["color"])
|
||||
)
|
||||
custom_classes.add(detection_class)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"类别配置缺少必要字段: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"创建自定义类别失败: {e}")
|
||||
return custom_classes
|
||||
Reference in New Issue
Block a user