23555e0cc9
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹 - 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头 - 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度 - 新增 GraffitiVisualizer 用于涂鸦可视化处理 - 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
198 lines
6.5 KiB
Python
198 lines
6.5 KiB
Python
import cv2
|
|
import numpy as np
|
|
import math
|
|
from yolo_gs.pipeline.base_processor import BaseProcessor
|
|
from yolo_gs.pipeline.pipeline_data import PipelineData
|
|
|
|
|
|
class DrawGraffitiProcessor(BaseProcessor):
|
|
"""涂鸦绘制处理器:绘制热力图和网格方向箭头"""
|
|
|
|
def __init__(self,
|
|
name: str = "涂鸦绘制处理器",
|
|
heatmap_opacity: float = 0.45,
|
|
grid_step: int = 40,
|
|
grid_heat_threshold: float = 0.6,
|
|
grid_draw_every: int = 2,
|
|
max_heat: float = 10.0):
|
|
super().__init__(name)
|
|
self.heatmap_opacity = heatmap_opacity
|
|
self.grid_step = grid_step
|
|
self.grid_heat_threshold = grid_heat_threshold
|
|
self.grid_draw_every = grid_draw_every
|
|
self.max_heat = max_heat
|
|
|
|
# 箭头参数
|
|
self.arrow_length_min = 10.0
|
|
self.grid_arrow_color = (0, 128, 255) # 橙色 - 网格箭头
|
|
|
|
# 帧计数器
|
|
self.frame_idx = 0
|
|
|
|
def process(self, data: PipelineData) -> PipelineData:
|
|
"""绘制热力图和网格方向箭头"""
|
|
if data.frame is None:
|
|
return data
|
|
|
|
self.frame_idx = data.frame_idx
|
|
vis = data.frame.copy()
|
|
|
|
# 获取涂鸦数据
|
|
heat_map = data.get_data("graffiti", "heat_map")
|
|
dir_x_map = data.get_data("graffiti", "dir_x_map")
|
|
dir_y_map = data.get_data("graffiti", "dir_y_map")
|
|
|
|
# 绘制热力图
|
|
if heat_map is not None:
|
|
vis = self._draw_heatmap(vis, heat_map)
|
|
|
|
# 绘制网格方向箭头
|
|
if (dir_x_map is not None and dir_y_map is not None and
|
|
self.frame_idx % self.grid_draw_every == 0):
|
|
vis = self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map)
|
|
|
|
# TODO 绘制调试信息,这个应该是单独的处理器
|
|
vis = self._draw_debug_info(vis)
|
|
|
|
data.frame = vis
|
|
return data
|
|
|
|
def _draw_heatmap(self, vis, heat_map):
|
|
"""绘制热力图叠加层"""
|
|
# 归一化热力图
|
|
heat_norm = np.clip(heat_map / self.max_heat, 0.0, 1.0)
|
|
heat_u8 = (heat_norm * 255).astype(np.uint8)
|
|
|
|
# 创建红色通道叠加层
|
|
red_overlay = np.zeros_like(vis)
|
|
red_overlay[:, :, 2] = heat_u8 # 红色通道
|
|
|
|
# 叠加热力图
|
|
vis = cv2.addWeighted(vis, 1.0, red_overlay, self.heatmap_opacity, 0)
|
|
|
|
return vis
|
|
|
|
def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map):
|
|
"""在网格上绘制持久方向箭头"""
|
|
h, w = heat_map.shape
|
|
|
|
for gy in range(0, h, self.grid_step):
|
|
for gx in range(0, w, self.grid_step):
|
|
x0, y0 = gx, gy
|
|
x1 = min(gx + self.grid_step, w)
|
|
y1 = min(gy + self.grid_step, h)
|
|
|
|
# 检查网格区域热力
|
|
heat_window = heat_map[y0:y1, x0:x1]
|
|
if heat_window.size == 0:
|
|
continue
|
|
|
|
avg_heat = float(heat_window.mean())
|
|
|
|
# 如果平均热力达到阈值,绘制方向箭头
|
|
if avg_heat / self.max_heat >= self.grid_heat_threshold:
|
|
avg_dx = float(dir_x_map[y0:y1, x0:x1].sum())
|
|
avg_dy = float(dir_y_map[y0:y1, x0:x1].sum())
|
|
|
|
if avg_dx == 0 and avg_dy == 0:
|
|
continue
|
|
|
|
# 计算网格中心点
|
|
cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2)
|
|
|
|
# 绘制网格边界(可选,用于调试)
|
|
# cv2.rectangle(vis, (x0, y0), (x1, y1), (100, 100, 100), 1)
|
|
|
|
# 绘制方向箭头
|
|
self._draw_arrow(vis, cell_center,
|
|
(avg_dx * 5.0, avg_dy * 5.0))
|
|
|
|
return vis
|
|
|
|
def _draw_arrow(self, img, start, vec, color=None, thickness=2):
|
|
"""绘制箭头"""
|
|
if color is None:
|
|
color = self.grid_arrow_color
|
|
|
|
sx, sy = int(round(start[0])), int(round(start[1]))
|
|
vx, vy = float(vec[0]), float(vec[1])
|
|
|
|
norm = math.hypot(vx, vy)
|
|
if norm < 1e-3:
|
|
return
|
|
|
|
# 确保箭头有最小长度
|
|
scale = 1.0
|
|
if norm < self.arrow_length_min:
|
|
scale = self.arrow_length_min / (norm + 1e-6)
|
|
|
|
ex = int(round(sx + vx * scale))
|
|
ey = int(round(sy + vy * scale))
|
|
|
|
cv2.arrowedLine(img,
|
|
(sx, sy),
|
|
(ex, ey),
|
|
color, thickness,
|
|
tipLength=0.3)
|
|
|
|
def _draw_debug_info(self, vis):
|
|
"""绘制调试信息"""
|
|
# 绘制帧号
|
|
cv2.putText(vis,
|
|
f"Frame:{self.frame_idx}",
|
|
(10, 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
|
(200, 200, 200), 2)
|
|
|
|
# 绘制热力图参数
|
|
cv2.putText(vis,
|
|
f"HeatMax:{self.max_heat} TrustThr:3.0",
|
|
(10, 45),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
(200, 200, 200), 1)
|
|
|
|
# 绘制网格参数
|
|
cv2.putText(vis,
|
|
f"Grid:{self.grid_step}px Thresh:{self.grid_heat_threshold}",
|
|
(10, 65),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
(200, 200, 200), 1)
|
|
|
|
# 绘制热力图图例
|
|
legend_y = vis.shape[0] - 30
|
|
legend_x = 10
|
|
|
|
# 图例标题
|
|
cv2.putText(vis, "Heat Legend:",
|
|
(legend_x, legend_y),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
(200, 200, 200), 1)
|
|
|
|
# 绘制颜色条
|
|
colorbar_width = 100
|
|
colorbar_height = 10
|
|
colorbar_x = legend_x + 80
|
|
colorbar_y = legend_y - 5
|
|
|
|
# 创建渐变颜色条
|
|
for i in range(colorbar_width):
|
|
intensity = i / colorbar_width
|
|
color = (0, 0, int(intensity * 255))
|
|
cv2.line(vis,
|
|
(colorbar_x + i, colorbar_y),
|
|
(colorbar_x + i, colorbar_y + colorbar_height),
|
|
color, 1)
|
|
|
|
# 标注最小值
|
|
cv2.putText(vis, "0",
|
|
(colorbar_x - 10, colorbar_y + 15),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
|
(200, 200, 200), 1)
|
|
|
|
# 标注最大值
|
|
cv2.putText(vis, f"{self.max_heat}",
|
|
(colorbar_x + colorbar_width - 10, colorbar_y + 15),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
|
(200, 200, 200), 1)
|
|
|
|
return vis |