增加训练yolo的代码
This commit is contained in:
1005
train_yolo/synth_compose_yolo.py
Normal file
1005
train_yolo/synth_compose_yolo.py
Normal file
File diff suppressed because it is too large
Load Diff
343
train_yolo/test_stage2_black_yolo_device.py
Normal file
343
train_yolo/test_stage2_black_yolo_device.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Stage2 黑三角 YOLO —— 在 Maix 设备上用本地图片测试(与线上 target_roi_yolo.try_black_triangle_boxes_work 完全一致)。
|
||||
|
||||
不在 PC 上跑 NPU;需把脚本与 config / target_roi_yolo.py 同步到设备,并在设备上执行。
|
||||
|
||||
典型用法
|
||||
--------
|
||||
# 输入已是 Stage1 裁切(与你保存的 stage2_roi_*.jpg 一致)
|
||||
python test/test_stage2_black_yolo_device.py /root/phot/stage2_roi_xxx.jpg
|
||||
|
||||
# 输入为整幅相机图,手动给出 Stage1 环靶 ROI(与线上日志 ring全图=[rx0,ry0,rx1,ry1] 一致)
|
||||
python test/test_stage2_black_yolo_device.py /root/phot/full.jpg --roi 197,196,507,461
|
||||
|
||||
# 对比 native / letterbox 坐标映射(排查 contain 训练与推理对齐)
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --compare-coord
|
||||
|
||||
# 覆盖置信度、模型路径(仍读其余项自 config)
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.25 -m /maixapp/apps/t11/model_270648.mud
|
||||
|
||||
# 只看 NPU 原始框(映射前):判断坐标是 ~224 网络空间还是归一化 0~1
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.05 --dump-raw 15
|
||||
|
||||
依赖:MaixPy(maix.nn)、OpenCV(cv2)、numpy;项目根须在 sys.path(本脚本已插入上级目录)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
|
||||
def _parse_roi(s: str) -> tuple[int, int, int, int]:
|
||||
parts = [p.strip() for p in s.replace(" ", "").split(",")]
|
||||
if len(parts) != 4:
|
||||
raise ValueError("ROI 需要 4 个整数:x0,y0,x1,y1")
|
||||
return tuple(int(x) for x in parts) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _load_rgb_numpy(path: str) -> "object":
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
bgr = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise FileNotFoundError(f"cv2.imread 失败: {path}")
|
||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
return np.ascontiguousarray(rgb, dtype=np.uint8)
|
||||
|
||||
|
||||
def _draw_boxes_on_crop(
|
||||
slab_rgb,
|
||||
boxes: list[tuple[int, int, int, int]],
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
"""slab_rgb: H×W×3 RGB uint8;boxes 为扩 margin 后的 Stage2 子框(与线上绿框一致)。"""
|
||||
import cv2
|
||||
|
||||
vis = slab_rgb.copy()
|
||||
bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
|
||||
rh, rw = bgr.shape[:2]
|
||||
for i, (bx0, by0, bx1, by1) in enumerate(boxes):
|
||||
x0, y0 = int(bx0), int(by0)
|
||||
x1, y1 = int(bx1) - 1, int(by1) - 1
|
||||
x1 = max(x0, min(x1, rw - 1))
|
||||
y1 = max(y0, min(y1, rh - 1))
|
||||
cv2.rectangle(bgr, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
||||
tag = labels[i] if labels and i < len(labels) else f"s2_{i}"
|
||||
cv2.putText(
|
||||
bgr,
|
||||
tag,
|
||||
(x0, max(0, y0 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
return bgr
|
||||
|
||||
|
||||
class _PrintLogger:
|
||||
def info(self, msg):
|
||||
print(msg)
|
||||
|
||||
def warning(self, msg):
|
||||
print(msg)
|
||||
|
||||
def error(self, msg):
|
||||
print(msg)
|
||||
|
||||
|
||||
def _run_once(yroi_mod, img_rgb, roi_xyxy, logger):
|
||||
boxes = yroi_mod.try_black_triangle_boxes_work(img_rgb, roi_xyxy, logger)
|
||||
rx0, ry0, rx1, ry1 = roi_xyxy
|
||||
slab = img_rgb[ry0:ry1, rx0:rx1].copy()
|
||||
return boxes, slab
|
||||
|
||||
|
||||
def _copy_dump_raw_rows(yroi_mod, objs):
|
||||
"""把 Maix detect 返回对象拷贝成基础类型,避免 native 对象跨下一次 detect 存活。"""
|
||||
rows = []
|
||||
for o in objs:
|
||||
cid = yroi_mod._det_obj_class_id(o)
|
||||
try:
|
||||
sc = float(getattr(o, "score", 0.0))
|
||||
except (TypeError, ValueError):
|
||||
sc = 0.0
|
||||
rows.append((cid, sc, float(o.x), float(o.y), float(o.w), float(o.h)))
|
||||
return rows
|
||||
|
||||
|
||||
def _dump_raw_and_hard_exit(det, yroi_mod, slab_for_det, rw_s, rh_s, net_w, net_h, conf_th, iou_th, limit):
|
||||
"""
|
||||
MaixPy 某些版本在 YOLO detect 返回对象正常析构时会 SIGSEGV/pure virtual。
|
||||
raw dump 是诊断路径,打印完成后硬退出,绕过 Python/native 析构链。
|
||||
"""
|
||||
from maix import image as maix_image
|
||||
|
||||
roi_maix = maix_image.cv2image(slab_for_det, False, False)
|
||||
raw = det.detect(roi_maix, conf_th=conf_th, iou_th=iou_th)
|
||||
objs = yroi_mod._normalize_objs(raw if raw is not None else [])
|
||||
dump_rows = _copy_dump_raw_rows(yroi_mod, objs)
|
||||
raw_count = len(dump_rows)
|
||||
print(
|
||||
f"[DUMP-RAW] slab={rw_s}×{rh_s} net={net_w}×{net_h} "
|
||||
f"conf={conf_th} iou={iou_th} → NMS 后 raw 框数={raw_count}(与 coord_mode 无关)"
|
||||
)
|
||||
npr = min(int(limit), raw_count)
|
||||
for i in range(npr):
|
||||
cid, sc, x, y, ww, hh = dump_rows[i]
|
||||
print(f" #{i} cls={cid} score={sc:.4f} xywh=({x:.3f},{y:.3f},{ww:.3f},{hh:.3f})")
|
||||
if dump_rows:
|
||||
xs = [r[2] for r in dump_rows]
|
||||
ws = [r[4] for r in dump_rows]
|
||||
print(
|
||||
f"[DUMP-RAW] hint: x 范围≈[{min(xs):.2f},{max(xs):.2f}] "
|
||||
f"w 范围≈[{min(ws):.2f},{max(ws):.2f}] — "
|
||||
f"若整体在 0~{net_w} 量级多为网络画布坐标→应用 letterbox;"
|
||||
f"若 x,w 多在 0~1→可能是归一化,需在代码里乘 net 尺寸"
|
||||
)
|
||||
print("[INFO] --dump-raw 已完成;为规避 MaixPy YOLO native 析构崩溃,测试进程将直接退出。")
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Stage2 黑三角 YOLO 设备本地图测试",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
ap.add_argument("image", help="本地图片路径(设备上的路径)")
|
||||
ap.add_argument(
|
||||
"--roi",
|
||||
default="",
|
||||
metavar="x0,y0,x1,y1",
|
||||
help="可选。若填写:image 为整幅图,在此图上取 Stage1 ROI 再跑 Stage2;"
|
||||
"留空:image 本身就是 Stage1 裁切图(默认)",
|
||||
)
|
||||
ap.add_argument("-o", "--output", default="", help="输出可视化路径;默认 原名_stage2_vis.jpg")
|
||||
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_BLACK_YOLO_MODEL_PATH")
|
||||
ap.add_argument("--conf", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_CONF_TH")
|
||||
ap.add_argument("--iou", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_IOU_TH")
|
||||
ap.add_argument(
|
||||
"--coord",
|
||||
choices=["native", "letterbox"],
|
||||
default="",
|
||||
help="覆盖 TRIANGLE_BLACK_YOLO_COORD_MODE;默认用 config",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--compare-coord",
|
||||
action="store_true",
|
||||
help="各跑一次 native 与 letterbox,输出两张图 *_stage2_native.jpg / *_stage2_letterbox.jpg",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fresh-detector",
|
||||
action="store_true",
|
||||
help="清掉 YOLO 缓存再测(换模型或排查缓存时用)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--allow-save-roi",
|
||||
action="store_true",
|
||||
help="不强制关闭 TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP(默认测试时会关掉以免写满相册目录)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dump-raw",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="打印前 N 个 detect 原始框 x,y,w,h,score,cls(coord 映射前;native/letterbox 共用同一批 raw)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
img_path = os.path.abspath(args.image)
|
||||
if not os.path.isfile(img_path):
|
||||
print(f"[ERR] 找不到图片: {img_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import config as cfg
|
||||
import target_roi_yolo as yroi
|
||||
except ImportError as e:
|
||||
print(f"[ERR] 无法导入 config / target_roi_yolo: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.fresh_detector:
|
||||
yroi.reset_yolo_detector_cache()
|
||||
|
||||
# 备份并临时覆盖 config(单进程顺序跑)
|
||||
bak: dict[str, object] = {}
|
||||
|
||||
def _patch(key: str, val: object):
|
||||
if key not in bak:
|
||||
bak[key] = getattr(cfg, key, None)
|
||||
setattr(cfg, key, val)
|
||||
|
||||
def _restore():
|
||||
for k, v in bak.items():
|
||||
setattr(cfg, k, v)
|
||||
|
||||
try:
|
||||
_patch("TRIANGLE_BLACK_YOLO_ENABLE", True)
|
||||
if not args.allow_save_roi:
|
||||
_patch("TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP", False)
|
||||
if args.model.strip():
|
||||
_patch("TRIANGLE_BLACK_YOLO_MODEL_PATH", args.model.strip())
|
||||
if args.conf is not None:
|
||||
_patch("TRIANGLE_BLACK_YOLO_CONF_TH", float(args.conf))
|
||||
if args.iou is not None:
|
||||
_patch("TRIANGLE_BLACK_YOLO_IOU_TH", float(args.iou))
|
||||
if args.coord and not args.compare_coord:
|
||||
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", args.coord)
|
||||
|
||||
mp = getattr(cfg, "TRIANGLE_BLACK_YOLO_MODEL_PATH", "") or ""
|
||||
if not os.path.isfile(mp):
|
||||
print(f"[ERR] 模型文件不存在: {mp}")
|
||||
sys.exit(1)
|
||||
|
||||
img_rgb = _load_rgb_numpy(img_path)
|
||||
h, w = int(img_rgb.shape[0]), int(img_rgb.shape[1])
|
||||
|
||||
if args.roi.strip():
|
||||
roi_xyxy = _parse_roi(args.roi.strip())
|
||||
rx0, ry0, rx1, ry1 = [int(round(float(v))) for v in roi_xyxy]
|
||||
if rx1 <= rx0 or ry1 <= ry0:
|
||||
print("[ERR] ROI 无效:需满足 x1>x0 且 y1>y0")
|
||||
sys.exit(1)
|
||||
# 与 target_roi_yolo.try_black_triangle_boxes_work 相同的 clip
|
||||
rx0 = max(0, min(rx0, w - 1))
|
||||
ry0 = max(0, min(ry0, h - 1))
|
||||
rx1 = max(rx0 + 1, min(rx1, w))
|
||||
ry1 = max(ry0 + 1, min(ry1, h))
|
||||
ring_roi = (rx0, ry0, rx1, ry1)
|
||||
print(f"[INFO] 模式=整图+ROI ring={ring_roi} image={w}×{h}")
|
||||
else:
|
||||
ring_roi = (0, 0, w, h)
|
||||
print(f"[INFO] 模式=已是 Stage1 裁切 crop={w}×{h}")
|
||||
|
||||
logger = _PrintLogger()
|
||||
det = yroi._get_detector(mp)
|
||||
if det is None:
|
||||
print("[ERR] 无法加载 nn.YOLOv5(检查模型路径与 Maix 环境)")
|
||||
sys.exit(1)
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
print(f"[INFO] model={mp} net_in={net_w}×{net_h}")
|
||||
|
||||
rx0, ry0, rx1, ry1 = ring_roi
|
||||
import numpy as np
|
||||
|
||||
slab_for_det = np.ascontiguousarray(img_rgb[ry0:ry1, rx0:rx1], dtype=np.uint8).copy()
|
||||
rh_s, rw_s = int(slab_for_det.shape[0]), int(slab_for_det.shape[1])
|
||||
|
||||
modes = ["native", "letterbox"] if args.compare_coord else [
|
||||
(args.coord or getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", "native"))
|
||||
]
|
||||
|
||||
base, ext = os.path.splitext(img_path)
|
||||
ext = ext if ext else ".jpg"
|
||||
|
||||
for mode in modes:
|
||||
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
|
||||
cur_coord = getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
|
||||
print(f"[INFO] --- TRIANGLE_BLACK_YOLO_COORD_MODE={cur_coord} ---")
|
||||
|
||||
boxes, slab = _run_once(yroi, img_rgb, ring_roi, logger)
|
||||
print(
|
||||
f"[INFO] 子框数量={len(boxes)} conf={getattr(cfg, 'TRIANGLE_BLACK_YOLO_CONF_TH', '?')} "
|
||||
f"coord={cur_coord}"
|
||||
)
|
||||
for i, b in enumerate(boxes):
|
||||
print(f" s2_{i}: {b}")
|
||||
|
||||
if args.compare_coord:
|
||||
out_path = f"{base}_stage2_{mode}{ext}"
|
||||
elif args.output.strip():
|
||||
out_path = args.output.strip()
|
||||
else:
|
||||
out_path = base + "_stage2_vis" + ext
|
||||
|
||||
import cv2
|
||||
|
||||
bgr = _draw_boxes_on_crop(slab, boxes)
|
||||
cv2.imwrite(out_path, bgr, [int(cv2.IMWRITE_JPEG_QUALITY), 92])
|
||||
print(f"[OK] saved: {out_path}")
|
||||
|
||||
if args.compare_coord:
|
||||
print(
|
||||
"[HINT] contain 训练时若 letterbox 对齐更好,请将 config 里 "
|
||||
"TRIANGLE_BLACK_YOLO_COORD_MODE 设为 letterbox"
|
||||
)
|
||||
|
||||
if args.dump_raw > 0:
|
||||
conf_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_CONF_TH", 0.5))
|
||||
iou_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_IOU_TH", 0.45))
|
||||
print("\n[INFO] --dump-raw 放在最后执行,避免 raw native 对象影响 compare-coord 流程。")
|
||||
_dump_raw_and_hard_exit(
|
||||
det,
|
||||
yroi,
|
||||
slab_for_det,
|
||||
rw_s,
|
||||
rh_s,
|
||||
net_w,
|
||||
net_h,
|
||||
conf_th,
|
||||
iou_th,
|
||||
args.dump_raw,
|
||||
)
|
||||
|
||||
finally:
|
||||
_restore()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
242
train_yolo/test_triangle_one_image.py
Normal file
242
train_yolo/test_triangle_one_image.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单张图片快速测试:三角形四角标记识别 + 单应性落点 + PnP 估距
|
||||
|
||||
用法(在板子上):
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --out /root/phot/tri_out.jpg
|
||||
|
||||
调参对比(不改代码,临时覆盖 config.TRIANGLE_*):
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --preset shadow
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --max-interior-gray 160 --min-dark-ratio 0.20
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import config
|
||||
import triangle_target as tri_mod
|
||||
from triangle_target import (
|
||||
detect_triangle_markers,
|
||||
load_camera_from_xml,
|
||||
load_triangle_positions,
|
||||
try_triangle_scoring,
|
||||
)
|
||||
|
||||
|
||||
def _apply_overrides(args) -> None:
|
||||
# 预设:阴影/低对比度场景更宽松(尽量保持速度:不启 CLAHE)
|
||||
if args.preset == "shadow":
|
||||
setattr(config, "TRIANGLE_ENABLE_CLAHE_FALLBACK", False)
|
||||
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", 0)
|
||||
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", 160)
|
||||
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", 160)
|
||||
setattr(config, "TRIANGLE_MIN_DARK_RATIO", 0.20)
|
||||
# adaptive 只在 Otsu 失败时尝试,保持尝试次数很少
|
||||
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", (21,))
|
||||
|
||||
# 手动覆盖(优先级高于 preset)
|
||||
if args.max_interior_gray is not None:
|
||||
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", int(args.max_interior_gray))
|
||||
if args.dark_pixel_gray is not None:
|
||||
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", int(args.dark_pixel_gray))
|
||||
if args.min_dark_ratio is not None:
|
||||
setattr(config, "TRIANGLE_MIN_DARK_RATIO", float(args.min_dark_ratio))
|
||||
if args.min_contrast_diff is not None:
|
||||
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", int(args.min_contrast_diff))
|
||||
if args.detect_scale is not None:
|
||||
setattr(config, "TRIANGLE_DETECT_SCALE", float(args.detect_scale))
|
||||
if args.adaptive_blocks is not None:
|
||||
bs = tuple(int(x) for x in args.adaptive_blocks.split(",") if x.strip())
|
||||
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", bs)
|
||||
|
||||
|
||||
def _dump_config() -> Dict[str, Any]:
|
||||
keys = [
|
||||
"TRIANGLE_DETECT_SCALE",
|
||||
"TRIANGLE_SIZE_RANGE",
|
||||
"TRIANGLE_MAX_INTERIOR_GRAY",
|
||||
"TRIANGLE_DARK_PIXEL_GRAY",
|
||||
"TRIANGLE_MIN_DARK_RATIO",
|
||||
"TRIANGLE_MIN_CONTRAST_DIFF",
|
||||
"TRIANGLE_ADAPTIVE_BLOCK_SIZES",
|
||||
"TRIANGLE_MAX_FILTERED_FOR_COMBO",
|
||||
"TRIANGLE_EARLY_EXIT_CANDIDATES",
|
||||
"TRIANGLE_ENABLE_CLAHE_FALLBACK",
|
||||
]
|
||||
out = {}
|
||||
for k in keys:
|
||||
out[k] = getattr(config, k, None)
|
||||
return out
|
||||
|
||||
|
||||
def _draw_tri_debug(img_bgr: np.ndarray, tri: Dict[str, Any]) -> np.ndarray:
|
||||
out = img_bgr.copy()
|
||||
markers = tri.get("markers") or []
|
||||
|
||||
# 画三角形轮廓 + center + id
|
||||
for m in markers:
|
||||
corners = np.array(m.get("corners", []), dtype=np.int32)
|
||||
if corners.size == 0:
|
||||
continue
|
||||
cv2.polylines(out, [corners], True, (0, 255, 0), 2)
|
||||
c = m.get("center") or (corners[:, 0].mean(), corners[:, 1].mean())
|
||||
cx, cy = int(c[0]), int(c[1])
|
||||
cv2.circle(out, (cx, cy), 4, (0, 0, 255), -1)
|
||||
mid = m.get("id", "?")
|
||||
cv2.putText(out, f"T{mid}", (cx - 18, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
|
||||
|
||||
# 若有 homography,画靶心(把 (0,0) 反投影到图像)
|
||||
H = tri.get("homography")
|
||||
if H is not None:
|
||||
try:
|
||||
H = np.array(H, dtype=np.float64)
|
||||
H_inv = np.linalg.inv(H)
|
||||
c_img = cv2.perspectiveTransform(np.array([[[0.0, 0.0]]], dtype=np.float32), H_inv)[0][0]
|
||||
ocx, ocy = int(c_img[0]), int(c_img[1])
|
||||
cv2.circle(out, (ocx, ocy), 5, (0, 0, 255), -1)
|
||||
cv2.circle(out, (ocx, ocy), 10, (0, 0, 255), 1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 叠加结果信息
|
||||
lines = []
|
||||
if tri.get("ok"):
|
||||
lines.append("tri_ok=True")
|
||||
if tri.get("dx_cm") is not None and tri.get("dy_cm") is not None:
|
||||
lines.append(f"dx,dy=({tri['dx_cm']:.2f},{tri['dy_cm']:.2f})cm")
|
||||
if tri.get("distance_m") is not None:
|
||||
lines.append(f"dist={float(tri['distance_m']):.2f}m")
|
||||
else:
|
||||
lines.append("tri_ok=False")
|
||||
|
||||
y0 = 22
|
||||
for i, t in enumerate(lines):
|
||||
cv2.putText(out, t, (10, y0 + i * 18), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--image", required=True, help="输入图片路径(jpg/png)")
|
||||
ap.add_argument("--out", default="", help="输出标注图片路径(可选)")
|
||||
ap.add_argument("--laser-x", type=int, default=-1, help="激光点 x(像素),默认用图像中心")
|
||||
ap.add_argument("--laser-y", type=int, default=-1, help="激光点 y(像素),默认用图像中心")
|
||||
ap.add_argument("--preset", choices=["", "shadow"], default="", help="调参预设(shadow=阴影更鲁棒,不启 CLAHE)")
|
||||
ap.add_argument("--max-interior-gray", type=int, default=None)
|
||||
ap.add_argument("--dark-pixel-gray", type=int, default=None)
|
||||
ap.add_argument("--min-dark-ratio", type=float, default=None)
|
||||
ap.add_argument("--min-contrast-diff", type=int, default=None)
|
||||
ap.add_argument("--detect-scale", type=float, default=None)
|
||||
ap.add_argument("--adaptive-blocks", default=None, help="例如: 11,21 ;为空表示不改")
|
||||
ap.add_argument("--verbose", action="store_true", help="输出更多检测阶段信息")
|
||||
args = ap.parse_args()
|
||||
|
||||
_apply_overrides(args)
|
||||
# triangle_target.py 的日志默认写到 logger_manager;在离线脚本里 logger 可能未初始化。
|
||||
# verbose 模式下把 _log 重定向为 print,方便直接看到诊断信息。
|
||||
if args.verbose:
|
||||
try:
|
||||
tri_mod._log = lambda msg: print(str(msg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
img_bgr = cv2.imread(args.image, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise SystemExit(f"读图失败:{args.image}")
|
||||
# triangle_target.try_triangle_scoring 约定输入为 RGB;OpenCV imread 返回 BGR
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
if args.laser_x >= 0 and args.laser_y >= 0:
|
||||
laser_point = (int(args.laser_x), int(args.laser_y))
|
||||
else:
|
||||
laser_point = (w // 2, h // 2)
|
||||
|
||||
K, dist = load_camera_from_xml(getattr(config, "CAMERA_CALIB_XML", ""))
|
||||
pos = load_triangle_positions(getattr(config, "TRIANGLE_POSITIONS_JSON", ""))
|
||||
|
||||
print("[tri-test] image:", args.image, "shape:", (h, w))
|
||||
print("[tri-test] laser_point:", laser_point)
|
||||
print("[tri-test] calib_ok:", bool(K is not None and dist is not None), "pos_ok:", bool(pos))
|
||||
print("[tri-test] config:", json.dumps(_dump_config(), ensure_ascii=False))
|
||||
|
||||
# 先单独跑一次三角形候选检测,便于区分“没找到候选” vs “找到候选但评分/单应性失败”
|
||||
scale = float(getattr(config, "TRIANGLE_DETECT_SCALE", 0.5) or 0.5)
|
||||
if not (0.05 <= scale <= 1.0):
|
||||
scale = 0.5
|
||||
long_side = max(h, w)
|
||||
max_dim = max(64, int(long_side * scale))
|
||||
if long_side > max_dim:
|
||||
det_scale = max_dim / long_side
|
||||
det_w = int(w * det_scale)
|
||||
det_h = int(h * det_scale)
|
||||
img_det = cv2.resize(img_bgr, (det_w, det_h), interpolation=cv2.INTER_LINEAR)
|
||||
inv_scale = 1.0 / det_scale
|
||||
size_range_det = (
|
||||
max(4, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[0] * det_scale)),
|
||||
max(8, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[1] * det_scale)),
|
||||
)
|
||||
else:
|
||||
img_det = img_bgr
|
||||
inv_scale = 1.0
|
||||
size_range_det = getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))
|
||||
|
||||
gray = cv2.cvtColor(img_det, cv2.COLOR_BGR2GRAY)
|
||||
markers_det = detect_triangle_markers(
|
||||
gray,
|
||||
orig_gray=gray,
|
||||
size_range=size_range_det,
|
||||
verbose=bool(args.verbose),
|
||||
)
|
||||
if inv_scale != 1.0 and markers_det:
|
||||
for m in markers_det:
|
||||
m["center"] = [m["center"][0] * inv_scale, m["center"][1] * inv_scale]
|
||||
m["corners"] = [[c[0] * inv_scale, c[1] * inv_scale] for c in m["corners"]]
|
||||
|
||||
print("[tri-test] markers_found:", len(markers_det), "ids:", [m.get("id") for m in markers_det])
|
||||
|
||||
t0 = time.time()
|
||||
tri = try_triangle_scoring(
|
||||
img_rgb, # try_triangle_scoring 期望 RGB
|
||||
laser_point,
|
||||
pos,
|
||||
K,
|
||||
dist,
|
||||
size_range=getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500)),
|
||||
)
|
||||
dt_ms = int(round((time.time() - t0) * 1000))
|
||||
|
||||
print("[tri-test] elapsed_ms:", dt_ms)
|
||||
print(json.dumps(tri, ensure_ascii=False, indent=2))
|
||||
|
||||
if args.out:
|
||||
out_path = args.out
|
||||
# 允许传目录(如 ./),自动生成文件名;未带扩展名时默认 .jpg
|
||||
if out_path.endswith("/") or out_path.endswith("\\") or os.path.isdir(out_path):
|
||||
out_path = os.path.join(out_path, "tri_out.jpg")
|
||||
root, ext = os.path.splitext(out_path)
|
||||
if not ext:
|
||||
out_path = root + ".jpg"
|
||||
|
||||
# 若 try_triangle_scoring 失败且没带回 markers,至少把候选 markers 画出来,方便肉眼判断
|
||||
tri_for_draw = tri if isinstance(tri, dict) else {"ok": False}
|
||||
if not tri_for_draw.get("markers") and markers_det:
|
||||
tri_for_draw = dict(tri_for_draw)
|
||||
tri_for_draw["markers"] = markers_det
|
||||
out_img = _draw_tri_debug(img_bgr, tri_for_draw)
|
||||
ok = cv2.imwrite(out_path, out_img)
|
||||
if not ok:
|
||||
raise SystemExit(f"写图失败(可能是不支持的扩展名):{out_path}")
|
||||
print("[tri-test] wrote:", out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
257
train_yolo/test_yolo_draw_boxes.py
Normal file
257
train_yolo/test_yolo_draw_boxes.py
Normal file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本地图片 → Maix YOLOv5 检测 → 画框保存(用于核对坐标 mode / 多框 union)。
|
||||
|
||||
运行环境:MaixCAM / MaixPy(需 maix.image / maix.nn),在项目根或任意目录执行均可。
|
||||
|
||||
示例:
|
||||
python test/test_yolo_draw_boxes.py /root/phot/shot_xxx.jpg
|
||||
python test/test_yolo_draw_boxes.py shot.jpg --loader cv2_rgb --conf 0.25
|
||||
python test/test_yolo_draw_boxes.py shot.jpg --debug
|
||||
python -h # 查看 --loader / --debug / --union 等全部参数
|
||||
|
||||
脚本版本(与设备同步用):20260206-yolo-vis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
|
||||
def _load_maix_image(path: str, image_mod):
|
||||
"""maix.image.load(部分 JPEG 解码后与 camera.read() 像素布局不一致,可能导致 NPU 全空)。"""
|
||||
return image_mod.load(path)
|
||||
|
||||
|
||||
def _load_cv2_rgb_as_maix(path: str, image_mod):
|
||||
"""
|
||||
OpenCV 读盘为 BGR → 转 RGB → 与 shoot_manager 里 image2cv 逆过程一致,供 YOLO input type: rgb。
|
||||
"""
|
||||
import cv2
|
||||
|
||||
arr = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
if arr is None:
|
||||
raise FileNotFoundError(f"cv2.imread 失败: {path}")
|
||||
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
||||
return image_mod.cv2image(arr, False, False)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(
|
||||
description="YOLO 画框测试(Maix)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="若提示 unrecognized arguments: --debug,说明设备上脚本未更新,请同步仓库中的 test/test_yolo_draw_boxes.py",
|
||||
)
|
||||
ap.add_argument("image", help="输入图片路径")
|
||||
ap.add_argument("-o", "--output", default="", help="输出图片路径;默认 原名_yolo_vis.jpg")
|
||||
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_YOLO_MODEL_PATH")
|
||||
ap.add_argument("--conf", type=float, default=None, help="置信度阈值")
|
||||
ap.add_argument("--iou", type=float, default=None, help="NMS IoU")
|
||||
ap.add_argument(
|
||||
"--coord",
|
||||
choices=["native", "letterbox"],
|
||||
default="",
|
||||
help="坐标映射;默认读 config.TRIANGLE_YOLO_COORD_MODE",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--union",
|
||||
action="store_true",
|
||||
help="按 TRIANGLE_YOLO_RING_CLASS_IDS 过滤后画合并外接矩形(与线上 ROI merge=union 一致)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--loader",
|
||||
choices=["auto", "maix", "cv2_rgb"],
|
||||
default="auto",
|
||||
help="auto: 先 maix.load,0 框则改用 cv2 RGB(推荐排查「有图但始终 0 框」)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="打印 detect 原始返回类型与 repr(截断)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from maix import image, nn
|
||||
except ImportError:
|
||||
print("[ERR] 需要 MaixPy(maix.image / maix.nn),请在 MaixCAM 上运行。")
|
||||
sys.exit(1)
|
||||
|
||||
import config as cfg
|
||||
import target_roi_yolo as yroi
|
||||
|
||||
img_path = os.path.abspath(args.image)
|
||||
if not os.path.isfile(img_path):
|
||||
print(f"[ERR] 找不到图片: {img_path}")
|
||||
sys.exit(1)
|
||||
|
||||
model_path = (args.model or getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or "").strip()
|
||||
if not os.path.isfile(model_path):
|
||||
print(f"[ERR] 模型文件不存在: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
conf_th = (
|
||||
float(args.conf)
|
||||
if args.conf is not None
|
||||
else float(getattr(cfg, "TRIANGLE_YOLO_CONF_TH", 0.5))
|
||||
)
|
||||
iou_th = (
|
||||
float(args.iou)
|
||||
if args.iou is not None
|
||||
else float(getattr(cfg, "TRIANGLE_YOLO_IOU_TH", 0.45))
|
||||
)
|
||||
coord_mode = (args.coord or getattr(cfg, "TRIANGLE_YOLO_COORD_MODE", "native")).lower()
|
||||
|
||||
out_path = args.output.strip()
|
||||
if not out_path:
|
||||
base, ext = os.path.splitext(img_path)
|
||||
ext = ext if ext else ".jpg"
|
||||
out_path = base + "_yolo_vis" + ext
|
||||
|
||||
det = nn.YOLOv5(model=model_path, dual_buff=False)
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
|
||||
def _run_detect(maix_img, tag: str):
|
||||
r = det.detect(maix_img, conf_th=conf_th, iou_th=iou_th)
|
||||
if args.debug:
|
||||
rlen = len(r) if r is not None and hasattr(r, "__len__") else "n/a"
|
||||
rrepr = repr(r)
|
||||
if len(rrepr) > 300:
|
||||
rrepr = rrepr[:300] + "..."
|
||||
print(f"[DEBUG] loader={tag} raw_type={type(r)} len={rlen} repr={rrepr}")
|
||||
return yroi._normalize_objs(r if r is not None else []), maix_img, tag
|
||||
|
||||
img = None
|
||||
load_tag = ""
|
||||
objs = []
|
||||
|
||||
if args.loader == "cv2_rgb":
|
||||
img = _load_cv2_rgb_as_maix(img_path, image)
|
||||
load_tag = "cv2_rgb"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
elif args.loader == "maix":
|
||||
img = _load_maix_image(img_path, image)
|
||||
load_tag = "maix_load"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
else:
|
||||
# auto
|
||||
img = _load_maix_image(img_path, image)
|
||||
load_tag = "maix_load"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
if len(objs) == 0:
|
||||
print(
|
||||
"[WARN] maix.image.load 在 conf_th=%s 下仍为 0 框,改用 cv2 BGR→RGB→cv2image 重试(常见可恢复)"
|
||||
% conf_th
|
||||
)
|
||||
img2 = _load_cv2_rgb_as_maix(img_path, image)
|
||||
objs, img, load_tag = _run_detect(img2, "cv2_rgb_retry")
|
||||
|
||||
src_w, src_h = img.width(), img.height()
|
||||
|
||||
labels = getattr(det, "labels", None)
|
||||
|
||||
def _label(cid: int) -> str:
|
||||
if labels is None:
|
||||
return str(cid)
|
||||
try:
|
||||
return str(labels[int(cid)])
|
||||
except Exception:
|
||||
return str(cid)
|
||||
|
||||
print(
|
||||
f"[INFO] loader={load_tag} image={src_w}×{src_h}, net_in={net_w}×{net_h}, "
|
||||
f"coord={coord_mode}, conf_th={conf_th}, iou_th={iou_th}"
|
||||
)
|
||||
print(f"[INFO] NMS 后检测框数量={len(objs)} → {out_path}")
|
||||
if len(objs) == 0:
|
||||
print(
|
||||
"[HINT] 仍为 0 框时常见原因:\n"
|
||||
" 1) 强制 cv2 路径: --loader cv2_rgb\n"
|
||||
" 2) NMS 过严: --iou 0.95\n"
|
||||
" 3) 图与训练分布差太大 / 模型未见过该场景\n"
|
||||
" 4) 用 camera.read() 一帧存盘再测,对比 file 与实时是否一致"
|
||||
)
|
||||
|
||||
# 颜色:按类别轮换(仅有 COLOR_* 时常量时用)
|
||||
color_cycle = []
|
||||
for name in ("RED", "GREEN", "BLUE", "ORANGE", "YELLOW", "CYAN", "MAGENTA"):
|
||||
c = getattr(image, f"COLOR_{name}", None)
|
||||
if c is not None:
|
||||
color_cycle.append(c)
|
||||
if not color_cycle:
|
||||
color_cycle = [getattr(image, "COLOR_RED", 0)]
|
||||
|
||||
for i, o in enumerate(objs):
|
||||
cid = yroi._det_obj_class_id(o)
|
||||
if cid is None:
|
||||
cid = -1
|
||||
try:
|
||||
sc = float(o.score)
|
||||
except Exception:
|
||||
sc = 0.0
|
||||
x0, y0, x1, y1 = yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h)
|
||||
ix = int(max(0, min(x0, src_w - 1)))
|
||||
iy = int(max(0, min(y0, src_h - 1)))
|
||||
iw = int(max(1, min(x1 - x0, src_w - ix)))
|
||||
ih = int(max(1, min(y1 - y0, src_h - iy)))
|
||||
col = color_cycle[cid % len(color_cycle)] if cid >= 0 else color_cycle[0]
|
||||
img.draw_rect(ix, iy, iw, ih, color=col)
|
||||
ty = max(0, iy - 14)
|
||||
msg = f"{_label(cid)} {sc:.2f}"
|
||||
img.draw_string(ix, ty, msg, color=col)
|
||||
print(f" #{i} cls={cid} {_label(cid)} score={sc:.3f} xywh=({ix},{iy},{iw},{ih})")
|
||||
|
||||
if args.union:
|
||||
class_ids = getattr(cfg, "TRIANGLE_YOLO_RING_CLASS_IDS", (0,))
|
||||
if isinstance(class_ids, int):
|
||||
class_ids = (class_ids,)
|
||||
cand = [o for o in objs if yroi._det_obj_class_id(o) in class_ids]
|
||||
if cand:
|
||||
xy_list = [
|
||||
yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h) for o in cand
|
||||
]
|
||||
merged = yroi._merge_roi_xyxy(xy_list, "union")
|
||||
if merged:
|
||||
mx0, my0, mx1, my1 = merged
|
||||
mx0 = max(0, min(mx0, src_w - 1))
|
||||
my0 = max(0, min(my0, src_h - 1))
|
||||
mx1 = max(mx0 + 1, min(mx1, src_w))
|
||||
my1 = max(my0 + 1, min(my1, src_h))
|
||||
uw, uh = int(mx1 - mx0), int(my1 - my0)
|
||||
ucol = getattr(image, "COLOR_GREEN", color_cycle[0])
|
||||
# 画粗一点的 union:描两遍错位矩形简易模拟加粗
|
||||
for d in (0, 2):
|
||||
img.draw_rect(
|
||||
int(mx0) - d,
|
||||
int(my0) - d,
|
||||
uw + 2 * d,
|
||||
uh + 2 * d,
|
||||
color=ucol,
|
||||
)
|
||||
img.draw_string(
|
||||
int(mx0),
|
||||
max(0, int(my0) - 28),
|
||||
f"UNION ({len(cand)} boxes)",
|
||||
color=ucol,
|
||||
)
|
||||
print(f"[INFO] UNION [{int(mx0)},{int(my0)},{int(mx1)},{int(my1)}] from {len(cand)} boxes")
|
||||
else:
|
||||
print("[WARN] --union 但 RING_CLASS_IDS 过滤后无框")
|
||||
|
||||
try:
|
||||
img.save(out_path, quality=95)
|
||||
except TypeError:
|
||||
img.save(out_path)
|
||||
print(f"[OK] saved: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
506
train_yolo/train_black_triangle_pos.py
Normal file
506
train_yolo/train_black_triangle_pos.py
Normal file
@@ -0,0 +1,506 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
YOLO11 关键点检测训练脚本(靶纸四角)。
|
||||
|
||||
设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。
|
||||
默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。
|
||||
|
||||
关于「业务像素误差」:
|
||||
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
|
||||
- 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。
|
||||
- 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics
|
||||
同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。
|
||||
多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。
|
||||
|
||||
XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA,
|
||||
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。
|
||||
|
||||
务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect,
|
||||
会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
from pose_pixel_metrics import eval_val_pixel_error
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore',
|
||||
message=".*scatter_add_kernel does not have a deterministic implementation.*")
|
||||
|
||||
|
||||
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
|
||||
"""删除 data.yaml 的 path 下 labels/*.cache。
|
||||
|
||||
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
|
||||
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
|
||||
from ultralytics.utils import YAML
|
||||
|
||||
try:
|
||||
cfg = YAML.load(data_yaml_path)
|
||||
except Exception:
|
||||
return 0
|
||||
root = cfg.get("path")
|
||||
if not root:
|
||||
return 0
|
||||
root = os.path.abspath(os.path.expanduser(str(root)))
|
||||
pattern = os.path.join(root, "labels", "*.cache")
|
||||
n = 0
|
||||
for p in glob.glob(pattern):
|
||||
try:
|
||||
os.unlink(p)
|
||||
n += 1
|
||||
except OSError:
|
||||
pass
|
||||
return n
|
||||
|
||||
|
||||
def _pick_device(explicit: str | None):
|
||||
"""返回 ultralytics train/predict 可用的 device。"""
|
||||
if explicit and explicit != "auto":
|
||||
e = explicit.lower()
|
||||
if e == "xpu":
|
||||
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
|
||||
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
|
||||
return torch.device("xpu")
|
||||
if e in ("0", "cuda", "gpu"):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("指定了 CUDA 但不可用")
|
||||
return 0
|
||||
if e == "cpu":
|
||||
return "cpu"
|
||||
return explicit
|
||||
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
if torch.cuda.is_available():
|
||||
return 0
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _default_amp(device) -> bool:
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
return False
|
||||
if device == "cpu":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _patch_ultralytics_for_xpu():
|
||||
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
|
||||
import ultralytics.engine.trainer as ut_trainer
|
||||
import ultralytics.engine.validator as ut_validator
|
||||
from ultralytics.utils.torch_utils import select_device as _original_select_device
|
||||
|
||||
# 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回;
|
||||
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device;
|
||||
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
|
||||
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
|
||||
def _patched_select_device(device="", *args, **kwargs):
|
||||
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
|
||||
# Older forks sometimes passed extra positional args; forward everything.
|
||||
if isinstance(device, str):
|
||||
d = device.strip().lower()
|
||||
if d == "xpu" or d.startswith("xpu:"):
|
||||
return torch.device(device.strip())
|
||||
return _original_select_device(device, *args, **kwargs)
|
||||
|
||||
import ultralytics.utils.torch_utils
|
||||
|
||||
ultralytics.utils.torch_utils.select_device = _patched_select_device
|
||||
ut_trainer.select_device = _patched_select_device
|
||||
ut_validator.select_device = _patched_select_device
|
||||
|
||||
# 2. 修补 Trainer 的内存函数
|
||||
BT = ut_trainer.BaseTrainer
|
||||
if not getattr(BT, "_archery_xpu_memory_patched", False):
|
||||
_orig_get_memory = BT._get_memory
|
||||
_orig_clear_memory = BT._clear_memory
|
||||
|
||||
def _get_memory(self, fraction=False):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_get_memory(self, fraction)
|
||||
# ... (原有的 XPU 内存获取逻辑保持不变) ...
|
||||
memory, total = 0, 0
|
||||
try:
|
||||
idx = self.device.index
|
||||
if idx is None:
|
||||
idx = torch.xpu.current_device()
|
||||
memory = int(torch.xpu.memory_allocated(idx))
|
||||
if fraction:
|
||||
total = int(torch.xpu.get_device_properties(idx).total_memory)
|
||||
except Exception:
|
||||
pass
|
||||
return (memory / total) if fraction and total > 0 else (memory / 2**30)
|
||||
|
||||
def _clear_memory(self, threshold=None):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_clear_memory(self, threshold)
|
||||
if threshold is not None:
|
||||
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
||||
if self._get_memory(fraction=True) <= threshold:
|
||||
return
|
||||
gc.collect()
|
||||
if hasattr(torch.xpu, "empty_cache"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
BT._get_memory = _get_memory
|
||||
BT._clear_memory = _clear_memory
|
||||
BT._archery_xpu_memory_patched = True
|
||||
|
||||
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
|
||||
BV = ut_validator.BaseValidator
|
||||
if not getattr(BV, "_archery_xpu_memory_patched", False):
|
||||
# 为 Validator 添加同样的内存处理方法
|
||||
BV._get_memory = _get_memory
|
||||
BV._clear_memory = _clear_memory
|
||||
BV._archery_xpu_memory_patched = True
|
||||
|
||||
|
||||
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
|
||||
"""用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。"""
|
||||
import ultralytics.engine.trainer as ut
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
BT = ut.BaseTrainer
|
||||
if getattr(BT, "_archery_best_by_pixel_installed", False):
|
||||
return
|
||||
|
||||
_orig_validate = BT.validate
|
||||
|
||||
def validate(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if self.ema and self.world_size > 1:
|
||||
for buffer in self.ema.ema.buffers():
|
||||
dist.broadcast(buffer, src=0)
|
||||
metrics = self.validator(self)
|
||||
if metrics is None:
|
||||
return None, None
|
||||
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
|
||||
|
||||
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
|
||||
mean_px: float | None = None
|
||||
if use_pixel:
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
|
||||
os.close(fd)
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
core = unwrap_model(self.ema.ema if self.ema else self.model)
|
||||
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
|
||||
probe = YOLO(tmp_path)
|
||||
stats = eval_val_pixel_error(
|
||||
probe,
|
||||
data_yaml,
|
||||
device=self.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
if mean_px is None:
|
||||
raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)")
|
||||
except Exception as exc:
|
||||
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
|
||||
mean_px = None
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if mean_px is not None:
|
||||
fitness = -float(mean_px)
|
||||
metrics["metrics/mean_px(val)"] = float(mean_px)
|
||||
else:
|
||||
fitness = float(orig_fitness)
|
||||
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
BT.validate = validate
|
||||
BT._archery_best_by_pixel_installed = True
|
||||
|
||||
|
||||
def _fmt_csv_metric(v: float | int | None) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
if isinstance(v, float):
|
||||
return f"{v:.6g}"
|
||||
return str(v)
|
||||
|
||||
|
||||
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
|
||||
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
|
||||
("pixel_error/mean_px", "mean_px"),
|
||||
("pixel_error/median_px", "median_px"),
|
||||
("pixel_error/p95_px", "p95_px"),
|
||||
("pixel_error/max_px", "max_px"),
|
||||
("pixel_error/n_points", "n_points"),
|
||||
("pixel_error/n_images", "n_images"),
|
||||
("pixel_error/skip_no_det", "skip_no_det"),
|
||||
("pixel_error/skip_no_gt", "skip_no_gt"),
|
||||
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
|
||||
)
|
||||
|
||||
|
||||
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
|
||||
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。"""
|
||||
csv_path = Path(save_dir) / "results.csv"
|
||||
if not csv_path.is_file():
|
||||
return
|
||||
try:
|
||||
with open(csv_path, newline="", encoding="utf-8") as f:
|
||||
rows = list(csv.reader(f))
|
||||
except OSError:
|
||||
return
|
||||
if len(rows) < 2:
|
||||
return
|
||||
header = list(rows[0])
|
||||
for col_name, _ in _PIXEL_METRIC_COLUMNS:
|
||||
if col_name not in header:
|
||||
header.append(col_name)
|
||||
for ri in range(1, len(rows)):
|
||||
rows[ri].append("")
|
||||
col_ix = {name: i for i, name in enumerate(header)}
|
||||
rows[0] = header
|
||||
target_ri: int | None = None
|
||||
for ri in range(1, len(rows)):
|
||||
row = rows[ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
try:
|
||||
if int(float(row[0].strip())) == int(epoch_1based):
|
||||
target_ri = ri
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if target_ri is None:
|
||||
return
|
||||
row = rows[target_ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
for col_name, sk in _PIXEL_METRIC_COLUMNS:
|
||||
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
|
||||
try:
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerows(rows)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
|
||||
def on_fit_epoch_end(trainer):
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
if RANK not in {-1, 0}:
|
||||
return
|
||||
if every <= 0:
|
||||
return
|
||||
ep = int(getattr(trainer, "epoch", -1))
|
||||
if (ep + 1) % every != 0:
|
||||
return
|
||||
w = Path(trainer.save_dir) / "weights" / "last.pt"
|
||||
if not w.is_file():
|
||||
return
|
||||
m = YOLO(str(w))
|
||||
stats = eval_val_pixel_error(
|
||||
m,
|
||||
data_yaml,
|
||||
device=trainer.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
p95_px = stats.get("p95_px")
|
||||
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
|
||||
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
|
||||
print(
|
||||
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
|
||||
f"n_points={stats.get('n_points', 0)} "
|
||||
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
|
||||
)
|
||||
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
|
||||
|
||||
return on_fit_epoch_end
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)")
|
||||
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
|
||||
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)")
|
||||
ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小")
|
||||
ap.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-amp",
|
||||
action="store_true",
|
||||
help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)",
|
||||
)
|
||||
ap.add_argument("--project", default="runs/pose")
|
||||
ap.add_argument("--name", default="target_pose_train")
|
||||
ap.add_argument("--workers", type=int, default=4)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-every",
|
||||
type=int,
|
||||
default=0,
|
||||
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--best-by-pixel",
|
||||
action="store_true",
|
||||
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-fitness-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--export-onnx",
|
||||
action="store_true",
|
||||
help="训练结束后导出 ONNX(需再设 --onnx-imgsz)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--onnx-imgsz",
|
||||
type=int,
|
||||
nargs=2,
|
||||
metavar=("H", "W"),
|
||||
default=[224, 320],
|
||||
help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--clear-label-cache",
|
||||
action="store_true",
|
||||
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
device = _pick_device(None if args.device == "auto" else args.device)
|
||||
use_amp = False if args.no_amp else _default_amp(device)
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
print(f"✅ 使用 Intel XPU: {device}")
|
||||
elif device == 0 or device == "0":
|
||||
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("⚠️ 使用 CPU,训练会较慢")
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
_patch_ultralytics_for_xpu()
|
||||
|
||||
data_yaml = args.data
|
||||
if not os.path.isabs(data_yaml):
|
||||
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
|
||||
if not os.path.exists(data_yaml):
|
||||
print(f"❌ 数据集配置不存在: {data_yaml}")
|
||||
return
|
||||
|
||||
if args.clear_label_cache:
|
||||
n_rm = _clear_ultralytics_label_caches(data_yaml)
|
||||
print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。")
|
||||
|
||||
print(f"📦 加载模型: {args.model}(固定 task=pose)")
|
||||
model = YOLO(args.model, task="pose")
|
||||
|
||||
if args.best_by_pixel:
|
||||
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
|
||||
print(
|
||||
"📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);"
|
||||
"反向传播仍为 Ultralytics 默认 pose/box loss。"
|
||||
)
|
||||
|
||||
if args.pixel_metrics_every > 0:
|
||||
model.add_callback(
|
||||
"on_fit_epoch_end",
|
||||
_make_pixel_metrics_callback(
|
||||
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
|
||||
),
|
||||
)
|
||||
|
||||
model.train(
|
||||
task="pose",
|
||||
data=data_yaml,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
name=args.name,
|
||||
project=args.project,
|
||||
exist_ok=True,
|
||||
save=True,
|
||||
save_period=5,
|
||||
device=device,
|
||||
workers=args.workers,
|
||||
lr0=0.0001,
|
||||
lrf=0.01,
|
||||
optimizer="AdamW",
|
||||
momentum=0.937,
|
||||
weight_decay=0.001,
|
||||
warmup_epochs=0,
|
||||
warmup_momentum=0.8,
|
||||
warmup_bias_lr=0.1,
|
||||
hsv_h=0.015,
|
||||
hsv_s=0.7,
|
||||
hsv_v=0.4,
|
||||
degrees=5.0,
|
||||
translate=0.0,
|
||||
scale=0.2,
|
||||
shear=0.0,
|
||||
perspective=0.0000,
|
||||
flipud=0.0,
|
||||
fliplr=0.5,
|
||||
mosaic=0.0,
|
||||
mixup=0.0,
|
||||
copy_paste=0.0,
|
||||
box=6,
|
||||
cls=0.5,
|
||||
dfl=1.5,
|
||||
pose=18.0,
|
||||
kobj=0.5,
|
||||
freeze=0,
|
||||
seed=42,
|
||||
verbose=True,
|
||||
amp=use_amp,
|
||||
patience=100,
|
||||
cos_lr=True,
|
||||
)
|
||||
|
||||
print("\n✅ 训练完成!")
|
||||
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
|
||||
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
|
||||
|
||||
if args.export_onnx:
|
||||
h, w = args.onnx_imgsz
|
||||
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
|
||||
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
|
||||
print("✅ ONNX 完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user