#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ YOLO11 关键点检测训练脚本(靶纸四角)。 设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。 默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。 关于「业务像素误差」: Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。 - 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。 - 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics 同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。 多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。 XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA, 会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。 务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect, 会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。 """ from __future__ import annotations import argparse import csv import gc import glob import os import tempfile from copy import deepcopy from pathlib import Path import torch from ultralytics import YOLO from pose_pixel_metrics import eval_val_pixel_error import warnings warnings.filterwarnings('ignore', message=".*scatter_add_kernel does not have a deterministic implementation.*") def _clear_ultralytics_label_caches(data_yaml_path: str) -> int: """删除 data.yaml 的 path 下 labels/*.cache。 Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容; 修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。""" from ultralytics.utils import YAML try: cfg = YAML.load(data_yaml_path) except Exception: return 0 root = cfg.get("path") if not root: return 0 root = os.path.abspath(os.path.expanduser(str(root))) pattern = os.path.join(root, "labels", "*.cache") n = 0 for p in glob.glob(pattern): try: os.unlink(p) n += 1 except OSError: pass return n def _pick_device(explicit: str | None): """返回 ultralytics train/predict 可用的 device。""" if explicit and explicit != "auto": e = explicit.lower() if e == "xpu": if getattr(torch, "xpu", None) is None or not torch.xpu.is_available(): raise RuntimeError("指定了 --device xpu 但当前环境不可用") return torch.device("xpu") if e in ("0", "cuda", "gpu"): if not torch.cuda.is_available(): raise RuntimeError("指定了 CUDA 但不可用") return 0 if e == "cpu": return "cpu" return explicit if getattr(torch, "xpu", None) is not None and torch.xpu.is_available(): return torch.device("xpu") if torch.cuda.is_available(): return 0 return "cpu" def _default_amp(device) -> bool: if isinstance(device, torch.device) and device.type == "xpu": return False if device == "cpu": return False return True def _patch_ultralytics_for_xpu(): """为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。""" import ultralytics.engine.trainer as ut_trainer import ultralytics.engine.validator as ut_validator from ultralytics.utils.torch_utils import select_device as _original_select_device # 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回; # 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device; # 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数, # 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。 def _patched_select_device(device="", *args, **kwargs): # Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True) # Older forks sometimes passed extra positional args; forward everything. if isinstance(device, str): d = device.strip().lower() if d == "xpu" or d.startswith("xpu:"): return torch.device(device.strip()) return _original_select_device(device, *args, **kwargs) import ultralytics.utils.torch_utils ultralytics.utils.torch_utils.select_device = _patched_select_device ut_trainer.select_device = _patched_select_device ut_validator.select_device = _patched_select_device # 2. 修补 Trainer 的内存函数 BT = ut_trainer.BaseTrainer if not getattr(BT, "_archery_xpu_memory_patched", False): _orig_get_memory = BT._get_memory _orig_clear_memory = BT._clear_memory def _get_memory(self, fraction=False): if self.device.type != "xpu": return _orig_get_memory(self, fraction) # ... (原有的 XPU 内存获取逻辑保持不变) ... memory, total = 0, 0 try: idx = self.device.index if idx is None: idx = torch.xpu.current_device() memory = int(torch.xpu.memory_allocated(idx)) if fraction: total = int(torch.xpu.get_device_properties(idx).total_memory) except Exception: pass return (memory / total) if fraction and total > 0 else (memory / 2**30) def _clear_memory(self, threshold=None): if self.device.type != "xpu": return _orig_clear_memory(self, threshold) if threshold is not None: assert 0 <= threshold <= 1, "Threshold must be between 0 and 1." if self._get_memory(fraction=True) <= threshold: return gc.collect() if hasattr(torch.xpu, "empty_cache"): torch.xpu.empty_cache() BT._get_memory = _get_memory BT._clear_memory = _clear_memory BT._archery_xpu_memory_patched = True # 3. 修补 Validator 的内存函数 (关键是添加这部分) BV = ut_validator.BaseValidator if not getattr(BV, "_archery_xpu_memory_patched", False): # 为 Validator 添加同样的内存处理方法 BV._get_memory = _get_memory BV._clear_memory = _clear_memory BV._archery_xpu_memory_patched = True def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None: """用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。""" import ultralytics.engine.trainer as ut from ultralytics.utils import RANK BT = ut.BaseTrainer if getattr(BT, "_archery_best_by_pixel_installed", False): return _orig_validate = BT.validate def validate(self): import torch.distributed as dist if self.ema and self.world_size > 1: for buffer in self.ema.ema.buffers(): dist.broadcast(buffer, src=0) metrics = self.validator(self) if metrics is None: return None, None orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) use_pixel = self.world_size <= 1 and RANK in {-1, 0} mean_px: float | None = None if use_pixel: tmp_path: str | None = None try: fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_") os.close(fd) from ultralytics.utils.torch_utils import unwrap_model core = unwrap_model(self.ema.ema if self.ema else self.model) torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path) probe = YOLO(tmp_path) stats = eval_val_pixel_error( probe, data_yaml, device=self.device, imgsz=imgsz, conf=conf, ) mean_px = stats.get("mean_px") if mean_px is None: raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)") except Exception as exc: print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n") mean_px = None finally: if tmp_path: try: os.unlink(tmp_path) except OSError: pass if mean_px is not None: fitness = -float(mean_px) metrics["metrics/mean_px(val)"] = float(mean_px) else: fitness = float(orig_fitness) if not self.best_fitness or self.best_fitness < fitness: self.best_fitness = fitness return metrics, fitness BT.validate = validate BT._archery_best_by_pixel_installed = True def _fmt_csv_metric(v: float | int | None) -> str: if v is None: return "" if isinstance(v, float): return f"{v:.6g}" return str(v) # 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行) _PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = ( ("pixel_error/mean_px", "mean_px"), ("pixel_error/median_px", "median_px"), ("pixel_error/p95_px", "p95_px"), ("pixel_error/max_px", "max_px"), ("pixel_error/n_points", "n_points"), ("pixel_error/n_images", "n_images"), ("pixel_error/skip_no_det", "skip_no_det"), ("pixel_error/skip_no_gt", "skip_no_gt"), ("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"), ) def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None: """在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。""" csv_path = Path(save_dir) / "results.csv" if not csv_path.is_file(): return try: with open(csv_path, newline="", encoding="utf-8") as f: rows = list(csv.reader(f)) except OSError: return if len(rows) < 2: return header = list(rows[0]) for col_name, _ in _PIXEL_METRIC_COLUMNS: if col_name not in header: header.append(col_name) for ri in range(1, len(rows)): rows[ri].append("") col_ix = {name: i for i, name in enumerate(header)} rows[0] = header target_ri: int | None = None for ri in range(1, len(rows)): row = rows[ri] while len(row) < len(header): row.append("") try: if int(float(row[0].strip())) == int(epoch_1based): target_ri = ri except (ValueError, IndexError): continue if target_ri is None: return row = rows[target_ri] while len(row) < len(header): row.append("") for col_name, sk in _PIXEL_METRIC_COLUMNS: row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk)) try: with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.writer(f) w.writerows(rows) except OSError: pass def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25): def on_fit_epoch_end(trainer): from ultralytics.utils import RANK if RANK not in {-1, 0}: return if every <= 0: return ep = int(getattr(trainer, "epoch", -1)) if (ep + 1) % every != 0: return w = Path(trainer.save_dir) / "weights" / "last.pt" if not w.is_file(): return m = YOLO(str(w)) stats = eval_val_pixel_error( m, data_yaml, device=trainer.device, imgsz=imgsz, conf=conf, ) mean_px = stats.get("mean_px") p95_px = stats.get("p95_px") mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a" p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a" print( f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} " f"n_points={stats.get('n_points', 0)} " f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n" ) _merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats) return on_fit_epoch_end def main(): ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)") ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml") ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重") ap.add_argument("--epochs", type=int, default=100) ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)") ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小") ap.add_argument( "--device", default="auto", help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)", ) ap.add_argument( "--no-amp", action="store_true", help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)", ) ap.add_argument("--project", default="runs/pose") ap.add_argument("--name", default="target_pose_train") ap.add_argument("--workers", type=int, default=4) ap.add_argument( "--pixel-metrics-every", type=int, default=0, help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致", ) ap.add_argument( "--pixel-metrics-conf", type=float, default=0.25, help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)", ) ap.add_argument( "--best-by-pixel", action="store_true", help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP", ) ap.add_argument( "--pixel-fitness-conf", type=float, default=0.25, help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)", ) ap.add_argument( "--export-onnx", action="store_true", help="训练结束后导出 ONNX(需再设 --onnx-imgsz)", ) ap.add_argument( "--onnx-imgsz", type=int, nargs=2, metavar=("H", "W"), default=[224, 320], help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)", ) ap.add_argument( "--clear-label-cache", action="store_true", help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)", ) args = ap.parse_args() device = _pick_device(None if args.device == "auto" else args.device) use_amp = False if args.no_amp else _default_amp(device) if isinstance(device, torch.device) and device.type == "xpu": print(f"✅ 使用 Intel XPU: {device}") elif device == 0 or device == "0": print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}") else: print("⚠️ 使用 CPU,训练会较慢") if isinstance(device, torch.device) and device.type == "xpu": _patch_ultralytics_for_xpu() data_yaml = args.data if not os.path.isabs(data_yaml): data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml) if not os.path.exists(data_yaml): print(f"❌ 数据集配置不存在: {data_yaml}") return if args.clear_label_cache: n_rm = _clear_ultralytics_label_caches(data_yaml) print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。") print(f"📦 加载模型: {args.model}(固定 task=pose)") model = YOLO(args.model, task="pose") if args.best_by_pixel: _install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf) print( "📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);" "反向传播仍为 Ultralytics 默认 pose/box loss。" ) if args.pixel_metrics_every > 0: model.add_callback( "on_fit_epoch_end", _make_pixel_metrics_callback( data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf ), ) model.train( task="pose", data=data_yaml, epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, name=args.name, project=args.project, exist_ok=True, save=True, save_period=5, device=device, workers=args.workers, lr0=0.0001, lrf=0.01, optimizer="AdamW", momentum=0.937, weight_decay=0.001, warmup_epochs=0, warmup_momentum=0.8, warmup_bias_lr=0.1, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=5.0, translate=0.0, scale=0.2, shear=0.0, perspective=0.0000, flipud=0.0, fliplr=0.5, mosaic=0.0, mixup=0.0, copy_paste=0.0, box=6, cls=0.5, dfl=1.5, pose=18.0, kobj=0.5, freeze=0, seed=42, verbose=True, amp=use_amp, patience=100, cos_lr=True, ) print("\n✅ 训练完成!") print(f"📁 best: {args.project}/{args.name}/weights/best.pt") print(f"📁 last: {args.project}/{args.name}/weights/last.pt") print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model --data --imgsz", args.imgsz) if args.export_onnx: h, w = args.onnx_imgsz print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...") model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False) print("✅ ONNX 完成") if __name__ == "__main__": main()