258 lines
9.2 KiB
Python
258 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
本地图片 → Maix YOLOv5 检测 → 画框保存(用于核对坐标 mode / 多框 union)。
|
||
|
||
运行环境:MaixCAM / MaixPy(需 maix.image / maix.nn),在项目根或任意目录执行均可。
|
||
|
||
示例:
|
||
python test/test_yolo_draw_boxes.py /root/phot/shot_xxx.jpg
|
||
python test/test_yolo_draw_boxes.py shot.jpg --loader cv2_rgb --conf 0.25
|
||
python test/test_yolo_draw_boxes.py shot.jpg --debug
|
||
python -h # 查看 --loader / --debug / --union 等全部参数
|
||
|
||
脚本版本(与设备同步用):20260206-yolo-vis
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
|
||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
if _ROOT not in sys.path:
|
||
sys.path.insert(0, _ROOT)
|
||
|
||
|
||
def _load_maix_image(path: str, image_mod):
|
||
"""maix.image.load(部分 JPEG 解码后与 camera.read() 像素布局不一致,可能导致 NPU 全空)。"""
|
||
return image_mod.load(path)
|
||
|
||
|
||
def _load_cv2_rgb_as_maix(path: str, image_mod):
|
||
"""
|
||
OpenCV 读盘为 BGR → 转 RGB → 与 shoot_manager 里 image2cv 逆过程一致,供 YOLO input type: rgb。
|
||
"""
|
||
import cv2
|
||
|
||
arr = cv2.imread(path, cv2.IMREAD_COLOR)
|
||
if arr is None:
|
||
raise FileNotFoundError(f"cv2.imread 失败: {path}")
|
||
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
||
return image_mod.cv2image(arr, False, False)
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(
|
||
description="YOLO 画框测试(Maix)",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="若提示 unrecognized arguments: --debug,说明设备上脚本未更新,请同步仓库中的 test/test_yolo_draw_boxes.py",
|
||
)
|
||
ap.add_argument("image", help="输入图片路径")
|
||
ap.add_argument("-o", "--output", default="", help="输出图片路径;默认 原名_yolo_vis.jpg")
|
||
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_YOLO_MODEL_PATH")
|
||
ap.add_argument("--conf", type=float, default=None, help="置信度阈值")
|
||
ap.add_argument("--iou", type=float, default=None, help="NMS IoU")
|
||
ap.add_argument(
|
||
"--coord",
|
||
choices=["native", "letterbox"],
|
||
default="",
|
||
help="坐标映射;默认读 config.TRIANGLE_YOLO_COORD_MODE",
|
||
)
|
||
ap.add_argument(
|
||
"--union",
|
||
action="store_true",
|
||
help="按 TRIANGLE_YOLO_RING_CLASS_IDS 过滤后画合并外接矩形(与线上 ROI merge=union 一致)",
|
||
)
|
||
ap.add_argument(
|
||
"--loader",
|
||
choices=["auto", "maix", "cv2_rgb"],
|
||
default="auto",
|
||
help="auto: 先 maix.load,0 框则改用 cv2 RGB(推荐排查「有图但始终 0 框」)",
|
||
)
|
||
ap.add_argument(
|
||
"--debug",
|
||
action="store_true",
|
||
help="打印 detect 原始返回类型与 repr(截断)",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
try:
|
||
from maix import image, nn
|
||
except ImportError:
|
||
print("[ERR] 需要 MaixPy(maix.image / maix.nn),请在 MaixCAM 上运行。")
|
||
sys.exit(1)
|
||
|
||
import config as cfg
|
||
import target_roi_yolo as yroi
|
||
|
||
img_path = os.path.abspath(args.image)
|
||
if not os.path.isfile(img_path):
|
||
print(f"[ERR] 找不到图片: {img_path}")
|
||
sys.exit(1)
|
||
|
||
model_path = (args.model or getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or "").strip()
|
||
if not os.path.isfile(model_path):
|
||
print(f"[ERR] 模型文件不存在: {model_path}")
|
||
sys.exit(1)
|
||
|
||
conf_th = (
|
||
float(args.conf)
|
||
if args.conf is not None
|
||
else float(getattr(cfg, "TRIANGLE_YOLO_CONF_TH", 0.5))
|
||
)
|
||
iou_th = (
|
||
float(args.iou)
|
||
if args.iou is not None
|
||
else float(getattr(cfg, "TRIANGLE_YOLO_IOU_TH", 0.45))
|
||
)
|
||
coord_mode = (args.coord or getattr(cfg, "TRIANGLE_YOLO_COORD_MODE", "native")).lower()
|
||
|
||
out_path = args.output.strip()
|
||
if not out_path:
|
||
base, ext = os.path.splitext(img_path)
|
||
ext = ext if ext else ".jpg"
|
||
out_path = base + "_yolo_vis" + ext
|
||
|
||
det = nn.YOLOv5(model=model_path, dual_buff=False)
|
||
net_w = int(det.input_width())
|
||
net_h = int(det.input_height())
|
||
|
||
def _run_detect(maix_img, tag: str):
|
||
r = det.detect(maix_img, conf_th=conf_th, iou_th=iou_th)
|
||
if args.debug:
|
||
rlen = len(r) if r is not None and hasattr(r, "__len__") else "n/a"
|
||
rrepr = repr(r)
|
||
if len(rrepr) > 300:
|
||
rrepr = rrepr[:300] + "..."
|
||
print(f"[DEBUG] loader={tag} raw_type={type(r)} len={rlen} repr={rrepr}")
|
||
return yroi._normalize_objs(r if r is not None else []), maix_img, tag
|
||
|
||
img = None
|
||
load_tag = ""
|
||
objs = []
|
||
|
||
if args.loader == "cv2_rgb":
|
||
img = _load_cv2_rgb_as_maix(img_path, image)
|
||
load_tag = "cv2_rgb"
|
||
objs, img, load_tag = _run_detect(img, load_tag)
|
||
elif args.loader == "maix":
|
||
img = _load_maix_image(img_path, image)
|
||
load_tag = "maix_load"
|
||
objs, img, load_tag = _run_detect(img, load_tag)
|
||
else:
|
||
# auto
|
||
img = _load_maix_image(img_path, image)
|
||
load_tag = "maix_load"
|
||
objs, img, load_tag = _run_detect(img, load_tag)
|
||
if len(objs) == 0:
|
||
print(
|
||
"[WARN] maix.image.load 在 conf_th=%s 下仍为 0 框,改用 cv2 BGR→RGB→cv2image 重试(常见可恢复)"
|
||
% conf_th
|
||
)
|
||
img2 = _load_cv2_rgb_as_maix(img_path, image)
|
||
objs, img, load_tag = _run_detect(img2, "cv2_rgb_retry")
|
||
|
||
src_w, src_h = img.width(), img.height()
|
||
|
||
labels = getattr(det, "labels", None)
|
||
|
||
def _label(cid: int) -> str:
|
||
if labels is None:
|
||
return str(cid)
|
||
try:
|
||
return str(labels[int(cid)])
|
||
except Exception:
|
||
return str(cid)
|
||
|
||
print(
|
||
f"[INFO] loader={load_tag} image={src_w}×{src_h}, net_in={net_w}×{net_h}, "
|
||
f"coord={coord_mode}, conf_th={conf_th}, iou_th={iou_th}"
|
||
)
|
||
print(f"[INFO] NMS 后检测框数量={len(objs)} → {out_path}")
|
||
if len(objs) == 0:
|
||
print(
|
||
"[HINT] 仍为 0 框时常见原因:\n"
|
||
" 1) 强制 cv2 路径: --loader cv2_rgb\n"
|
||
" 2) NMS 过严: --iou 0.95\n"
|
||
" 3) 图与训练分布差太大 / 模型未见过该场景\n"
|
||
" 4) 用 camera.read() 一帧存盘再测,对比 file 与实时是否一致"
|
||
)
|
||
|
||
# 颜色:按类别轮换(仅有 COLOR_* 时常量时用)
|
||
color_cycle = []
|
||
for name in ("RED", "GREEN", "BLUE", "ORANGE", "YELLOW", "CYAN", "MAGENTA"):
|
||
c = getattr(image, f"COLOR_{name}", None)
|
||
if c is not None:
|
||
color_cycle.append(c)
|
||
if not color_cycle:
|
||
color_cycle = [getattr(image, "COLOR_RED", 0)]
|
||
|
||
for i, o in enumerate(objs):
|
||
cid = yroi._det_obj_class_id(o)
|
||
if cid is None:
|
||
cid = -1
|
||
try:
|
||
sc = float(o.score)
|
||
except Exception:
|
||
sc = 0.0
|
||
x0, y0, x1, y1 = yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h)
|
||
ix = int(max(0, min(x0, src_w - 1)))
|
||
iy = int(max(0, min(y0, src_h - 1)))
|
||
iw = int(max(1, min(x1 - x0, src_w - ix)))
|
||
ih = int(max(1, min(y1 - y0, src_h - iy)))
|
||
col = color_cycle[cid % len(color_cycle)] if cid >= 0 else color_cycle[0]
|
||
img.draw_rect(ix, iy, iw, ih, color=col)
|
||
ty = max(0, iy - 14)
|
||
msg = f"{_label(cid)} {sc:.2f}"
|
||
img.draw_string(ix, ty, msg, color=col)
|
||
print(f" #{i} cls={cid} {_label(cid)} score={sc:.3f} xywh=({ix},{iy},{iw},{ih})")
|
||
|
||
if args.union:
|
||
class_ids = getattr(cfg, "TRIANGLE_YOLO_RING_CLASS_IDS", (0,))
|
||
if isinstance(class_ids, int):
|
||
class_ids = (class_ids,)
|
||
cand = [o for o in objs if yroi._det_obj_class_id(o) in class_ids]
|
||
if cand:
|
||
xy_list = [
|
||
yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h) for o in cand
|
||
]
|
||
merged = yroi._merge_roi_xyxy(xy_list, "union")
|
||
if merged:
|
||
mx0, my0, mx1, my1 = merged
|
||
mx0 = max(0, min(mx0, src_w - 1))
|
||
my0 = max(0, min(my0, src_h - 1))
|
||
mx1 = max(mx0 + 1, min(mx1, src_w))
|
||
my1 = max(my0 + 1, min(my1, src_h))
|
||
uw, uh = int(mx1 - mx0), int(my1 - my0)
|
||
ucol = getattr(image, "COLOR_GREEN", color_cycle[0])
|
||
# 画粗一点的 union:描两遍错位矩形简易模拟加粗
|
||
for d in (0, 2):
|
||
img.draw_rect(
|
||
int(mx0) - d,
|
||
int(my0) - d,
|
||
uw + 2 * d,
|
||
uh + 2 * d,
|
||
color=ucol,
|
||
)
|
||
img.draw_string(
|
||
int(mx0),
|
||
max(0, int(my0) - 28),
|
||
f"UNION ({len(cand)} boxes)",
|
||
color=ucol,
|
||
)
|
||
print(f"[INFO] UNION [{int(mx0)},{int(my0)},{int(mx1)},{int(my1)}] from {len(cand)} boxes")
|
||
else:
|
||
print("[WARN] --union 但 RING_CLASS_IDS 过滤后无框")
|
||
|
||
try:
|
||
img.save(out_path, quality=95)
|
||
except TypeError:
|
||
img.save(out_path)
|
||
print(f"[OK] saved: {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|