增加训练yolo的代码

This commit is contained in:
gcw_4spBpAfv
2026-05-15 09:35:53 +08:00
parent dff5096164
commit 541418fd60
13 changed files with 953 additions and 140 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,343 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage2 黑三角 YOLO —— 在 Maix 设备上用本地图片测试(与线上 target_roi_yolo.try_black_triangle_boxes_work 完全一致)。
不在 PC 上跑 NPU需把脚本与 config / target_roi_yolo.py 同步到设备,并在设备上执行。
典型用法
--------
# 输入已是 Stage1 裁切(与你保存的 stage2_roi_*.jpg 一致)
python test/test_stage2_black_yolo_device.py /root/phot/stage2_roi_xxx.jpg
# 输入为整幅相机图,手动给出 Stage1 环靶 ROI与线上日志 ring全图=[rx0,ry0,rx1,ry1] 一致)
python test/test_stage2_black_yolo_device.py /root/phot/full.jpg --roi 197,196,507,461
# 对比 native / letterbox 坐标映射(排查 contain 训练与推理对齐)
python test/test_stage2_black_yolo_device.py ./crop.jpg --compare-coord
# 覆盖置信度、模型路径(仍读其余项自 config
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.25 -m /maixapp/apps/t11/model_270648.mud
# 只看 NPU 原始框(映射前):判断坐标是 ~224 网络空间还是归一化 0~1
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.05 --dump-raw 15
依赖MaixPymaix.nn、OpenCVcv2、numpy项目根须在 sys.path本脚本已插入上级目录
"""
from __future__ import annotations
import argparse
import os
import sys
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
def _parse_roi(s: str) -> tuple[int, int, int, int]:
parts = [p.strip() for p in s.replace(" ", "").split(",")]
if len(parts) != 4:
raise ValueError("ROI 需要 4 个整数x0,y0,x1,y1")
return tuple(int(x) for x in parts) # type: ignore[return-value]
def _load_rgb_numpy(path: str) -> "object":
import cv2
import numpy as np
bgr = cv2.imread(path, cv2.IMREAD_COLOR)
if bgr is None:
raise FileNotFoundError(f"cv2.imread 失败: {path}")
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return np.ascontiguousarray(rgb, dtype=np.uint8)
def _draw_boxes_on_crop(
slab_rgb,
boxes: list[tuple[int, int, int, int]],
labels: list[str] | None = None,
):
"""slab_rgb: H×W×3 RGB uint8boxes 为扩 margin 后的 Stage2 子框(与线上绿框一致)。"""
import cv2
vis = slab_rgb.copy()
bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
rh, rw = bgr.shape[:2]
for i, (bx0, by0, bx1, by1) in enumerate(boxes):
x0, y0 = int(bx0), int(by0)
x1, y1 = int(bx1) - 1, int(by1) - 1
x1 = max(x0, min(x1, rw - 1))
y1 = max(y0, min(y1, rh - 1))
cv2.rectangle(bgr, (x0, y0), (x1, y1), (0, 255, 0), 2)
tag = labels[i] if labels and i < len(labels) else f"s2_{i}"
cv2.putText(
bgr,
tag,
(x0, max(0, y0 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
1,
cv2.LINE_AA,
)
return bgr
class _PrintLogger:
def info(self, msg):
print(msg)
def warning(self, msg):
print(msg)
def error(self, msg):
print(msg)
def _run_once(yroi_mod, img_rgb, roi_xyxy, logger):
boxes = yroi_mod.try_black_triangle_boxes_work(img_rgb, roi_xyxy, logger)
rx0, ry0, rx1, ry1 = roi_xyxy
slab = img_rgb[ry0:ry1, rx0:rx1].copy()
return boxes, slab
def _copy_dump_raw_rows(yroi_mod, objs):
"""把 Maix detect 返回对象拷贝成基础类型,避免 native 对象跨下一次 detect 存活。"""
rows = []
for o in objs:
cid = yroi_mod._det_obj_class_id(o)
try:
sc = float(getattr(o, "score", 0.0))
except (TypeError, ValueError):
sc = 0.0
rows.append((cid, sc, float(o.x), float(o.y), float(o.w), float(o.h)))
return rows
def _dump_raw_and_hard_exit(det, yroi_mod, slab_for_det, rw_s, rh_s, net_w, net_h, conf_th, iou_th, limit):
"""
MaixPy 某些版本在 YOLO detect 返回对象正常析构时会 SIGSEGV/pure virtual。
raw dump 是诊断路径,打印完成后硬退出,绕过 Python/native 析构链。
"""
from maix import image as maix_image
roi_maix = maix_image.cv2image(slab_for_det, False, False)
raw = det.detect(roi_maix, conf_th=conf_th, iou_th=iou_th)
objs = yroi_mod._normalize_objs(raw if raw is not None else [])
dump_rows = _copy_dump_raw_rows(yroi_mod, objs)
raw_count = len(dump_rows)
print(
f"[DUMP-RAW] slab={rw_s}×{rh_s} net={net_w}×{net_h} "
f"conf={conf_th} iou={iou_th} → NMS 后 raw 框数={raw_count}(与 coord_mode 无关)"
)
npr = min(int(limit), raw_count)
for i in range(npr):
cid, sc, x, y, ww, hh = dump_rows[i]
print(f" #{i} cls={cid} score={sc:.4f} xywh=({x:.3f},{y:.3f},{ww:.3f},{hh:.3f})")
if dump_rows:
xs = [r[2] for r in dump_rows]
ws = [r[4] for r in dump_rows]
print(
f"[DUMP-RAW] hint: x 范围≈[{min(xs):.2f},{max(xs):.2f}] "
f"w 范围≈[{min(ws):.2f},{max(ws):.2f}] — "
f"若整体在 0~{net_w} 量级多为网络画布坐标→应用 letterbox"
f"若 x,w 多在 0~1→可能是归一化需在代码里乘 net 尺寸"
)
print("[INFO] --dump-raw 已完成;为规避 MaixPy YOLO native 析构崩溃,测试进程将直接退出。")
sys.stdout.flush()
sys.stderr.flush()
os._exit(0)
def main():
ap = argparse.ArgumentParser(
description="Stage2 黑三角 YOLO 设备本地图测试",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
ap.add_argument("image", help="本地图片路径(设备上的路径)")
ap.add_argument(
"--roi",
default="",
metavar="x0,y0,x1,y1",
help="可选。若填写image 为整幅图,在此图上取 Stage1 ROI 再跑 Stage2"
"留空image 本身就是 Stage1 裁切图(默认)",
)
ap.add_argument("-o", "--output", default="", help="输出可视化路径;默认 原名_stage2_vis.jpg")
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_BLACK_YOLO_MODEL_PATH")
ap.add_argument("--conf", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_CONF_TH")
ap.add_argument("--iou", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_IOU_TH")
ap.add_argument(
"--coord",
choices=["native", "letterbox"],
default="",
help="覆盖 TRIANGLE_BLACK_YOLO_COORD_MODE默认用 config",
)
ap.add_argument(
"--compare-coord",
action="store_true",
help="各跑一次 native 与 letterbox输出两张图 *_stage2_native.jpg / *_stage2_letterbox.jpg",
)
ap.add_argument(
"--fresh-detector",
action="store_true",
help="清掉 YOLO 缓存再测(换模型或排查缓存时用)",
)
ap.add_argument(
"--allow-save-roi",
action="store_true",
help="不强制关闭 TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP默认测试时会关掉以免写满相册目录",
)
ap.add_argument(
"--dump-raw",
type=int,
default=0,
metavar="N",
help="打印前 N 个 detect 原始框 x,y,w,h,score,clscoord 映射前native/letterbox 共用同一批 raw",
)
args = ap.parse_args()
img_path = os.path.abspath(args.image)
if not os.path.isfile(img_path):
print(f"[ERR] 找不到图片: {img_path}")
sys.exit(1)
try:
import config as cfg
import target_roi_yolo as yroi
except ImportError as e:
print(f"[ERR] 无法导入 config / target_roi_yolo: {e}")
sys.exit(1)
if args.fresh_detector:
yroi.reset_yolo_detector_cache()
# 备份并临时覆盖 config单进程顺序跑
bak: dict[str, object] = {}
def _patch(key: str, val: object):
if key not in bak:
bak[key] = getattr(cfg, key, None)
setattr(cfg, key, val)
def _restore():
for k, v in bak.items():
setattr(cfg, k, v)
try:
_patch("TRIANGLE_BLACK_YOLO_ENABLE", True)
if not args.allow_save_roi:
_patch("TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP", False)
if args.model.strip():
_patch("TRIANGLE_BLACK_YOLO_MODEL_PATH", args.model.strip())
if args.conf is not None:
_patch("TRIANGLE_BLACK_YOLO_CONF_TH", float(args.conf))
if args.iou is not None:
_patch("TRIANGLE_BLACK_YOLO_IOU_TH", float(args.iou))
if args.coord and not args.compare_coord:
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", args.coord)
mp = getattr(cfg, "TRIANGLE_BLACK_YOLO_MODEL_PATH", "") or ""
if not os.path.isfile(mp):
print(f"[ERR] 模型文件不存在: {mp}")
sys.exit(1)
img_rgb = _load_rgb_numpy(img_path)
h, w = int(img_rgb.shape[0]), int(img_rgb.shape[1])
if args.roi.strip():
roi_xyxy = _parse_roi(args.roi.strip())
rx0, ry0, rx1, ry1 = [int(round(float(v))) for v in roi_xyxy]
if rx1 <= rx0 or ry1 <= ry0:
print("[ERR] ROI 无效:需满足 x1>x0 且 y1>y0")
sys.exit(1)
# 与 target_roi_yolo.try_black_triangle_boxes_work 相同的 clip
rx0 = max(0, min(rx0, w - 1))
ry0 = max(0, min(ry0, h - 1))
rx1 = max(rx0 + 1, min(rx1, w))
ry1 = max(ry0 + 1, min(ry1, h))
ring_roi = (rx0, ry0, rx1, ry1)
print(f"[INFO] 模式=整图+ROI ring={ring_roi} image={w}×{h}")
else:
ring_roi = (0, 0, w, h)
print(f"[INFO] 模式=已是 Stage1 裁切 crop={w}×{h}")
logger = _PrintLogger()
det = yroi._get_detector(mp)
if det is None:
print("[ERR] 无法加载 nn.YOLOv5检查模型路径与 Maix 环境)")
sys.exit(1)
net_w = int(det.input_width())
net_h = int(det.input_height())
print(f"[INFO] model={mp} net_in={net_w}×{net_h}")
rx0, ry0, rx1, ry1 = ring_roi
import numpy as np
slab_for_det = np.ascontiguousarray(img_rgb[ry0:ry1, rx0:rx1], dtype=np.uint8).copy()
rh_s, rw_s = int(slab_for_det.shape[0]), int(slab_for_det.shape[1])
modes = ["native", "letterbox"] if args.compare_coord else [
(args.coord or getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", "native"))
]
base, ext = os.path.splitext(img_path)
ext = ext if ext else ".jpg"
for mode in modes:
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
cur_coord = getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
print(f"[INFO] --- TRIANGLE_BLACK_YOLO_COORD_MODE={cur_coord} ---")
boxes, slab = _run_once(yroi, img_rgb, ring_roi, logger)
print(
f"[INFO] 子框数量={len(boxes)} conf={getattr(cfg, 'TRIANGLE_BLACK_YOLO_CONF_TH', '?')} "
f"coord={cur_coord}"
)
for i, b in enumerate(boxes):
print(f" s2_{i}: {b}")
if args.compare_coord:
out_path = f"{base}_stage2_{mode}{ext}"
elif args.output.strip():
out_path = args.output.strip()
else:
out_path = base + "_stage2_vis" + ext
import cv2
bgr = _draw_boxes_on_crop(slab, boxes)
cv2.imwrite(out_path, bgr, [int(cv2.IMWRITE_JPEG_QUALITY), 92])
print(f"[OK] saved: {out_path}")
if args.compare_coord:
print(
"[HINT] contain 训练时若 letterbox 对齐更好,请将 config 里 "
"TRIANGLE_BLACK_YOLO_COORD_MODE 设为 letterbox"
)
if args.dump_raw > 0:
conf_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_CONF_TH", 0.5))
iou_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_IOU_TH", 0.45))
print("\n[INFO] --dump-raw 放在最后执行,避免 raw native 对象影响 compare-coord 流程。")
_dump_raw_and_hard_exit(
det,
yroi,
slab_for_det,
rw_s,
rh_s,
net_w,
net_h,
conf_th,
iou_th,
args.dump_raw,
)
finally:
_restore()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,242 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
单张图片快速测试:三角形四角标记识别 + 单应性落点 + PnP 估距
用法(在板子上):
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --out /root/phot/tri_out.jpg
调参对比(不改代码,临时覆盖 config.TRIANGLE_*
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --preset shadow
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --max-interior-gray 160 --min-dark-ratio 0.20
"""
import argparse
import json
import os
import time
from typing import Any, Dict, Tuple
import cv2
import numpy as np
import config
import triangle_target as tri_mod
from triangle_target import (
detect_triangle_markers,
load_camera_from_xml,
load_triangle_positions,
try_triangle_scoring,
)
def _apply_overrides(args) -> None:
# 预设:阴影/低对比度场景更宽松(尽量保持速度:不启 CLAHE
if args.preset == "shadow":
setattr(config, "TRIANGLE_ENABLE_CLAHE_FALLBACK", False)
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", 0)
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", 160)
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", 160)
setattr(config, "TRIANGLE_MIN_DARK_RATIO", 0.20)
# adaptive 只在 Otsu 失败时尝试,保持尝试次数很少
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", (21,))
# 手动覆盖(优先级高于 preset
if args.max_interior_gray is not None:
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", int(args.max_interior_gray))
if args.dark_pixel_gray is not None:
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", int(args.dark_pixel_gray))
if args.min_dark_ratio is not None:
setattr(config, "TRIANGLE_MIN_DARK_RATIO", float(args.min_dark_ratio))
if args.min_contrast_diff is not None:
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", int(args.min_contrast_diff))
if args.detect_scale is not None:
setattr(config, "TRIANGLE_DETECT_SCALE", float(args.detect_scale))
if args.adaptive_blocks is not None:
bs = tuple(int(x) for x in args.adaptive_blocks.split(",") if x.strip())
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", bs)
def _dump_config() -> Dict[str, Any]:
keys = [
"TRIANGLE_DETECT_SCALE",
"TRIANGLE_SIZE_RANGE",
"TRIANGLE_MAX_INTERIOR_GRAY",
"TRIANGLE_DARK_PIXEL_GRAY",
"TRIANGLE_MIN_DARK_RATIO",
"TRIANGLE_MIN_CONTRAST_DIFF",
"TRIANGLE_ADAPTIVE_BLOCK_SIZES",
"TRIANGLE_MAX_FILTERED_FOR_COMBO",
"TRIANGLE_EARLY_EXIT_CANDIDATES",
"TRIANGLE_ENABLE_CLAHE_FALLBACK",
]
out = {}
for k in keys:
out[k] = getattr(config, k, None)
return out
def _draw_tri_debug(img_bgr: np.ndarray, tri: Dict[str, Any]) -> np.ndarray:
out = img_bgr.copy()
markers = tri.get("markers") or []
# 画三角形轮廓 + center + id
for m in markers:
corners = np.array(m.get("corners", []), dtype=np.int32)
if corners.size == 0:
continue
cv2.polylines(out, [corners], True, (0, 255, 0), 2)
c = m.get("center") or (corners[:, 0].mean(), corners[:, 1].mean())
cx, cy = int(c[0]), int(c[1])
cv2.circle(out, (cx, cy), 4, (0, 0, 255), -1)
mid = m.get("id", "?")
cv2.putText(out, f"T{mid}", (cx - 18, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
# 若有 homography画靶心把 (0,0) 反投影到图像)
H = tri.get("homography")
if H is not None:
try:
H = np.array(H, dtype=np.float64)
H_inv = np.linalg.inv(H)
c_img = cv2.perspectiveTransform(np.array([[[0.0, 0.0]]], dtype=np.float32), H_inv)[0][0]
ocx, ocy = int(c_img[0]), int(c_img[1])
cv2.circle(out, (ocx, ocy), 5, (0, 0, 255), -1)
cv2.circle(out, (ocx, ocy), 10, (0, 0, 255), 1)
except Exception:
pass
# 叠加结果信息
lines = []
if tri.get("ok"):
lines.append("tri_ok=True")
if tri.get("dx_cm") is not None and tri.get("dy_cm") is not None:
lines.append(f"dx,dy=({tri['dx_cm']:.2f},{tri['dy_cm']:.2f})cm")
if tri.get("distance_m") is not None:
lines.append(f"dist={float(tri['distance_m']):.2f}m")
else:
lines.append("tri_ok=False")
y0 = 22
for i, t in enumerate(lines):
cv2.putText(out, t, (10, y0 + i * 18), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True, help="输入图片路径jpg/png")
ap.add_argument("--out", default="", help="输出标注图片路径(可选)")
ap.add_argument("--laser-x", type=int, default=-1, help="激光点 x像素默认用图像中心")
ap.add_argument("--laser-y", type=int, default=-1, help="激光点 y像素默认用图像中心")
ap.add_argument("--preset", choices=["", "shadow"], default="", help="调参预设shadow=阴影更鲁棒,不启 CLAHE")
ap.add_argument("--max-interior-gray", type=int, default=None)
ap.add_argument("--dark-pixel-gray", type=int, default=None)
ap.add_argument("--min-dark-ratio", type=float, default=None)
ap.add_argument("--min-contrast-diff", type=int, default=None)
ap.add_argument("--detect-scale", type=float, default=None)
ap.add_argument("--adaptive-blocks", default=None, help="例如: 11,21 ;为空表示不改")
ap.add_argument("--verbose", action="store_true", help="输出更多检测阶段信息")
args = ap.parse_args()
_apply_overrides(args)
# triangle_target.py 的日志默认写到 logger_manager在离线脚本里 logger 可能未初始化。
# verbose 模式下把 _log 重定向为 print方便直接看到诊断信息。
if args.verbose:
try:
tri_mod._log = lambda msg: print(str(msg))
except Exception:
pass
img_bgr = cv2.imread(args.image, cv2.IMREAD_COLOR)
if img_bgr is None:
raise SystemExit(f"读图失败:{args.image}")
# triangle_target.try_triangle_scoring 约定输入为 RGBOpenCV imread 返回 BGR
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
h, w = img_bgr.shape[:2]
if args.laser_x >= 0 and args.laser_y >= 0:
laser_point = (int(args.laser_x), int(args.laser_y))
else:
laser_point = (w // 2, h // 2)
K, dist = load_camera_from_xml(getattr(config, "CAMERA_CALIB_XML", ""))
pos = load_triangle_positions(getattr(config, "TRIANGLE_POSITIONS_JSON", ""))
print("[tri-test] image:", args.image, "shape:", (h, w))
print("[tri-test] laser_point:", laser_point)
print("[tri-test] calib_ok:", bool(K is not None and dist is not None), "pos_ok:", bool(pos))
print("[tri-test] config:", json.dumps(_dump_config(), ensure_ascii=False))
# 先单独跑一次三角形候选检测,便于区分“没找到候选” vs “找到候选但评分/单应性失败”
scale = float(getattr(config, "TRIANGLE_DETECT_SCALE", 0.5) or 0.5)
if not (0.05 <= scale <= 1.0):
scale = 0.5
long_side = max(h, w)
max_dim = max(64, int(long_side * scale))
if long_side > max_dim:
det_scale = max_dim / long_side
det_w = int(w * det_scale)
det_h = int(h * det_scale)
img_det = cv2.resize(img_bgr, (det_w, det_h), interpolation=cv2.INTER_LINEAR)
inv_scale = 1.0 / det_scale
size_range_det = (
max(4, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[0] * det_scale)),
max(8, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[1] * det_scale)),
)
else:
img_det = img_bgr
inv_scale = 1.0
size_range_det = getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))
gray = cv2.cvtColor(img_det, cv2.COLOR_BGR2GRAY)
markers_det = detect_triangle_markers(
gray,
orig_gray=gray,
size_range=size_range_det,
verbose=bool(args.verbose),
)
if inv_scale != 1.0 and markers_det:
for m in markers_det:
m["center"] = [m["center"][0] * inv_scale, m["center"][1] * inv_scale]
m["corners"] = [[c[0] * inv_scale, c[1] * inv_scale] for c in m["corners"]]
print("[tri-test] markers_found:", len(markers_det), "ids:", [m.get("id") for m in markers_det])
t0 = time.time()
tri = try_triangle_scoring(
img_rgb, # try_triangle_scoring 期望 RGB
laser_point,
pos,
K,
dist,
size_range=getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500)),
)
dt_ms = int(round((time.time() - t0) * 1000))
print("[tri-test] elapsed_ms:", dt_ms)
print(json.dumps(tri, ensure_ascii=False, indent=2))
if args.out:
out_path = args.out
# 允许传目录(如 ./),自动生成文件名;未带扩展名时默认 .jpg
if out_path.endswith("/") or out_path.endswith("\\") or os.path.isdir(out_path):
out_path = os.path.join(out_path, "tri_out.jpg")
root, ext = os.path.splitext(out_path)
if not ext:
out_path = root + ".jpg"
# 若 try_triangle_scoring 失败且没带回 markers至少把候选 markers 画出来,方便肉眼判断
tri_for_draw = tri if isinstance(tri, dict) else {"ok": False}
if not tri_for_draw.get("markers") and markers_det:
tri_for_draw = dict(tri_for_draw)
tri_for_draw["markers"] = markers_det
out_img = _draw_tri_debug(img_bgr, tri_for_draw)
ok = cv2.imwrite(out_path, out_img)
if not ok:
raise SystemExit(f"写图失败(可能是不支持的扩展名):{out_path}")
print("[tri-test] wrote:", out_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,257 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
本地图片 → Maix YOLOv5 检测 → 画框保存(用于核对坐标 mode / 多框 union
运行环境MaixCAM / MaixPy需 maix.image / maix.nn在项目根或任意目录执行均可。
示例:
python test/test_yolo_draw_boxes.py /root/phot/shot_xxx.jpg
python test/test_yolo_draw_boxes.py shot.jpg --loader cv2_rgb --conf 0.25
python test/test_yolo_draw_boxes.py shot.jpg --debug
python -h # 查看 --loader / --debug / --union 等全部参数
脚本版本与设备同步用20260206-yolo-vis
"""
from __future__ import annotations
import argparse
import os
import sys
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
def _load_maix_image(path: str, image_mod):
"""maix.image.load部分 JPEG 解码后与 camera.read() 像素布局不一致,可能导致 NPU 全空)。"""
return image_mod.load(path)
def _load_cv2_rgb_as_maix(path: str, image_mod):
"""
OpenCV 读盘为 BGR → 转 RGB → 与 shoot_manager 里 image2cv 逆过程一致,供 YOLO input type: rgb。
"""
import cv2
arr = cv2.imread(path, cv2.IMREAD_COLOR)
if arr is None:
raise FileNotFoundError(f"cv2.imread 失败: {path}")
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
return image_mod.cv2image(arr, False, False)
def main():
ap = argparse.ArgumentParser(
description="YOLO 画框测试Maix",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="若提示 unrecognized arguments: --debug说明设备上脚本未更新请同步仓库中的 test/test_yolo_draw_boxes.py",
)
ap.add_argument("image", help="输入图片路径")
ap.add_argument("-o", "--output", default="", help="输出图片路径;默认 原名_yolo_vis.jpg")
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_YOLO_MODEL_PATH")
ap.add_argument("--conf", type=float, default=None, help="置信度阈值")
ap.add_argument("--iou", type=float, default=None, help="NMS IoU")
ap.add_argument(
"--coord",
choices=["native", "letterbox"],
default="",
help="坐标映射;默认读 config.TRIANGLE_YOLO_COORD_MODE",
)
ap.add_argument(
"--union",
action="store_true",
help="按 TRIANGLE_YOLO_RING_CLASS_IDS 过滤后画合并外接矩形(与线上 ROI merge=union 一致)",
)
ap.add_argument(
"--loader",
choices=["auto", "maix", "cv2_rgb"],
default="auto",
help="auto: 先 maix.load0 框则改用 cv2 RGB推荐排查「有图但始终 0 框」)",
)
ap.add_argument(
"--debug",
action="store_true",
help="打印 detect 原始返回类型与 repr截断",
)
args = ap.parse_args()
try:
from maix import image, nn
except ImportError:
print("[ERR] 需要 MaixPymaix.image / maix.nn请在 MaixCAM 上运行。")
sys.exit(1)
import config as cfg
import target_roi_yolo as yroi
img_path = os.path.abspath(args.image)
if not os.path.isfile(img_path):
print(f"[ERR] 找不到图片: {img_path}")
sys.exit(1)
model_path = (args.model or getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or "").strip()
if not os.path.isfile(model_path):
print(f"[ERR] 模型文件不存在: {model_path}")
sys.exit(1)
conf_th = (
float(args.conf)
if args.conf is not None
else float(getattr(cfg, "TRIANGLE_YOLO_CONF_TH", 0.5))
)
iou_th = (
float(args.iou)
if args.iou is not None
else float(getattr(cfg, "TRIANGLE_YOLO_IOU_TH", 0.45))
)
coord_mode = (args.coord or getattr(cfg, "TRIANGLE_YOLO_COORD_MODE", "native")).lower()
out_path = args.output.strip()
if not out_path:
base, ext = os.path.splitext(img_path)
ext = ext if ext else ".jpg"
out_path = base + "_yolo_vis" + ext
det = nn.YOLOv5(model=model_path, dual_buff=False)
net_w = int(det.input_width())
net_h = int(det.input_height())
def _run_detect(maix_img, tag: str):
r = det.detect(maix_img, conf_th=conf_th, iou_th=iou_th)
if args.debug:
rlen = len(r) if r is not None and hasattr(r, "__len__") else "n/a"
rrepr = repr(r)
if len(rrepr) > 300:
rrepr = rrepr[:300] + "..."
print(f"[DEBUG] loader={tag} raw_type={type(r)} len={rlen} repr={rrepr}")
return yroi._normalize_objs(r if r is not None else []), maix_img, tag
img = None
load_tag = ""
objs = []
if args.loader == "cv2_rgb":
img = _load_cv2_rgb_as_maix(img_path, image)
load_tag = "cv2_rgb"
objs, img, load_tag = _run_detect(img, load_tag)
elif args.loader == "maix":
img = _load_maix_image(img_path, image)
load_tag = "maix_load"
objs, img, load_tag = _run_detect(img, load_tag)
else:
# auto
img = _load_maix_image(img_path, image)
load_tag = "maix_load"
objs, img, load_tag = _run_detect(img, load_tag)
if len(objs) == 0:
print(
"[WARN] maix.image.load 在 conf_th=%s 下仍为 0 框,改用 cv2 BGR→RGB→cv2image 重试(常见可恢复)"
% conf_th
)
img2 = _load_cv2_rgb_as_maix(img_path, image)
objs, img, load_tag = _run_detect(img2, "cv2_rgb_retry")
src_w, src_h = img.width(), img.height()
labels = getattr(det, "labels", None)
def _label(cid: int) -> str:
if labels is None:
return str(cid)
try:
return str(labels[int(cid)])
except Exception:
return str(cid)
print(
f"[INFO] loader={load_tag} image={src_w}×{src_h}, net_in={net_w}×{net_h}, "
f"coord={coord_mode}, conf_th={conf_th}, iou_th={iou_th}"
)
print(f"[INFO] NMS 后检测框数量={len(objs)}{out_path}")
if len(objs) == 0:
print(
"[HINT] 仍为 0 框时常见原因:\n"
" 1) 强制 cv2 路径: --loader cv2_rgb\n"
" 2) NMS 过严: --iou 0.95\n"
" 3) 图与训练分布差太大 / 模型未见过该场景\n"
" 4) 用 camera.read() 一帧存盘再测,对比 file 与实时是否一致"
)
# 颜色:按类别轮换(仅有 COLOR_* 时常量时用)
color_cycle = []
for name in ("RED", "GREEN", "BLUE", "ORANGE", "YELLOW", "CYAN", "MAGENTA"):
c = getattr(image, f"COLOR_{name}", None)
if c is not None:
color_cycle.append(c)
if not color_cycle:
color_cycle = [getattr(image, "COLOR_RED", 0)]
for i, o in enumerate(objs):
cid = yroi._det_obj_class_id(o)
if cid is None:
cid = -1
try:
sc = float(o.score)
except Exception:
sc = 0.0
x0, y0, x1, y1 = yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h)
ix = int(max(0, min(x0, src_w - 1)))
iy = int(max(0, min(y0, src_h - 1)))
iw = int(max(1, min(x1 - x0, src_w - ix)))
ih = int(max(1, min(y1 - y0, src_h - iy)))
col = color_cycle[cid % len(color_cycle)] if cid >= 0 else color_cycle[0]
img.draw_rect(ix, iy, iw, ih, color=col)
ty = max(0, iy - 14)
msg = f"{_label(cid)} {sc:.2f}"
img.draw_string(ix, ty, msg, color=col)
print(f" #{i} cls={cid} {_label(cid)} score={sc:.3f} xywh=({ix},{iy},{iw},{ih})")
if args.union:
class_ids = getattr(cfg, "TRIANGLE_YOLO_RING_CLASS_IDS", (0,))
if isinstance(class_ids, int):
class_ids = (class_ids,)
cand = [o for o in objs if yroi._det_obj_class_id(o) in class_ids]
if cand:
xy_list = [
yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h) for o in cand
]
merged = yroi._merge_roi_xyxy(xy_list, "union")
if merged:
mx0, my0, mx1, my1 = merged
mx0 = max(0, min(mx0, src_w - 1))
my0 = max(0, min(my0, src_h - 1))
mx1 = max(mx0 + 1, min(mx1, src_w))
my1 = max(my0 + 1, min(my1, src_h))
uw, uh = int(mx1 - mx0), int(my1 - my0)
ucol = getattr(image, "COLOR_GREEN", color_cycle[0])
# 画粗一点的 union描两遍错位矩形简易模拟加粗
for d in (0, 2):
img.draw_rect(
int(mx0) - d,
int(my0) - d,
uw + 2 * d,
uh + 2 * d,
color=ucol,
)
img.draw_string(
int(mx0),
max(0, int(my0) - 28),
f"UNION ({len(cand)} boxes)",
color=ucol,
)
print(f"[INFO] UNION [{int(mx0)},{int(my0)},{int(mx1)},{int(my1)}] from {len(cand)} boxes")
else:
print("[WARN] --union 但 RING_CLASS_IDS 过滤后无框")
try:
img.save(out_path, quality=95)
except TypeError:
img.save(out_path)
print(f"[OK] saved: {out_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,506 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
YOLO11 关键点检测训练脚本(靶纸四角)。
设备优先级(--device autoIntel XPU > NVIDIA CUDA > CPU。
默认 imgsz=960批大小默认 4大图显存紧张时可再降
关于「业务像素误差」:
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
- 监控:--pixel-metrics-every N每 N 个 epoch 打印 mean/p95并合并进 runs/.../results.csv见 pose_pixel_metrics.py
- 选 best.pt / early stopping加 --best-by-pixel用验证集 mean 像素误差(与 pose_pixel_metrics
同一口径)代替 mAP 合成 fitnessfitness = -mean_px越小越好
多卡 DDPworld_size>1时会自动退回默认 mAP fitness。
XPUUltralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu
务必使用 pose 任务YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect
会把 17 列 Pose 标注当成检测/分割解析校验时出现「coordinates > 1」或 [2.] 等假象。
"""
from __future__ import annotations
import argparse
import csv
import gc
import glob
import os
import tempfile
from copy import deepcopy
from pathlib import Path
import torch
from ultralytics import YOLO
from pose_pixel_metrics import eval_val_pixel_error
import warnings
warnings.filterwarnings('ignore',
message=".*scatter_add_kernel does not have a deterministic implementation.*")
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
"""删除 data.yaml 的 path 下 labels/*.cache。
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
from ultralytics.utils import YAML
try:
cfg = YAML.load(data_yaml_path)
except Exception:
return 0
root = cfg.get("path")
if not root:
return 0
root = os.path.abspath(os.path.expanduser(str(root)))
pattern = os.path.join(root, "labels", "*.cache")
n = 0
for p in glob.glob(pattern):
try:
os.unlink(p)
n += 1
except OSError:
pass
return n
def _pick_device(explicit: str | None):
"""返回 ultralytics train/predict 可用的 device。"""
if explicit and explicit != "auto":
e = explicit.lower()
if e == "xpu":
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
return torch.device("xpu")
if e in ("0", "cuda", "gpu"):
if not torch.cuda.is_available():
raise RuntimeError("指定了 CUDA 但不可用")
return 0
if e == "cpu":
return "cpu"
return explicit
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
return torch.device("xpu")
if torch.cuda.is_available():
return 0
return "cpu"
def _default_amp(device) -> bool:
if isinstance(device, torch.device) and device.type == "xpu":
return False
if device == "cpu":
return False
return True
def _patch_ultralytics_for_xpu():
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
import ultralytics.engine.trainer as ut_trainer
import ultralytics.engine.validator as ut_validator
from ultralytics.utils.torch_utils import select_device as _original_select_device
# 1. 覆盖 select_deviceTrainer 初始化传入 torch.device("xpu") 会走原版早返回;
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device不调用 select_device
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
def _patched_select_device(device="", *args, **kwargs):
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
# Older forks sometimes passed extra positional args; forward everything.
if isinstance(device, str):
d = device.strip().lower()
if d == "xpu" or d.startswith("xpu:"):
return torch.device(device.strip())
return _original_select_device(device, *args, **kwargs)
import ultralytics.utils.torch_utils
ultralytics.utils.torch_utils.select_device = _patched_select_device
ut_trainer.select_device = _patched_select_device
ut_validator.select_device = _patched_select_device
# 2. 修补 Trainer 的内存函数
BT = ut_trainer.BaseTrainer
if not getattr(BT, "_archery_xpu_memory_patched", False):
_orig_get_memory = BT._get_memory
_orig_clear_memory = BT._clear_memory
def _get_memory(self, fraction=False):
if self.device.type != "xpu":
return _orig_get_memory(self, fraction)
# ... (原有的 XPU 内存获取逻辑保持不变) ...
memory, total = 0, 0
try:
idx = self.device.index
if idx is None:
idx = torch.xpu.current_device()
memory = int(torch.xpu.memory_allocated(idx))
if fraction:
total = int(torch.xpu.get_device_properties(idx).total_memory)
except Exception:
pass
return (memory / total) if fraction and total > 0 else (memory / 2**30)
def _clear_memory(self, threshold=None):
if self.device.type != "xpu":
return _orig_clear_memory(self, threshold)
if threshold is not None:
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
if self._get_memory(fraction=True) <= threshold:
return
gc.collect()
if hasattr(torch.xpu, "empty_cache"):
torch.xpu.empty_cache()
BT._get_memory = _get_memory
BT._clear_memory = _clear_memory
BT._archery_xpu_memory_patched = True
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
BV = ut_validator.BaseValidator
if not getattr(BV, "_archery_xpu_memory_patched", False):
# 为 Validator 添加同样的内存处理方法
BV._get_memory = _get_memory
BV._clear_memory = _clear_memory
BV._archery_xpu_memory_patched = True
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
"""用验证集关键点像素 mean 替代 mAP fitness驱动 best.pt 与 patience early stopping。"""
import ultralytics.engine.trainer as ut
from ultralytics.utils import RANK
BT = ut.BaseTrainer
if getattr(BT, "_archery_best_by_pixel_installed", False):
return
_orig_validate = BT.validate
def validate(self):
import torch.distributed as dist
if self.ema and self.world_size > 1:
for buffer in self.ema.ema.buffers():
dist.broadcast(buffer, src=0)
metrics = self.validator(self)
if metrics is None:
return None, None
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
mean_px: float | None = None
if use_pixel:
tmp_path: str | None = None
try:
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
os.close(fd)
from ultralytics.utils.torch_utils import unwrap_model
core = unwrap_model(self.ema.ema if self.ema else self.model)
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
probe = YOLO(tmp_path)
stats = eval_val_pixel_error(
probe,
data_yaml,
device=self.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
if mean_px is None:
raise RuntimeError("无有效 mean_px检查 val 标签与检测是否为空)")
except Exception as exc:
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
mean_px = None
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
if mean_px is not None:
fitness = -float(mean_px)
metrics["metrics/mean_px(val)"] = float(mean_px)
else:
fitness = float(orig_fitness)
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
BT.validate = validate
BT._archery_best_by_pixel_installed = True
def _fmt_csv_metric(v: float | int | None) -> str:
if v is None:
return ""
if isinstance(v, float):
return f"{v:.6g}"
return str(v)
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
("pixel_error/mean_px", "mean_px"),
("pixel_error/median_px", "median_px"),
("pixel_error/p95_px", "p95_px"),
("pixel_error/max_px", "max_px"),
("pixel_error/n_points", "n_points"),
("pixel_error/n_images", "n_images"),
("pixel_error/skip_no_det", "skip_no_det"),
("pixel_error/skip_no_gt", "skip_no_gt"),
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
)
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv扩展表头、补空列"""
csv_path = Path(save_dir) / "results.csv"
if not csv_path.is_file():
return
try:
with open(csv_path, newline="", encoding="utf-8") as f:
rows = list(csv.reader(f))
except OSError:
return
if len(rows) < 2:
return
header = list(rows[0])
for col_name, _ in _PIXEL_METRIC_COLUMNS:
if col_name not in header:
header.append(col_name)
for ri in range(1, len(rows)):
rows[ri].append("")
col_ix = {name: i for i, name in enumerate(header)}
rows[0] = header
target_ri: int | None = None
for ri in range(1, len(rows)):
row = rows[ri]
while len(row) < len(header):
row.append("")
try:
if int(float(row[0].strip())) == int(epoch_1based):
target_ri = ri
except (ValueError, IndexError):
continue
if target_ri is None:
return
row = rows[target_ri]
while len(row) < len(header):
row.append("")
for col_name, sk in _PIXEL_METRIC_COLUMNS:
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
try:
with open(csv_path, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerows(rows)
except OSError:
pass
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
def on_fit_epoch_end(trainer):
from ultralytics.utils import RANK
if RANK not in {-1, 0}:
return
if every <= 0:
return
ep = int(getattr(trainer, "epoch", -1))
if (ep + 1) % every != 0:
return
w = Path(trainer.save_dir) / "weights" / "last.pt"
if not w.is_file():
return
m = YOLO(str(w))
stats = eval_val_pixel_error(
m,
data_yaml,
device=trainer.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
p95_px = stats.get("p95_px")
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
print(
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
f"n_points={stats.get('n_points', 0)} "
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
)
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
return on_fit_epoch_end
def main():
ap = argparse.ArgumentParser(description="YOLO Pose 训练XPU/CUDA/CPU")
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
ap.add_argument("--epochs", type=int, default=100)
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960")
ap.add_argument("--batch", type=int, default=4, help="批大小OOM 时减小")
ap.add_argument(
"--device",
default="auto",
help="auto | xpu | 0 | cuda | cpuautoXPU 优先)",
)
ap.add_argument(
"--no-amp",
action="store_true",
help="关闭混合精度默认CUDA 开启XPU/CPU 关闭)",
)
ap.add_argument("--project", default="runs/pose")
ap.add_argument("--name", default="target_pose_train")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument(
"--pixel-metrics-every",
type=int,
default=0,
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行0=关闭);需 labels 与 data.yaml 布局一致",
)
ap.add_argument(
"--pixel-metrics-conf",
type=float,
default=0.25,
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25",
)
ap.add_argument(
"--best-by-pixel",
action="store_true",
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metricsfitness=-mean_px单卡有效DDP 自动退回 mAP",
)
ap.add_argument(
"--pixel-fitness-conf",
type=float,
default=0.25,
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
)
ap.add_argument(
"--export-onnx",
action="store_true",
help="训练结束后导出 ONNX需再设 --onnx-imgsz",
)
ap.add_argument(
"--onnx-imgsz",
type=int,
nargs=2,
metavar=("H", "W"),
default=[224, 320],
help="导出 ONNX 的 [高, 宽],默认 224 320Maix 常用)",
)
ap.add_argument(
"--clear-label-cache",
action="store_true",
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache修正标注后仍报 corrupt 时用)",
)
args = ap.parse_args()
device = _pick_device(None if args.device == "auto" else args.device)
use_amp = False if args.no_amp else _default_amp(device)
if isinstance(device, torch.device) and device.type == "xpu":
print(f"✅ 使用 Intel XPU: {device}")
elif device == 0 or device == "0":
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ 使用 CPU训练会较慢")
if isinstance(device, torch.device) and device.type == "xpu":
_patch_ultralytics_for_xpu()
data_yaml = args.data
if not os.path.isabs(data_yaml):
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
if not os.path.exists(data_yaml):
print(f"❌ 数据集配置不存在: {data_yaml}")
return
if args.clear_label_cache:
n_rm = _clear_ultralytics_label_caches(data_yaml)
print(f"🗑️ 已删除标签目录缓存 {n_rm}labels/*.cache将强制重新扫描标注。")
print(f"📦 加载模型: {args.model}(固定 task=pose")
model = YOLO(args.model, task="pose")
if args.best_by_pixel:
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
print(
"📌 已启用 --best-by-pixelbest.pt / patience 按验证集 mean 像素误差fitness=-mean_px"
"反向传播仍为 Ultralytics 默认 pose/box loss。"
)
if args.pixel_metrics_every > 0:
model.add_callback(
"on_fit_epoch_end",
_make_pixel_metrics_callback(
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
),
)
model.train(
task="pose",
data=data_yaml,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
name=args.name,
project=args.project,
exist_ok=True,
save=True,
save_period=5,
device=device,
workers=args.workers,
lr0=0.0001,
lrf=0.01,
optimizer="AdamW",
momentum=0.937,
weight_decay=0.001,
warmup_epochs=0,
warmup_momentum=0.8,
warmup_bias_lr=0.1,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=5.0,
translate=0.0,
scale=0.2,
shear=0.0,
perspective=0.0000,
flipud=0.0,
fliplr=0.5,
mosaic=0.0,
mixup=0.0,
copy_paste=0.0,
box=6,
cls=0.5,
dfl=1.5,
pose=18.0,
kobj=0.5,
freeze=0,
seed=42,
verbose=True,
amp=use_amp,
patience=100,
cos_lr=True,
)
print("\n✅ 训练完成!")
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
if args.export_onnx:
h, w = args.onnx_imgsz
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
print("✅ ONNX 完成")
if __name__ == "__main__":
main()