Files
HighwayEventDet/pipeline/handler/graffiti_visualizer.py
T
zimoyin 23555e0cc9 feat(pipeline): 添加逆行处理器
- 新增 DrawDirectionProcessor 用于绘制车辆行驶方向和轨迹
- 新增 DrawGraffitiProcessor 用于绘制热力图和网格方向箭头
- 新增 DrawObjectBoxProcessor 用于绘制检测框、ID、标签和置信度
- 新增 GraffitiVisualizer 用于涂鸦可视化处理
- 新增 GraffitiProcessor 用于计算车辆轨迹并更新车道涂鸦
2026-01-10 09:41:18 +08:00

198 lines
7.0 KiB
Python

import numpy as np
import cv2
import math
from yolo_gs.pipeline.base_processor import BaseProcessor
from yolo_gs.pipeline.pipeline_data import PipelineData
class GraffitiVisualizer(BaseProcessor):
"""涂鸦可视化处理器:绘制热力图和方向箭头"""
def __init__(self, name: str = "涂鸦可视化处理器"):
super().__init__(name)
# 可调参数
self.MAX_HEAT = 10.0
self.ARROW_LEN_MIN = 10.0
self.GRID_STEP = 40
self.GRID_HEAT_THRESHOLD = 0.6
self.GRID_DRAW_EVERY = 2
self.FULL_DECAY_EVERY = 30
self.frame_idx = 0
def process(self, data: PipelineData) -> PipelineData:
"""绘制涂鸦可视化"""
if data.frame is None:
return data
self.frame_idx = data.frame_idx
vis = data.frame.copy()
# 获取涂鸦数据
heat_map = data.get_data("graffiti", "heat_map")
dir_x_map = data.get_data("graffiti", "dir_x_map")
dir_y_map = data.get_data("graffiti", "dir_y_map")
track_hist = data.get_data("graffiti", "track_hist")
current_events = data.get_data("retrograde", "current_events", {})
all_events = data.get_data("retrograde", "all_events", {})
# 绘制热力图
if heat_map is not None:
heat_norm = np.clip(heat_map / self.MAX_HEAT, 0.0, 1.0)
heat_u8 = (heat_norm * 255).astype(np.uint8)
red_overlay = np.zeros_like(vis)
red_overlay[:, :, 2] = heat_u8
alpha = 0.45
vis = cv2.addWeighted(vis, 1.0, red_overlay, alpha, 0)
# 绘制网格方向箭头
if (dir_x_map is not None and dir_y_map is not None and
self.frame_idx % self.GRID_DRAW_EVERY == 0):
self._draw_grid_arrows(vis, heat_map, dir_x_map, dir_y_map)
# 绘制车辆轨迹和事件
if track_hist is not None:
self._draw_tracks_and_events(vis, track_hist, current_events, all_events)
# 绘制检测框和方向
if data.current_result is not None:
self._draw_detections(vis, data.current_result, track_hist)
# 添加调试信息
self._draw_debug_info(vis)
data.frame = vis
return data
def _draw_grid_arrows(self, vis, heat_map, dir_x_map, dir_y_map):
"""在网格上绘制方向箭头"""
h, w = heat_map.shape
for gy in range(0, h, self.GRID_STEP):
for gx in range(0, w, self.GRID_STEP):
x0, y0 = gx, gy
x1 = min(gx + self.GRID_STEP, w)
y1 = min(gy + self.GRID_STEP, h)
win = heat_map[y0:y1, x0:x1]
if win.size == 0:
continue
avg_heat = float(win.mean())
if avg_heat / self.MAX_HEAT >= self.GRID_HEAT_THRESHOLD:
avg_dx = float(dir_x_map[y0:y1, x0:x1].sum())
avg_dy = float(dir_y_map[y0:y1, x0:x1].sum())
if avg_dx == 0 and avg_dy == 0:
continue
cell_center = (x0 + (x1 - x0) // 2, y0 + (y1 - y0) // 2)
self._draw_arrow(
vis, cell_center,
(avg_dx * 5.0, avg_dy * 5.0),
color=(0, 128, 255), thickness=2
)
def _draw_tracks_and_events(self, vis, track_hist, current_events, all_events):
"""绘制车辆轨迹和事件标签"""
for tid, hist in track_hist.items():
if len(hist) == 0:
continue
# 绘制轨迹线
pts = np.array(hist, dtype=np.int32)
for i in range(1, len(pts)):
cv2.line(vis, tuple(pts[i - 1]), tuple(pts[i]), (255, 255, 255), 1)
# 绘制当前位置
cx, cy = hist[-1]
cv2.circle(vis, (int(cx), int(cy)), 3, (255, 255, 255), -1)
# 绘制当前事件
if tid in current_events:
event = current_events[tid]
txt = f"id:{tid} {event}"
cv2.putText(
vis, txt, (int(cx) + 6, int(cy) - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2
)
# 绘制近期事件(过去5帧内)
elif tid in all_events:
last_frame_idx, event = all_events[tid]
if event is not None and self.frame_idx - last_frame_idx <= 5:
txt = f"id:{tid} {event}"
cv2.putText(
vis, txt, (int(cx) + 6, int(cy) - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2
)
def _draw_detections(self, vis, result, track_hist):
"""绘制检测框和车辆方向"""
boxes = getattr(result, "boxes", None)
if boxes is None:
return
for box in boxes:
try:
xyxy = getattr(box, "xyxy", None)
if xyxy is None:
continue
x1, y1, x2, y2 = map(int, map(float, xyxy[0].tolist()))
track_id = int(getattr(box, "id", -1))
cls = int(getattr(box, "cls", -1))
# 绘制检测框
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 255), 2)
# 绘制车辆ID
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
cv2.putText(
vis, f"id:{track_id}", (cx + 6, cy - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
)
# 绘制车辆方向箭头
if track_id in track_hist and len(track_hist[track_id]) >= 2:
hist = track_hist[track_id]
p0 = hist[-2]
p1 = hist[-1]
v = (p1[0] - p0[0], p1[1] - p0[1])
self._draw_arrow(
vis, (cx, cy), (v[0], v[1]),
color=(0, 0, 255), thickness=2
)
except Exception:
continue
def _draw_arrow(self, img, start, vec, color=(0, 255, 0), thickness=2, tip_length=0.3):
"""绘制箭头"""
sx, sy = int(round(start[0])), int(round(start[1]))
vx, vy = float(vec[0]), float(vec[1])
norm = math.hypot(vx, vy)
if norm < 1e-3:
return
scale = 1.0
if norm < self.ARROW_LEN_MIN:
scale = self.ARROW_LEN_MIN / (norm + 1e-6)
ex = int(round(sx + vx * scale))
ey = int(round(sy + vy * scale))
cv2.arrowedLine(
img, (sx, sy), (ex, ey),
color, thickness, tipLength=tip_length
)
def _draw_debug_info(self, vis):
"""绘制调试信息"""
cv2.putText(
vis, f"Frame:{self.frame_idx}", (10, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2
)
cv2.putText(
vis, f"HeatMax:{self.MAX_HEAT} TrustThr:3.0", (10, 45),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1
)