feat(pipeline): 添加逆行处理器
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹 - 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头 - 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度 - 新增 GraffitiVisualizer 用于涂鸦可视化处理 - 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
This commit is contained in:
@@ -0,0 +1,225 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from collections import deque
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
class DrawDirectionProcessor(BaseProcessor):
|
||||
"""行驶方向绘制处理器:基于YOLO跟踪结果绘制车辆行驶方向"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = "行驶方向绘制处理器",
|
||||
arrow_length: float = 20.0,
|
||||
arrow_thickness: int = 2,
|
||||
arrow_color: tuple = (0, 0, 255), # 红色箭头
|
||||
show_trajectory: bool = True,
|
||||
trajectory_length: int = 10, # 轨迹点数量
|
||||
trajectory_color: tuple = (255, 255, 255),
|
||||
trajectory_thickness: int = 1,
|
||||
min_points_for_direction: int = 2,
|
||||
min_move_distance: float = 3.0): # 最小移动距离,避免抖动
|
||||
super().__init__(name)
|
||||
self.arrow_length = arrow_length
|
||||
self.arrow_thickness = arrow_thickness
|
||||
self.arrow_color = arrow_color
|
||||
self.show_trajectory = show_trajectory
|
||||
self.trajectory_length = trajectory_length
|
||||
self.trajectory_color = trajectory_color
|
||||
self.trajectory_thickness = trajectory_thickness
|
||||
self.min_points_for_direction = min_points_for_direction
|
||||
self.min_move_distance = min_move_distance
|
||||
|
||||
# 存储每个track_id的轨迹历史 {track_id: deque([(x1, y1), (x2, y2), ...])}
|
||||
self.track_histories = {}
|
||||
|
||||
# 只存储车辆中心点,不存储完整轨迹
|
||||
self.center_histories = {}
|
||||
|
||||
# 清理超时的track(超过多少帧没出现就清理)
|
||||
self.max_track_lifetime = 30 # 30帧没出现就清理
|
||||
self.track_last_seen = {} # {track_id: last_frame_idx}
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""基于YOLO跟踪结果绘制车辆行驶方向"""
|
||||
if data.frame is None or data.current_result is None:
|
||||
return data
|
||||
|
||||
vis = data.frame.copy()
|
||||
frame_idx = data.frame_idx
|
||||
|
||||
# 获取YOLO检测结果
|
||||
boxes = getattr(data.current_result, "boxes", None)
|
||||
if boxes is None:
|
||||
# 没有检测框,清理过期track
|
||||
self._cleanup_old_tracks(frame_idx)
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
# 当前帧存在的track_id
|
||||
current_track_ids = set()
|
||||
|
||||
# 处理每个检测框
|
||||
for box in boxes:
|
||||
try:
|
||||
# 解析检测框
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
continue
|
||||
|
||||
# 获取坐标
|
||||
x1, y1, x2, y2 = map(float, xyxy[0].tolist())
|
||||
|
||||
# 获取track_id
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
if track_id == -1:
|
||||
continue # 没有跟踪ID,跳过
|
||||
|
||||
current_track_ids.add(track_id)
|
||||
|
||||
# 计算中心点
|
||||
center_x = (x1 + x2) / 2.0
|
||||
center_y = (y1 + y2) / 2.0
|
||||
|
||||
# 更新轨迹历史
|
||||
if track_id not in self.center_histories:
|
||||
self.center_histories[track_id] = deque(maxlen=self.trajectory_length)
|
||||
|
||||
# 添加当前中心点到历史
|
||||
self.center_histories[track_id].append((center_x, center_y))
|
||||
|
||||
# 更新最后出现帧
|
||||
self.track_last_seen[track_id] = frame_idx
|
||||
|
||||
# 绘制该车辆的轨迹和方向
|
||||
self._draw_track_info(vis, track_id, (center_x, center_y))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing box: {e}")
|
||||
continue
|
||||
|
||||
# 清理长时间未出现的track
|
||||
self._cleanup_old_tracks(frame_idx)
|
||||
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
def _draw_track_info(self, vis: np.ndarray, track_id: int, current_center: tuple):
|
||||
"""绘制单个车辆的轨迹和方向"""
|
||||
if track_id not in self.center_histories:
|
||||
return
|
||||
|
||||
history = self.center_histories[track_id]
|
||||
|
||||
# 需要至少2个点才能绘制方向
|
||||
if len(history) < self.min_points_for_direction:
|
||||
return
|
||||
|
||||
# 绘制轨迹
|
||||
if self.show_trajectory:
|
||||
self._draw_trajectory(vis, history)
|
||||
|
||||
# 绘制方向箭头
|
||||
self._draw_direction_arrow(vis, history)
|
||||
|
||||
# 绘制当前中心点
|
||||
cx, cy = current_center
|
||||
cv2.circle(vis, (int(cx), int(cy)), 3, (0, 255, 255), -1) # 黄色点
|
||||
|
||||
# 绘制track_id文本
|
||||
text = f"ID:{track_id}"
|
||||
cv2.putText(vis, text,
|
||||
(int(cx) + 10, int(cy) - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(255, 255, 255), 1)
|
||||
|
||||
def _draw_trajectory(self, vis: np.ndarray, history: deque):
|
||||
"""绘制轨迹线"""
|
||||
if len(history) < 2:
|
||||
return
|
||||
|
||||
# 将deque转换为列表以便索引
|
||||
points = list(history)
|
||||
|
||||
# 绘制轨迹线
|
||||
for i in range(1, len(points)):
|
||||
pt1 = points[i - 1]
|
||||
pt2 = points[i]
|
||||
|
||||
# 确保点是有效的
|
||||
if (isinstance(pt1, (tuple, list)) and len(pt1) >= 2 and
|
||||
isinstance(pt2, (tuple, list)) and len(pt2) >= 2):
|
||||
x1, y1 = int(pt1[0]), int(pt1[1])
|
||||
x2, y2 = int(pt2[0]), int(pt2[1])
|
||||
|
||||
# 绘制轨迹线
|
||||
cv2.line(vis, (x1, y1), (x2, y2),
|
||||
self.trajectory_color, self.trajectory_thickness)
|
||||
|
||||
def _draw_direction_arrow(self, vis: np.ndarray, history: deque):
|
||||
"""绘制方向箭头"""
|
||||
if len(history) < 2:
|
||||
return
|
||||
|
||||
# 获取最近的两个点计算方向
|
||||
points = list(history)
|
||||
current_point = points[-1]
|
||||
|
||||
# 使用最近的两个点计算方向(避免抖动)
|
||||
start_idx = max(0, len(points) - 3) # 使用最近3个点
|
||||
direction_points = points[start_idx:]
|
||||
|
||||
if len(direction_points) < 2:
|
||||
return
|
||||
|
||||
# 计算平均方向(首尾点之间的向量)
|
||||
start_point = direction_points[0]
|
||||
end_point = direction_points[-1]
|
||||
|
||||
dx = end_point[0] - start_point[0]
|
||||
dy = end_point[1] - start_point[1]
|
||||
|
||||
# 计算移动距离
|
||||
move_distance = math.hypot(dx, dy)
|
||||
|
||||
# 如果移动距离太小,不绘制箭头(避免抖动)
|
||||
if move_distance < self.min_move_distance:
|
||||
return
|
||||
|
||||
# 归一化方向向量
|
||||
if move_distance > 0:
|
||||
dx_norm = dx / move_distance
|
||||
dy_norm = dy / move_distance
|
||||
else:
|
||||
return
|
||||
|
||||
# 计算箭头起点(当前点)
|
||||
sx, sy = current_point
|
||||
|
||||
# 计算箭头终点
|
||||
ex = int(sx + dx_norm * self.arrow_length)
|
||||
ey = int(sy + dy_norm * self.arrow_length)
|
||||
|
||||
# 绘制方向箭头
|
||||
cv2.arrowedLine(vis,
|
||||
(int(sx), int(sy)),
|
||||
(ex, ey),
|
||||
self.arrow_color,
|
||||
self.arrow_thickness,
|
||||
tipLength=0.3)
|
||||
|
||||
def _cleanup_old_tracks(self, current_frame_idx: int):
|
||||
"""清理长时间未出现的track"""
|
||||
tracks_to_remove = []
|
||||
|
||||
for track_id, last_seen in list(self.track_last_seen.items()):
|
||||
# 如果超过最大生命周期没出现,则清理
|
||||
if current_frame_idx - last_seen > self.max_track_lifetime:
|
||||
tracks_to_remove.append(track_id)
|
||||
|
||||
for track_id in tracks_to_remove:
|
||||
if track_id in self.center_histories:
|
||||
del self.center_histories[track_id]
|
||||
if track_id in self.track_last_seen:
|
||||
del self.track_last_seen[track_id]
|
||||
@@ -0,0 +1,198 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
class DrawGraffitiProcessor(BaseProcessor):
|
||||
"""涂鸦绘制处理器:绘制热力图和网格方向箭头"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = "涂鸦绘制处理器",
|
||||
heatmap_opacity: float = 0.45,
|
||||
grid_step: int = 40,
|
||||
grid_heat_threshold: float = 0.6,
|
||||
grid_draw_every: int = 2,
|
||||
max_heat: float = 10.0):
|
||||
super().__init__(name)
|
||||
self.heatmap_opacity = heatmap_opacity
|
||||
self.grid_step = grid_step
|
||||
self.grid_heat_threshold = grid_heat_threshold
|
||||
self.grid_draw_every = grid_draw_every
|
||||
self.max_heat = max_heat
|
||||
|
||||
# 箭头参数
|
||||
self.arrow_length_min = 10.0
|
||||
self.grid_arrow_color = (0, 128, 255) # 橙色 - 网格箭头
|
||||
|
||||
# 帧计数器
|
||||
self.frame_idx = 0
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""绘制热力图和网格方向箭头"""
|
||||
if data.frame is None:
|
||||
return data
|
||||
|
||||
self.frame_idx = data.frame_idx
|
||||
vis = data.frame.copy()
|
||||
|
||||
# 获取涂鸦数据
|
||||
heat_map = data.get_data("graffiti", "heat_map")
|
||||
dir_x_map = data.get_data("graffiti", "dir_x_map")
|
||||
dir_y_map = data.get_data("graffiti", "dir_y_map")
|
||||
|
||||
# 绘制热力图
|
||||
if heat_map is not None:
|
||||
vis = self._draw_heatmap(vis, heat_map)
|
||||
|
||||
# 绘制网格方向箭头
|
||||
if (dir_x_map is not None and dir_y_map is not None and
|
||||
self.frame_idx % self.grid_draw_every == 0):
|
||||
vis = self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map)
|
||||
|
||||
# TODO 绘制调试信息,这个应该是单独的处理器
|
||||
vis = self._draw_debug_info(vis)
|
||||
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
def _draw_heatmap(self, vis, heat_map):
|
||||
"""绘制热力图叠加层"""
|
||||
# 归一化热力图
|
||||
heat_norm = np.clip(heat_map / self.max_heat, 0.0, 1.0)
|
||||
heat_u8 = (heat_norm * 255).astype(np.uint8)
|
||||
|
||||
# 创建红色通道叠加层
|
||||
red_overlay = np.zeros_like(vis)
|
||||
red_overlay[:, :, 2] = heat_u8 # 红色通道
|
||||
|
||||
# 叠加热力图
|
||||
vis = cv2.addWeighted(vis, 1.0, red_overlay, self.heatmap_opacity, 0)
|
||||
|
||||
return vis
|
||||
|
||||
def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map):
|
||||
"""在网格上绘制持久方向箭头"""
|
||||
h, w = heat_map.shape
|
||||
|
||||
for gy in range(0, h, self.grid_step):
|
||||
for gx in range(0, w, self.grid_step):
|
||||
x0, y0 = gx, gy
|
||||
x1 = min(gx + self.grid_step, w)
|
||||
y1 = min(gy + self.grid_step, h)
|
||||
|
||||
# 检查网格区域热力
|
||||
heat_window = heat_map[y0:y1, x0:x1]
|
||||
if heat_window.size == 0:
|
||||
continue
|
||||
|
||||
avg_heat = float(heat_window.mean())
|
||||
|
||||
# 如果平均热力达到阈值,绘制方向箭头
|
||||
if avg_heat / self.max_heat >= self.grid_heat_threshold:
|
||||
avg_dx = float(dir_x_map[y0:y1, x0:x1].sum())
|
||||
avg_dy = float(dir_y_map[y0:y1, x0:x1].sum())
|
||||
|
||||
if avg_dx == 0 and avg_dy == 0:
|
||||
continue
|
||||
|
||||
# 计算网格中心点
|
||||
cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2)
|
||||
|
||||
# 绘制网格边界(可选,用于调试)
|
||||
# cv2.rectangle(vis, (x0, y0), (x1, y1), (100, 100, 100), 1)
|
||||
|
||||
# 绘制方向箭头
|
||||
self._draw_arrow(vis, cell_center,
|
||||
(avg_dx * 5.0, avg_dy * 5.0))
|
||||
|
||||
return vis
|
||||
|
||||
def _draw_arrow(self, img, start, vec, color=None, thickness=2):
|
||||
"""绘制箭头"""
|
||||
if color is None:
|
||||
color = self.grid_arrow_color
|
||||
|
||||
sx, sy = int(round(start[0])), int(round(start[1]))
|
||||
vx, vy = float(vec[0]), float(vec[1])
|
||||
|
||||
norm = math.hypot(vx, vy)
|
||||
if norm < 1e-3:
|
||||
return
|
||||
|
||||
# 确保箭头有最小长度
|
||||
scale = 1.0
|
||||
if norm < self.arrow_length_min:
|
||||
scale = self.arrow_length_min / (norm + 1e-6)
|
||||
|
||||
ex = int(round(sx + vx * scale))
|
||||
ey = int(round(sy + vy * scale))
|
||||
|
||||
cv2.arrowedLine(img,
|
||||
(sx, sy),
|
||||
(ex, ey),
|
||||
color, thickness,
|
||||
tipLength=0.3)
|
||||
|
||||
def _draw_debug_info(self, vis):
|
||||
"""绘制调试信息"""
|
||||
# 绘制帧号
|
||||
cv2.putText(vis,
|
||||
f"Frame:{self.frame_idx}",
|
||||
(10, 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
||||
(200, 200, 200), 2)
|
||||
|
||||
# 绘制热力图参数
|
||||
cv2.putText(vis,
|
||||
f"HeatMax:{self.max_heat} TrustThr:3.0",
|
||||
(10, 45),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(200, 200, 200), 1)
|
||||
|
||||
# 绘制网格参数
|
||||
cv2.putText(vis,
|
||||
f"Grid:{self.grid_step}px Thresh:{self.grid_heat_threshold}",
|
||||
(10, 65),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(200, 200, 200), 1)
|
||||
|
||||
# 绘制热力图图例
|
||||
legend_y = vis.shape[0] - 30
|
||||
legend_x = 10
|
||||
|
||||
# 图例标题
|
||||
cv2.putText(vis, "Heat Legend:",
|
||||
(legend_x, legend_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(200, 200, 200), 1)
|
||||
|
||||
# 绘制颜色条
|
||||
colorbar_width = 100
|
||||
colorbar_height = 10
|
||||
colorbar_x = legend_x + 80
|
||||
colorbar_y = legend_y - 5
|
||||
|
||||
# 创建渐变颜色条
|
||||
for i in range(colorbar_width):
|
||||
intensity = i / colorbar_width
|
||||
color = (0, 0, int(intensity * 255))
|
||||
cv2.line(vis,
|
||||
(colorbar_x + i, colorbar_y),
|
||||
(colorbar_x + i, colorbar_y + colorbar_height),
|
||||
color, 1)
|
||||
|
||||
# 标注最小值
|
||||
cv2.putText(vis, "0",
|
||||
(colorbar_x - 10, colorbar_y + 15),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
||||
(200, 200, 200), 1)
|
||||
|
||||
# 标注最大值
|
||||
cv2.putText(vis, f"{self.max_heat}",
|
||||
(colorbar_x + colorbar_width - 10, colorbar_y + 15),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
||||
(200, 200, 200), 1)
|
||||
|
||||
return vis
|
||||
@@ -0,0 +1,184 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionClass:
|
||||
"""检测类别数据类"""
|
||||
class_id: int # 类别ID(对应COCO数据集)
|
||||
class_name: str # 类别名称
|
||||
color: tuple # 绘制颜色 (B, G, R)
|
||||
|
||||
def __hash__(self):
|
||||
"""用于支持集合操作"""
|
||||
return hash(self.class_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""用于支持集合操作"""
|
||||
if not isinstance(other, DetectionClass):
|
||||
return False
|
||||
return self.class_id == other.class_id
|
||||
|
||||
|
||||
class DrawObjectBoxProcessor(BaseProcessor):
|
||||
"""对象框绘制处理器:绘制检测框、ID、标签、置信度"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = "对象框绘制处理器",
|
||||
show_id: bool = True,
|
||||
show_label: bool = True,
|
||||
show_conf: bool = False, # 默认不显示置信度
|
||||
box_thickness: int = 2,
|
||||
font_scale: float = 0.5,
|
||||
vehicle_classes: Optional[Set[DetectionClass]] = None):
|
||||
super().__init__(name)
|
||||
self.show_id = show_id
|
||||
self.show_label = show_label
|
||||
self.show_conf = show_conf
|
||||
self.box_thickness = box_thickness
|
||||
self.font_scale = font_scale
|
||||
|
||||
# 初始化车辆类别集合(使用预定义或用户自定义)
|
||||
self.vehicle_classes: Set[DetectionClass] = vehicle_classes or [
|
||||
# coco 数据集 预设类
|
||||
DetectionClass(class_id=0, class_name="person", color=(255, 0, 0)), # 蓝色
|
||||
DetectionClass(class_id=1, class_name="bicycle", color=(60, 100, 180)), # 深蓝
|
||||
DetectionClass(class_id=2, class_name="car", color=(180, 160, 40)), # 深黄
|
||||
DetectionClass(class_id=3, class_name="motorcycle", color=(140, 80, 140)), # 深紫
|
||||
DetectionClass(class_id=5, class_name="bus", color=(40, 160, 160)), # 深青
|
||||
DetectionClass(class_id=7, class_name="truck", color=(180, 120, 40)), # 深橙
|
||||
DetectionClass(class_id=8, class_name="boat", color=(60, 120, 60)), # 深绿
|
||||
]
|
||||
|
||||
# 构建快速查询映射(提高查找效率)
|
||||
self._class_id_to_info = {
|
||||
cls.class_id: (cls.class_name, cls.color)
|
||||
for cls in self.vehicle_classes
|
||||
}
|
||||
# 车辆类别ID集合(用于快速判断)
|
||||
self._vehicle_class_ids = {cls.class_id for cls in self.vehicle_classes}
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""绘制检测框和相关信息"""
|
||||
if data.frame is None or data.current_result is None:
|
||||
return data
|
||||
|
||||
vis = data.frame.copy()
|
||||
boxes = getattr(data.current_result, "boxes", None)
|
||||
|
||||
if boxes is None:
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
# 绘制每个检测框
|
||||
for box in boxes:
|
||||
try:
|
||||
self._draw_single_box(vis, box)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"绘制检测框失败: {e}")
|
||||
continue
|
||||
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
def _draw_single_box(self, vis: np.ndarray, box):
|
||||
"""绘制单个检测框"""
|
||||
# 获取坐标
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
return
|
||||
|
||||
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
|
||||
|
||||
# 获取类别和置信度
|
||||
cls = int(getattr(box, "cls", -1))
|
||||
conf = float(getattr(box, "conf", 0.0))
|
||||
|
||||
# 获取track_id
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
|
||||
# 检查是否是需要绘制的车辆类别
|
||||
if cls not in self._vehicle_class_ids:
|
||||
return
|
||||
|
||||
# 获取类别名称和颜色
|
||||
class_name, color = self._class_id_to_info.get(
|
||||
cls, ("unknown", (0, 255, 255)) # 默认白色文字,黄色框
|
||||
)
|
||||
|
||||
# 绘制检测框
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), color, self.box_thickness)
|
||||
|
||||
# 构建显示文本
|
||||
text_parts = []
|
||||
|
||||
if self.show_id and track_id != -1:
|
||||
text_parts.append(f"ID:{track_id}")
|
||||
|
||||
if self.show_label:
|
||||
text_parts.append(class_name)
|
||||
|
||||
if self.show_conf:
|
||||
text_parts.append(f"{conf:.2f}")
|
||||
|
||||
if text_parts:
|
||||
text = " ".join(text_parts)
|
||||
|
||||
# 计算文本大小
|
||||
text_size = cv2.getTextSize(
|
||||
text, cv2.FONT_HERSHEY_SIMPLEX, self.font_scale, 1
|
||||
)[0]
|
||||
|
||||
# 确保文本位置在图像内
|
||||
text_y = max(y1 - 10, 20) # 至少离顶部20像素
|
||||
text_x = max(x1, 10) # 至少离左侧10像素
|
||||
|
||||
# 绘制文本背景(带内边距)
|
||||
bg_top_left = (text_x - 2, text_y - text_size[1] - 4)
|
||||
bg_bottom_right = (text_x + text_size[0] + 2, text_y + 2)
|
||||
cv2.rectangle(vis, bg_top_left, bg_bottom_right, color, -1)
|
||||
|
||||
# 绘制文本(白色)
|
||||
cv2.putText(
|
||||
vis, text,
|
||||
(text_x, text_y - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
self.font_scale,
|
||||
(255, 255, 255), # 白色文字
|
||||
thickness=1,
|
||||
lineType=cv2.LINE_AA # 抗锯齿
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_custom_classes(cls, class_configs: List[dict]) -> Set[DetectionClass]:
|
||||
"""
|
||||
工厂方法:创建自定义的类别集合
|
||||
|
||||
Args:
|
||||
class_configs: 类别配置列表,每个元素包含 class_id, class_name, color
|
||||
示例: [
|
||||
{"class_id": 2, "class_name": "轿车", "color": (0, 255, 0)},
|
||||
{"class_id": 3, "class_name": "摩托车", "color": (255, 0, 0)}
|
||||
]
|
||||
|
||||
Returns:
|
||||
自定义的DetectionClass集合
|
||||
"""
|
||||
custom_classes = set()
|
||||
for config in class_configs:
|
||||
try:
|
||||
detection_class = DetectionClass(
|
||||
class_id=config["class_id"],
|
||||
class_name=config["class_name"],
|
||||
color=tuple(config["color"])
|
||||
)
|
||||
custom_classes.add(detection_class)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"类别配置缺少必要字段: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"创建自定义类别失败: {e}")
|
||||
return custom_classes
|
||||
@@ -0,0 +1,263 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
from collections import deque, defaultdict
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
class GraffitiProcessor(BaseProcessor):
|
||||
"""涂鸦计算处理器:计算车辆轨迹并更新车道涂鸦"""
|
||||
|
||||
def __init__(self, name: str = "涂鸦计算处理器"):
|
||||
super().__init__(name)
|
||||
# 可调参数
|
||||
self.MAX_HEAT = 10.0
|
||||
self.INC_PER_PASS = 1.0
|
||||
self.DEC_PER_OPPOSITE = 1.0
|
||||
self.MIN_MOVE_DIST = 6.0
|
||||
self.MIN_CUM_DIST = 20.0
|
||||
self.HISTORY_LEN = 16
|
||||
self.HEAT_DECAY_PER_FRAME = 0.002
|
||||
self.VEHICLE_CLASSES = {2, 3, 5, 7}
|
||||
self.FULL_DECAY_EVERY = 30 # 每30帧进行一次全图衰减
|
||||
|
||||
# 初始化数据结构(在process中根据帧尺寸初始化)
|
||||
self.graffiti = None
|
||||
self.frame_shape = None
|
||||
self.track_hist = defaultdict(lambda: deque(maxlen=self.HISTORY_LEN))
|
||||
self.track_cum_dist = defaultdict(float)
|
||||
self.frame_idx = 0
|
||||
|
||||
def init_graffiti(self, h, w):
|
||||
"""初始化涂鸦数据结构"""
|
||||
if self.graffiti is None or self.frame_shape != (h, w):
|
||||
self.graffiti = GraffitiLane(w, h)
|
||||
self.frame_shape = (h, w)
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""处理每一帧,更新涂鸦数据"""
|
||||
if data.current_result is None or data.frame is None:
|
||||
return data
|
||||
|
||||
frame = data.frame
|
||||
h, w = frame.shape[:2]
|
||||
self.init_graffiti(h, w)
|
||||
self.frame_idx = data.frame_idx
|
||||
|
||||
# 周期性对整图做完整衰减
|
||||
if self.frame_idx % self.FULL_DECAY_EVERY == 0:
|
||||
self.graffiti.full_decay(self.frame_idx)
|
||||
|
||||
# 获取检测框
|
||||
boxes = getattr(data.current_result, "boxes", None)
|
||||
if boxes is None:
|
||||
return data
|
||||
|
||||
# 处理每个检测框
|
||||
for box in boxes:
|
||||
try:
|
||||
# 解析检测框
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = map(float, xyxy[0].tolist())
|
||||
cls = int(getattr(box, "cls", -1))
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
|
||||
# 仅处理车辆类
|
||||
if cls not in self.VEHICLE_CLASSES:
|
||||
continue
|
||||
|
||||
# 更新轨迹和历史
|
||||
self._update_track(track_id, x1, y1, x2, y2)
|
||||
|
||||
# 如果累计距离达到阈值,更新涂鸦
|
||||
if (self.track_cum_dist[track_id] >= self.MIN_CUM_DIST and
|
||||
len(self.track_hist[track_id]) >= 2):
|
||||
self._update_graffiti(track_id, x2 - x1)
|
||||
self.track_cum_dist[track_id] = 0.0
|
||||
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# 清理长时间不活跃的track
|
||||
self._cleanup_old_tracks()
|
||||
|
||||
# 存储涂鸦数据到PipelineData(deque转换为list以便序列化)
|
||||
data.put_data("graffiti", "heat_map", self.graffiti.heat)
|
||||
data.put_data("graffiti", "dir_x_map", self.graffiti.dir_x)
|
||||
data.put_data("graffiti", "dir_y_map", self.graffiti.dir_y)
|
||||
data.put_data("graffiti", "track_hist", {
|
||||
track_id: list(hist) for track_id, hist in self.track_hist.items()
|
||||
})
|
||||
data.put_data("graffiti", "track_cum_dist", dict(self.track_cum_dist))
|
||||
|
||||
return data
|
||||
|
||||
def _update_track(self, track_id, x1, y1, x2, y2):
|
||||
"""更新车辆轨迹信息"""
|
||||
cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
|
||||
prev_centers = self.track_hist[track_id]
|
||||
|
||||
if len(prev_centers) > 0:
|
||||
last_cx, last_cy = prev_centers[-1]
|
||||
move_dist = math.hypot(cx - last_cx, cy - last_cy)
|
||||
else:
|
||||
move_dist = 0.0
|
||||
|
||||
if len(prev_centers) > 0:
|
||||
self.track_cum_dist[track_id] += move_dist
|
||||
else:
|
||||
self.track_cum_dist[track_id] = 0.0
|
||||
|
||||
if move_dist >= self.MIN_MOVE_DIST:
|
||||
prev_centers.append((cx, cy))
|
||||
elif len(prev_centers) == 0:
|
||||
prev_centers.append((cx, cy))
|
||||
|
||||
def _update_graffiti(self, track_id, vehicle_width):
|
||||
"""根据车辆轨迹更新涂鸦"""
|
||||
hist = self.track_hist[track_id]
|
||||
if len(hist) < 2:
|
||||
return
|
||||
|
||||
cp_prev = hist[-2]
|
||||
cp_cur = hist[-1]
|
||||
|
||||
# 创建多边形掩码
|
||||
poly = self._polygon_from_segment(cp_prev, cp_cur, max(8.0, vehicle_width))
|
||||
if poly is None:
|
||||
return
|
||||
|
||||
# 计算ROI掩码
|
||||
bbox_roi, mask_roi = self._polygon_mask_roi_from_pts(poly, self.frame_shape)
|
||||
if mask_roi.size == 0:
|
||||
return
|
||||
|
||||
# 计算方向向量
|
||||
dir_vec = self._normalize_vec((cp_cur[0] - cp_prev[0], cp_cur[1] - cp_prev[1]))
|
||||
dir_vec_scaled = dir_vec.astype(np.float32) * 1.0
|
||||
|
||||
# 更新涂鸦
|
||||
self.graffiti.apply_paint_roi(
|
||||
bbox_roi, mask_roi, dir_vec_scaled,
|
||||
inc=self.INC_PER_PASS,
|
||||
current_frame=self.frame_idx
|
||||
)
|
||||
|
||||
def _cleanup_old_tracks(self):
|
||||
"""清理长时间不活跃的track,防止内存泄漏"""
|
||||
# 可以根据需要实现,比如超过一定帧数没有更新就清理
|
||||
# 这里先不实现,根据实际情况调整
|
||||
pass
|
||||
|
||||
def _polygon_from_segment(self, center_prev, center_cur, veh_width_px):
|
||||
"""从两点构造矩形带"""
|
||||
x0, y0 = center_prev
|
||||
x1, y1 = center_cur
|
||||
vx = x1 - x0
|
||||
vy = y1 - y0
|
||||
dist = math.hypot(vx, vy)
|
||||
if dist == 0:
|
||||
return None
|
||||
|
||||
ux, uy = vx / dist, vy / dist
|
||||
px, py = -uy, ux
|
||||
half_w = veh_width_px / 2.0
|
||||
|
||||
p0 = (int(round(x0 + px * half_w)), int(round(y0 + py * half_w)))
|
||||
p1 = (int(round(x0 - px * half_w)), int(round(y0 - py * half_w)))
|
||||
p2 = (int(round(x1 - px * half_w)), int(round(y1 - py * half_w)))
|
||||
p3 = (int(round(x1 + px * half_w)), int(round(y1 + py * half_w)))
|
||||
|
||||
return np.array([p0, p1, p2, p3], dtype=np.int32)
|
||||
|
||||
def _polygon_mask_roi_from_pts(self, pts, img_shape):
|
||||
"""在多边形最小包围矩形内生成局部掩码"""
|
||||
if pts is None or len(pts) == 0:
|
||||
return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(pts)
|
||||
x2 = min(x + w, img_shape[1])
|
||||
y2 = min(y + h, img_shape[0])
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
|
||||
if w <= 0 or h <= 0:
|
||||
return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8)
|
||||
|
||||
pts_roi = pts.copy()
|
||||
pts_roi[:, 0] -= x
|
||||
pts_roi[:, 1] -= y
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts_roi], 1)
|
||||
|
||||
return (x, y, w, h), mask
|
||||
|
||||
def _normalize_vec(self, v):
|
||||
"""归一化向量"""
|
||||
v = np.array(v, dtype=np.float32)
|
||||
norm = np.linalg.norm(v)
|
||||
if norm == 0:
|
||||
return np.array([0.0, 0.0], dtype=np.float32)
|
||||
return v / norm
|
||||
|
||||
|
||||
class GraffitiLane:
|
||||
"""涂鸦车道数据结构"""
|
||||
|
||||
def __init__(self, frame_w, frame_h):
|
||||
self.w = frame_w
|
||||
self.h = frame_h
|
||||
self.heat = np.zeros((self.h, self.w), dtype=np.float32)
|
||||
self.dir_x = np.zeros((self.h, self.w), dtype=np.float32)
|
||||
self.dir_y = np.zeros((self.h, self.w), dtype=np.float32)
|
||||
self.last_decay_frame = 0
|
||||
|
||||
def _decay_factor(self, frames):
|
||||
if frames <= 0:
|
||||
return 1.0
|
||||
return (1.0 - 0.002) ** frames # HEAT_DECAY_PER_FRAME = 0.002
|
||||
|
||||
def lazy_decay_roi(self, bbox, current_frame):
|
||||
x, y, w, h = bbox
|
||||
if w <= 0 or h <= 0:
|
||||
return
|
||||
|
||||
frames = current_frame - self.last_decay_frame
|
||||
if frames <= 0:
|
||||
return
|
||||
|
||||
factor = self._decay_factor(frames)
|
||||
self.heat[y:y + h, x:x + w] *= factor
|
||||
self.dir_x[y:y + h, x:x + w] *= factor
|
||||
self.dir_y[y:y + h, x:x + w] *= factor
|
||||
|
||||
def full_decay(self, current_frame):
|
||||
"""在整图上强制应用pending衰减(用于周期性清理)"""
|
||||
frames = current_frame - self.last_decay_frame
|
||||
if frames <= 0:
|
||||
return
|
||||
|
||||
factor = self._decay_factor(frames)
|
||||
if factor == 1.0:
|
||||
self.last_decay_frame = current_frame
|
||||
return
|
||||
|
||||
# 对整图应用因子
|
||||
self.heat *= factor
|
||||
self.dir_x *= factor
|
||||
self.dir_y *= factor
|
||||
self.last_decay_frame = current_frame
|
||||
|
||||
def apply_paint_roi(self, bbox, mask_roi, direction_vector, inc=1.0, current_frame=0):
|
||||
self.lazy_decay_roi(bbox, current_frame)
|
||||
x, y, w, h = bbox
|
||||
add = inc * mask_roi.astype(np.float32)
|
||||
|
||||
self.heat[y:y + h, x:x + w] = np.minimum(10.0, self.heat[y:y + h, x:x + w] + add) # MAX_HEAT = 10.0
|
||||
dx, dy = direction_vector
|
||||
self.dir_x[y:y + h, x:x + w] += dx * add
|
||||
self.dir_y[y:y + h, x:x + w] += dy * add
|
||||
@@ -0,0 +1,10 @@
|
||||
from ..base_processor import BaseProcessor
|
||||
from ..pipeline_data import PipelineData
|
||||
import numpy as np
|
||||
|
||||
class ResultLogger(BaseProcessor):
|
||||
"""示例处理器:打印检测结果日志"""
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
print(f"\n【{self.name}】帧{data.frame_idx} - 检测到目标数: {len(data.current_result.boxes)}")
|
||||
print(f"缓存帧数: {len(data.result_cache)}")
|
||||
return data
|
||||
@@ -0,0 +1,244 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
import math
|
||||
|
||||
|
||||
class RetrogradeProcessor(BaseProcessor):
|
||||
"""逆行检测处理器:检测逆向行驶的车辆"""
|
||||
|
||||
def __init__(self, name: str = "逆行检测处理器"):
|
||||
super().__init__(name)
|
||||
# 可调参数
|
||||
self.TRUST_THRESHOLD = 3.0
|
||||
self.PARTIAL_OVERLAP_THRESHOLD = 0.1
|
||||
self.MAJORITY_OVERLAP_THRESHOLD = 0.5
|
||||
self.DEC_PER_OPPOSITE = 1.0
|
||||
# self.OPPOSITE_ANGLE_COS_THRESH = -0.9397
|
||||
self.OPPOSITE_ANGLE_COS_THRESH = math.cos(math.radians(1))
|
||||
self.VEHICLE_CLASSES = {2, 3, 5, 7}
|
||||
|
||||
# 存储上一帧的事件信息
|
||||
self.track_last_event = {}
|
||||
self.frame_idx = 0
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""检测逆行行为"""
|
||||
if data.current_result is None:
|
||||
return data
|
||||
|
||||
self.frame_idx = data.frame_idx
|
||||
|
||||
# 获取涂鸦数据
|
||||
heat_map = data.get_data("graffiti", "heat_map")
|
||||
dir_x_map = data.get_data("graffiti", "dir_x_map")
|
||||
dir_y_map = data.get_data("graffiti", "dir_y_map")
|
||||
track_hist = data.get_data("graffiti", "track_hist")
|
||||
|
||||
if heat_map is None or track_hist is None:
|
||||
return data
|
||||
|
||||
boxes = getattr(data.current_result, "boxes", None)
|
||||
if boxes is None:
|
||||
return data
|
||||
|
||||
frame_events = {}
|
||||
events_all0 = {}
|
||||
# 处理每个检测框
|
||||
for box in boxes:
|
||||
try:
|
||||
# 解析检测框
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = map(float, xyxy[0].tolist())
|
||||
cls = int(getattr(box, "cls", -1))
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
|
||||
# 仅处理车辆类
|
||||
if cls not in self.VEHICLE_CLASSES:
|
||||
continue
|
||||
|
||||
# 获取车辆历史轨迹
|
||||
hist = track_hist.get(track_id, [])
|
||||
if len(hist) < 2:
|
||||
continue
|
||||
|
||||
# 计算当前方向
|
||||
cp_prev = hist[-2]
|
||||
cp_cur = hist[-1]
|
||||
dir_vec = self._normalize_vec((cp_cur[0] - cp_prev[0], cp_cur[1] - cp_prev[1]))
|
||||
|
||||
# 计算车辆经过的区域
|
||||
vehicle_width = max(8.0, x2 - x1)
|
||||
poly = self._polygon_from_segment(cp_prev, cp_cur, vehicle_width)
|
||||
if poly is None:
|
||||
continue
|
||||
|
||||
# 创建ROI掩码
|
||||
frame_shape = heat_map.shape[:2] if len(heat_map.shape) == 2 else heat_map.shape
|
||||
bbox_roi, mask_roi = self._polygon_mask_roi_from_pts(poly, frame_shape)
|
||||
if mask_roi.size == 0:
|
||||
continue
|
||||
|
||||
# 计算与现有涂鸦的重叠
|
||||
x, y, w_roi, h_roi = bbox_roi
|
||||
heat_roi = heat_map[y:y + h_roi, x:x + w_roi]
|
||||
overlap_pixels = np.logical_and(mask_roi > 0, heat_roi > 0)
|
||||
overlap_ratio = float(overlap_pixels.sum()) / max(1.0, mask_roi.sum())
|
||||
|
||||
# 计算区域平均热力和方向
|
||||
region_heat = self._region_avg_heat(heat_roi, mask_roi)
|
||||
region_dir = self._region_avg_direction(dir_x_map[y:y + h_roi, x:x + w_roi],
|
||||
dir_y_map[y:y + h_roi, x:x + w_roi],
|
||||
mask_roi)
|
||||
|
||||
# 判断事件类型
|
||||
event = self._detect_retrograde_event(
|
||||
region_heat, region_dir, dir_vec, overlap_ratio
|
||||
)
|
||||
events_all0[track_id] = event
|
||||
|
||||
if event:
|
||||
frame_events[track_id] = event
|
||||
# 减少涂鸦区域热力(逆向行驶会减少涂鸦可信度)
|
||||
self._reduce_paint(heat_map, dir_x_map, dir_y_map, bbox_roi, mask_roi, overlap_ratio)
|
||||
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# 存储事件信息
|
||||
data.put_data("retrograde", "current_events", frame_events)
|
||||
data.put_data("retrograde", "all_events", self.track_last_event)
|
||||
|
||||
# 输出逆行信息
|
||||
for track_id, event in events_all0.items():
|
||||
print(f"【逆行检测器】帧{self.frame_idx} - 检测到事件:{event} - 轨迹ID: {track_id}")
|
||||
|
||||
# 更新事件历史
|
||||
for track_id, event in frame_events.items():
|
||||
self.track_last_event[track_id] = (self.frame_idx, event)
|
||||
|
||||
return data
|
||||
|
||||
def _detect_retrograde_event(self, region_heat, region_dir, current_dir, overlap_ratio):
|
||||
"""检测逆行事件类型"""
|
||||
if region_heat >= 1.0 and region_dir is not None:
|
||||
cos_sim = self._cos_between(region_dir, current_dir)
|
||||
|
||||
if cos_sim < self.OPPOSITE_ANGLE_COS_THRESH:
|
||||
if region_heat >= self.TRUST_THRESHOLD:
|
||||
if overlap_ratio >= self.MAJORITY_OVERLAP_THRESHOLD:
|
||||
return "REVERSE"
|
||||
elif overlap_ratio >= self.PARTIAL_OVERLAP_THRESHOLD:
|
||||
return "CROSS_LINE"
|
||||
else:
|
||||
return "SUSPECT"
|
||||
else:
|
||||
return "SUSPECT"
|
||||
else:
|
||||
return "NORMAL"
|
||||
|
||||
return None
|
||||
|
||||
def _reduce_paint(self, heat_map, dir_x_map, dir_y_map, bbox_roi, mask_roi, overlap_ratio):
|
||||
"""减少涂鸦区域热力(逆向行驶惩罚)"""
|
||||
x, y, w, h = bbox_roi
|
||||
amount = self.DEC_PER_OPPOSITE * overlap_ratio
|
||||
dec = amount * mask_roi.astype(np.float32)
|
||||
|
||||
# 减少热力
|
||||
heat_roi = heat_map[y:y + h, x:x + w]
|
||||
heat_roi_before = heat_roi.copy()
|
||||
heat_roi_after = np.maximum(0.0, heat_roi_before - dec)
|
||||
|
||||
# 按比例减少方向强度
|
||||
eps = 1e-6
|
||||
ratio = np.ones_like(heat_roi_before)
|
||||
nonzero_mask = heat_roi_before > eps
|
||||
ratio[nonzero_mask] = heat_roi_after[nonzero_mask] / (heat_roi_before[nonzero_mask] + eps)
|
||||
|
||||
dir_x_map[y:y + h, x:x + w] *= ratio
|
||||
dir_y_map[y:y + h, x:x + w] *= ratio
|
||||
heat_map[y:y + h, x:x + w] = heat_roi_after
|
||||
|
||||
def _polygon_from_segment(self, center_prev, center_cur, veh_width_px):
|
||||
"""从两点构造矩形带"""
|
||||
x0, y0 = center_prev
|
||||
x1, y1 = center_cur
|
||||
vx = x1 - x0
|
||||
vy = y1 - y0
|
||||
dist = math.hypot(vx, vy)
|
||||
if dist == 0:
|
||||
return None
|
||||
|
||||
ux, uy = vx / dist, vy / dist
|
||||
px, py = -uy, ux
|
||||
half_w = veh_width_px / 2.0
|
||||
|
||||
p0 = (int(round(x0 + px * half_w)), int(round(y0 + py * half_w)))
|
||||
p1 = (int(round(x0 - px * half_w)), int(round(y0 - py * half_w)))
|
||||
p2 = (int(round(x1 - px * half_w)), int(round(y1 - py * half_w)))
|
||||
p3 = (int(round(x1 + px * half_w)), int(round(y1 + py * half_w)))
|
||||
|
||||
return np.array([p0, p1, p2, p3], dtype=np.int32)
|
||||
|
||||
def _polygon_mask_roi_from_pts(self, pts, img_shape):
|
||||
"""在多边形最小包围矩形内生成局部掩码"""
|
||||
if pts is None or len(pts) == 0:
|
||||
return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(pts)
|
||||
x2 = min(x + w, img_shape[1])
|
||||
y2 = min(y + h, img_shape[0])
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
|
||||
if w <= 0 or h <= 0:
|
||||
return (0, 0, 0, 0), np.zeros((0, 0), dtype=np.uint8)
|
||||
|
||||
pts_roi = pts.copy()
|
||||
pts_roi[:, 0] -= x
|
||||
pts_roi[:, 1] -= y
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts_roi], 1)
|
||||
|
||||
return (x, y, w, h), mask
|
||||
|
||||
def _region_avg_heat(self, heat_roi, mask_roi):
|
||||
"""计算区域平均热力"""
|
||||
mask = mask_roi.astype(bool)
|
||||
if mask.sum() == 0:
|
||||
return 0.0
|
||||
return float(heat_roi[mask].mean())
|
||||
|
||||
def _region_avg_direction(self, dir_x_roi, dir_y_roi, mask_roi):
|
||||
"""计算区域平均方向"""
|
||||
mask = mask_roi.astype(bool)
|
||||
if mask.sum() == 0:
|
||||
return None
|
||||
|
||||
sx = dir_x_roi[mask].sum()
|
||||
sy = dir_y_roi[mask].sum()
|
||||
vec = np.array([sx, sy], dtype=np.float32)
|
||||
norm = np.linalg.norm(vec)
|
||||
|
||||
if norm == 0:
|
||||
return None
|
||||
return vec / norm
|
||||
|
||||
def _normalize_vec(self, v):
|
||||
"""归一化向量"""
|
||||
v = np.array(v, dtype=np.float32)
|
||||
norm = np.linalg.norm(v)
|
||||
if norm == 0:
|
||||
return np.array([0.0, 0.0], dtype=np.float32)
|
||||
return v / norm
|
||||
|
||||
def _cos_between(self, a, b):
|
||||
"""计算两个向量的余弦相似度"""
|
||||
an = self._normalize_vec(a)
|
||||
bn = self._normalize_vec(b)
|
||||
return float(np.dot(an, bn))
|
||||
@@ -0,0 +1,10 @@
|
||||
from ..base_processor import BaseProcessor
|
||||
from ..pipeline import Pipeline
|
||||
from ..pipeline_data import PipelineData
|
||||
|
||||
__all__ = [
|
||||
'BaseProcessor',
|
||||
'Pipeline',
|
||||
'PipelineData',
|
||||
'ResultLogger',
|
||||
]
|
||||
@@ -0,0 +1,198 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
from yolo_gs.pipeline.base_processor import BaseProcessor
|
||||
from yolo_gs.pipeline.pipeline_data import PipelineData
|
||||
|
||||
|
||||
class GraffitiVisualizer(BaseProcessor):
|
||||
"""涂鸦可视化处理器:绘制热力图和方向箭头"""
|
||||
|
||||
def __init__(self, name: str = "涂鸦可视化处理器"):
|
||||
super().__init__(name)
|
||||
# 可调参数
|
||||
self.MAX_HEAT = 10.0
|
||||
self.ARROW_LEN_MIN = 10.0
|
||||
self.GRID_STEP = 40
|
||||
self.GRID_HEAT_THRESHOLD = 0.6
|
||||
self.GRID_DRAW_EVERY = 2
|
||||
self.FULL_DECAY_EVERY = 30
|
||||
|
||||
self.frame_idx = 0
|
||||
|
||||
def process(self, data: PipelineData) -> PipelineData:
|
||||
"""绘制涂鸦可视化"""
|
||||
if data.frame is None:
|
||||
return data
|
||||
|
||||
self.frame_idx = data.frame_idx
|
||||
vis = data.frame.copy()
|
||||
|
||||
# 获取涂鸦数据
|
||||
heat_map = data.get_data("graffiti", "heat_map")
|
||||
dir_x_map = data.get_data("graffiti", "dir_x_map")
|
||||
dir_y_map = data.get_data("graffiti", "dir_y_map")
|
||||
track_hist = data.get_data("graffiti", "track_hist")
|
||||
current_events = data.get_data("retrograde", "current_events", {})
|
||||
all_events = data.get_data("retrograde", "all_events", {})
|
||||
|
||||
# 绘制热力图
|
||||
if heat_map is not None:
|
||||
heat_norm = np.clip(heat_map / self.MAX_HEAT, 0.0, 1.0)
|
||||
heat_u8 = (heat_norm * 255).astype(np.uint8)
|
||||
red_overlay = np.zeros_like(vis)
|
||||
red_overlay[:, :, 2] = heat_u8
|
||||
alpha = 0.45
|
||||
vis = cv2.addWeighted(vis, 1.0, red_overlay, alpha, 0)
|
||||
|
||||
# 绘制网格方向箭头
|
||||
if (dir_x_map is not None and dir_y_map is not None and
|
||||
self.frame_idx % self.GRID_DRAW_EVERY == 0):
|
||||
self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map)
|
||||
|
||||
# 绘制车辆轨迹和事件
|
||||
if track_hist is not None:
|
||||
self._draw_tracks_and_events(vis, track_hist, current_events, all_events)
|
||||
|
||||
# 绘制检测框和方向
|
||||
if data.current_result is not None:
|
||||
self._draw_detections(vis, data.current_result, track_hist)
|
||||
|
||||
# 添加调试信息
|
||||
self._draw_debug_info(vis)
|
||||
|
||||
data.frame = vis
|
||||
return data
|
||||
|
||||
def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map):
|
||||
"""在网格上绘制方向箭头"""
|
||||
h, w = heat_map.shape
|
||||
for gy in range(0, h, self.GRID_STEP):
|
||||
for gx in range(0, w, self.GRID_STEP):
|
||||
x0, y0 = gx, gy
|
||||
x1 = min(gx + self.GRID_STEP, w)
|
||||
y1 = min(gy + self.GRID_STEP, h)
|
||||
|
||||
win = heat_map[y0:y1, x0:x1]
|
||||
if win.size == 0:
|
||||
continue
|
||||
|
||||
avg_heat = float(win.mean())
|
||||
if avg_heat / self.MAX_HEAT >= self.GRID_HEAT_THRESHOLD:
|
||||
avg_dx = float(dir_x_map[y0:y1, x0:x1].sum())
|
||||
avg_dy = float(dir_y_map[y0:y1, x0:x1].sum())
|
||||
|
||||
if avg_dx == 0 and avg_dy == 0:
|
||||
continue
|
||||
|
||||
cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2)
|
||||
self._draw_arrow(
|
||||
vis, cell_center,
|
||||
(avg_dx * 5.0, avg_dy * 5.0),
|
||||
color=(0, 128, 255), thickness=2
|
||||
)
|
||||
|
||||
def _draw_tracks_and_events(self, vis, track_hist, current_events, all_events):
|
||||
"""绘制车辆轨迹和事件标签"""
|
||||
for tid, hist in track_hist.items():
|
||||
if len(hist) == 0:
|
||||
continue
|
||||
|
||||
# 绘制轨迹线
|
||||
pts = np.array(hist, dtype=np.int32)
|
||||
for i in range(1, len(pts)):
|
||||
cv2.line(vis, tuple(pts[i - 1]), tuple(pts[i]), (255, 255, 255), 1)
|
||||
|
||||
# 绘制当前位置
|
||||
cx, cy = hist[-1]
|
||||
cv2.circle(vis, (int(cx), int(cy)), 3, (255, 255, 255), -1)
|
||||
|
||||
# 绘制当前事件
|
||||
if tid in current_events:
|
||||
event = current_events[tid]
|
||||
txt = f"id:{tid} {event}"
|
||||
cv2.putText(
|
||||
vis, txt, (int(cx) + 6, int(cy) - 6),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2
|
||||
)
|
||||
# 绘制近期事件(过去5帧内)
|
||||
elif tid in all_events:
|
||||
last_frame_idx, event = all_events[tid]
|
||||
if event is not None and self.frame_idx - last_frame_idx <= 5:
|
||||
txt = f"id:{tid} {event}"
|
||||
cv2.putText(
|
||||
vis, txt, (int(cx) + 6, int(cy) - 6),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2
|
||||
)
|
||||
|
||||
def _draw_detections(self, vis, result, track_hist):
|
||||
"""绘制检测框和车辆方向"""
|
||||
boxes = getattr(result, "boxes", None)
|
||||
if boxes is None:
|
||||
return
|
||||
|
||||
for box in boxes:
|
||||
try:
|
||||
xyxy = getattr(box, "xyxy", None)
|
||||
if xyxy is None:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
|
||||
track_id = int(getattr(box, "id", -1))
|
||||
cls = int(getattr(box, "cls", -1))
|
||||
|
||||
# 绘制检测框
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 255), 2)
|
||||
|
||||
# 绘制车辆ID
|
||||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
cv2.putText(
|
||||
vis, f"id:{track_id}", (cx + 6, cy - 6),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
|
||||
)
|
||||
|
||||
# 绘制车辆方向箭头
|
||||
if track_id in track_hist and len(track_hist[track_id]) >= 2:
|
||||
hist = track_hist[track_id]
|
||||
p0 = hist[-2]
|
||||
p1 = hist[-1]
|
||||
v = (p1[0] - p0[0], p1[1] - p0[1])
|
||||
self._draw_arrow(
|
||||
vis, (cx, cy), (v[0], v[1]),
|
||||
color=(0, 0, 255), thickness=2
|
||||
)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def _draw_arrow(self, img, start, vec, color=(0, 255, 0), thickness=2, tip_length=0.3):
|
||||
"""绘制箭头"""
|
||||
sx, sy = int(round(start[0])), int(round(start[1]))
|
||||
vx, vy = float(vec[0]), float(vec[1])
|
||||
norm = math.hypot(vx, vy)
|
||||
|
||||
if norm < 1e-3:
|
||||
return
|
||||
|
||||
scale = 1.0
|
||||
if norm < self.ARROW_LEN_MIN:
|
||||
scale = self.ARROW_LEN_MIN / (norm + 1e-6)
|
||||
|
||||
ex = int(round(sx + vx * scale))
|
||||
ey = int(round(sy + vy * scale))
|
||||
|
||||
cv2.arrowedLine(
|
||||
img, (sx, sy), (ex, ey),
|
||||
color, thickness, tipLength=tip_length
|
||||
)
|
||||
|
||||
def _draw_debug_info(self, vis):
|
||||
"""绘制调试信息"""
|
||||
cv2.putText(
|
||||
vis, f"Frame:{self.frame_idx}", (10, 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2
|
||||
)
|
||||
cv2.putText(
|
||||
vis, f"HeatMax:{self.MAX_HEAT} TrustThr:3.0", (10, 45),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
测试新创建的高级处理器
|
||||
"""
|
||||
from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor
|
||||
from pipeline.handler.enhanced_reverse_direction_processor import EnhancedRetrogradeProcessor
|
||||
|
||||
|
||||
def test_advanced_processors():
|
||||
print("测试高级处理器...")
|
||||
|
||||
# 测试AdvancedHeatmapProcessor
|
||||
advanced_heatmap_proc = AdvancedHeatmapProcessor()
|
||||
print(f"AdvancedHeatmapProcessor创建成功: {advanced_heatmap_proc.name}")
|
||||
|
||||
# 测试EnhancedRetrogradeProcessor
|
||||
enhanced_retrograde_proc = EnhancedRetrogradeProcessor()
|
||||
print(f"EnhancedRetrogradeProcessor创建成功: {enhanced_retrograde_proc.name}")
|
||||
|
||||
print("所有高级处理器测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_advanced_processors()
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
测试处理器兼容性
|
||||
"""
|
||||
from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor
|
||||
from pipeline.handler.draw_heatmap_box_processor import DrawHeatMapBoxProcessor
|
||||
from pipeline.pipeline_data import PipelineData
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def test_compatibility():
|
||||
print("测试处理器兼容性...")
|
||||
|
||||
# 创建处理器
|
||||
heatmap_proc = AdvancedHeatmapProcessor()
|
||||
draw_proc = DrawHeatMapBoxProcessor()
|
||||
|
||||
print(f"处理器创建成功: {heatmap_proc.name}, {draw_proc.name}")
|
||||
|
||||
# 创建模拟数据
|
||||
data = PipelineData()
|
||||
# 创建一个模拟帧
|
||||
data.frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
data.frame_idx = 1
|
||||
|
||||
# 由于我们无法创建真实的YOLO结果,我们跳过处理步骤,主要验证类定义和依赖
|
||||
print("处理器兼容性测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compatibility()
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
测试修改后的处理器
|
||||
"""
|
||||
from pipeline.handler.advanced_heatmap_processor import AdvancedHeatmapProcessor
|
||||
from pipeline.handler.draw_heatmap_box_processor import DrawHeatMapBoxProcessor
|
||||
from pipeline.handler.draw_id_box_processor import DrawIdBoxProcessor
|
||||
from pipeline.handler.draw_motion_vector_processor import DrawMotionVectorProcessor
|
||||
|
||||
|
||||
def test_modifications():
|
||||
print("测试修改后的处理器...")
|
||||
|
||||
# 测试所有处理器是否都能正确创建
|
||||
proc1 = AdvancedHeatmapProcessor()
|
||||
proc2 = DrawHeatMapBoxProcessor()
|
||||
proc3 = DrawIdBoxProcessor()
|
||||
proc4 = DrawMotionVectorProcessor()
|
||||
|
||||
print('所有处理器创建成功')
|
||||
print(f'高级热力图处理器: {proc1.name}')
|
||||
print(f'热力图框绘制处理器: {proc2.name}')
|
||||
print(f'ID框绘制处理器: {proc3.name}')
|
||||
print(f'运动向量绘制处理器: {proc4.name}')
|
||||
|
||||
print("修改验证完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_modifications()
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
测试更新后的系统,验证移除旧处理器后的功能
|
||||
"""
|
||||
from pipeline.handler import (
|
||||
AdvancedHeatmapProcessor,
|
||||
EnhancedRetrogradeProcessor,
|
||||
DrawIdBoxProcessor,
|
||||
DrawHeatMapBoxProcessor,
|
||||
DrawMotionVectorProcessor,
|
||||
ResultLogger,
|
||||
BoxFilter
|
||||
)
|
||||
|
||||
def test_updated_system():
|
||||
print("测试更新后的系统...")
|
||||
|
||||
# 测试所有处理器是否都能正确创建
|
||||
proc1 = AdvancedHeatmapProcessor()
|
||||
proc2 = EnhancedRetrogradeProcessor()
|
||||
proc3 = DrawIdBoxProcessor()
|
||||
proc4 = DrawHeatMapBoxProcessor()
|
||||
proc5 = DrawMotionVectorProcessor()
|
||||
proc6 = ResultLogger("TestLogger")
|
||||
proc7 = BoxFilter("TestFilter")
|
||||
|
||||
print('所有处理器创建成功')
|
||||
print(f'高级热力图处理器: {proc1.name}')
|
||||
print(f'增强型逆行检测处理器: {proc2.name}')
|
||||
print(f'ID框绘制处理器: {proc3.name}')
|
||||
print(f'热力图框绘制处理器: {proc4.name}')
|
||||
print(f'运动向量绘制处理器: {proc5.name}')
|
||||
print(f'日志处理器: {proc6.name}')
|
||||
print(f'过滤器: {proc7.name}')
|
||||
|
||||
print("系统更新测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_updated_system()
|
||||
Reference in New Issue
Block a user