507 lines
18 KiB
Python
507 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
YOLO11 关键点检测训练脚本(靶纸四角)。
|
||
|
||
设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。
|
||
默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。
|
||
|
||
关于「业务像素误差」:
|
||
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
|
||
- 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。
|
||
- 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics
|
||
同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。
|
||
多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。
|
||
|
||
XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA,
|
||
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。
|
||
|
||
务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect,
|
||
会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import gc
|
||
import glob
|
||
import os
|
||
import tempfile
|
||
from copy import deepcopy
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
from pose_pixel_metrics import eval_val_pixel_error
|
||
import warnings
|
||
warnings.filterwarnings('ignore',
|
||
message=".*scatter_add_kernel does not have a deterministic implementation.*")
|
||
|
||
|
||
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
|
||
"""删除 data.yaml 的 path 下 labels/*.cache。
|
||
|
||
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
|
||
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
|
||
from ultralytics.utils import YAML
|
||
|
||
try:
|
||
cfg = YAML.load(data_yaml_path)
|
||
except Exception:
|
||
return 0
|
||
root = cfg.get("path")
|
||
if not root:
|
||
return 0
|
||
root = os.path.abspath(os.path.expanduser(str(root)))
|
||
pattern = os.path.join(root, "labels", "*.cache")
|
||
n = 0
|
||
for p in glob.glob(pattern):
|
||
try:
|
||
os.unlink(p)
|
||
n += 1
|
||
except OSError:
|
||
pass
|
||
return n
|
||
|
||
|
||
def _pick_device(explicit: str | None):
|
||
"""返回 ultralytics train/predict 可用的 device。"""
|
||
if explicit and explicit != "auto":
|
||
e = explicit.lower()
|
||
if e == "xpu":
|
||
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
|
||
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
|
||
return torch.device("xpu")
|
||
if e in ("0", "cuda", "gpu"):
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("指定了 CUDA 但不可用")
|
||
return 0
|
||
if e == "cpu":
|
||
return "cpu"
|
||
return explicit
|
||
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
|
||
return torch.device("xpu")
|
||
if torch.cuda.is_available():
|
||
return 0
|
||
return "cpu"
|
||
|
||
|
||
def _default_amp(device) -> bool:
|
||
if isinstance(device, torch.device) and device.type == "xpu":
|
||
return False
|
||
if device == "cpu":
|
||
return False
|
||
return True
|
||
|
||
|
||
def _patch_ultralytics_for_xpu():
|
||
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
|
||
import ultralytics.engine.trainer as ut_trainer
|
||
import ultralytics.engine.validator as ut_validator
|
||
from ultralytics.utils.torch_utils import select_device as _original_select_device
|
||
|
||
# 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回;
|
||
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device;
|
||
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
|
||
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
|
||
def _patched_select_device(device="", *args, **kwargs):
|
||
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
|
||
# Older forks sometimes passed extra positional args; forward everything.
|
||
if isinstance(device, str):
|
||
d = device.strip().lower()
|
||
if d == "xpu" or d.startswith("xpu:"):
|
||
return torch.device(device.strip())
|
||
return _original_select_device(device, *args, **kwargs)
|
||
|
||
import ultralytics.utils.torch_utils
|
||
|
||
ultralytics.utils.torch_utils.select_device = _patched_select_device
|
||
ut_trainer.select_device = _patched_select_device
|
||
ut_validator.select_device = _patched_select_device
|
||
|
||
# 2. 修补 Trainer 的内存函数
|
||
BT = ut_trainer.BaseTrainer
|
||
if not getattr(BT, "_archery_xpu_memory_patched", False):
|
||
_orig_get_memory = BT._get_memory
|
||
_orig_clear_memory = BT._clear_memory
|
||
|
||
def _get_memory(self, fraction=False):
|
||
if self.device.type != "xpu":
|
||
return _orig_get_memory(self, fraction)
|
||
# ... (原有的 XPU 内存获取逻辑保持不变) ...
|
||
memory, total = 0, 0
|
||
try:
|
||
idx = self.device.index
|
||
if idx is None:
|
||
idx = torch.xpu.current_device()
|
||
memory = int(torch.xpu.memory_allocated(idx))
|
||
if fraction:
|
||
total = int(torch.xpu.get_device_properties(idx).total_memory)
|
||
except Exception:
|
||
pass
|
||
return (memory / total) if fraction and total > 0 else (memory / 2**30)
|
||
|
||
def _clear_memory(self, threshold=None):
|
||
if self.device.type != "xpu":
|
||
return _orig_clear_memory(self, threshold)
|
||
if threshold is not None:
|
||
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
||
if self._get_memory(fraction=True) <= threshold:
|
||
return
|
||
gc.collect()
|
||
if hasattr(torch.xpu, "empty_cache"):
|
||
torch.xpu.empty_cache()
|
||
|
||
BT._get_memory = _get_memory
|
||
BT._clear_memory = _clear_memory
|
||
BT._archery_xpu_memory_patched = True
|
||
|
||
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
|
||
BV = ut_validator.BaseValidator
|
||
if not getattr(BV, "_archery_xpu_memory_patched", False):
|
||
# 为 Validator 添加同样的内存处理方法
|
||
BV._get_memory = _get_memory
|
||
BV._clear_memory = _clear_memory
|
||
BV._archery_xpu_memory_patched = True
|
||
|
||
|
||
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
|
||
"""用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。"""
|
||
import ultralytics.engine.trainer as ut
|
||
from ultralytics.utils import RANK
|
||
|
||
BT = ut.BaseTrainer
|
||
if getattr(BT, "_archery_best_by_pixel_installed", False):
|
||
return
|
||
|
||
_orig_validate = BT.validate
|
||
|
||
def validate(self):
|
||
import torch.distributed as dist
|
||
|
||
if self.ema and self.world_size > 1:
|
||
for buffer in self.ema.ema.buffers():
|
||
dist.broadcast(buffer, src=0)
|
||
metrics = self.validator(self)
|
||
if metrics is None:
|
||
return None, None
|
||
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
|
||
|
||
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
|
||
mean_px: float | None = None
|
||
if use_pixel:
|
||
tmp_path: str | None = None
|
||
try:
|
||
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
|
||
os.close(fd)
|
||
from ultralytics.utils.torch_utils import unwrap_model
|
||
|
||
core = unwrap_model(self.ema.ema if self.ema else self.model)
|
||
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
|
||
probe = YOLO(tmp_path)
|
||
stats = eval_val_pixel_error(
|
||
probe,
|
||
data_yaml,
|
||
device=self.device,
|
||
imgsz=imgsz,
|
||
conf=conf,
|
||
)
|
||
mean_px = stats.get("mean_px")
|
||
if mean_px is None:
|
||
raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)")
|
||
except Exception as exc:
|
||
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
|
||
mean_px = None
|
||
finally:
|
||
if tmp_path:
|
||
try:
|
||
os.unlink(tmp_path)
|
||
except OSError:
|
||
pass
|
||
|
||
if mean_px is not None:
|
||
fitness = -float(mean_px)
|
||
metrics["metrics/mean_px(val)"] = float(mean_px)
|
||
else:
|
||
fitness = float(orig_fitness)
|
||
|
||
if not self.best_fitness or self.best_fitness < fitness:
|
||
self.best_fitness = fitness
|
||
return metrics, fitness
|
||
|
||
BT.validate = validate
|
||
BT._archery_best_by_pixel_installed = True
|
||
|
||
|
||
def _fmt_csv_metric(v: float | int | None) -> str:
|
||
if v is None:
|
||
return ""
|
||
if isinstance(v, float):
|
||
return f"{v:.6g}"
|
||
return str(v)
|
||
|
||
|
||
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
|
||
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
|
||
("pixel_error/mean_px", "mean_px"),
|
||
("pixel_error/median_px", "median_px"),
|
||
("pixel_error/p95_px", "p95_px"),
|
||
("pixel_error/max_px", "max_px"),
|
||
("pixel_error/n_points", "n_points"),
|
||
("pixel_error/n_images", "n_images"),
|
||
("pixel_error/skip_no_det", "skip_no_det"),
|
||
("pixel_error/skip_no_gt", "skip_no_gt"),
|
||
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
|
||
)
|
||
|
||
|
||
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
|
||
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。"""
|
||
csv_path = Path(save_dir) / "results.csv"
|
||
if not csv_path.is_file():
|
||
return
|
||
try:
|
||
with open(csv_path, newline="", encoding="utf-8") as f:
|
||
rows = list(csv.reader(f))
|
||
except OSError:
|
||
return
|
||
if len(rows) < 2:
|
||
return
|
||
header = list(rows[0])
|
||
for col_name, _ in _PIXEL_METRIC_COLUMNS:
|
||
if col_name not in header:
|
||
header.append(col_name)
|
||
for ri in range(1, len(rows)):
|
||
rows[ri].append("")
|
||
col_ix = {name: i for i, name in enumerate(header)}
|
||
rows[0] = header
|
||
target_ri: int | None = None
|
||
for ri in range(1, len(rows)):
|
||
row = rows[ri]
|
||
while len(row) < len(header):
|
||
row.append("")
|
||
try:
|
||
if int(float(row[0].strip())) == int(epoch_1based):
|
||
target_ri = ri
|
||
except (ValueError, IndexError):
|
||
continue
|
||
if target_ri is None:
|
||
return
|
||
row = rows[target_ri]
|
||
while len(row) < len(header):
|
||
row.append("")
|
||
for col_name, sk in _PIXEL_METRIC_COLUMNS:
|
||
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
|
||
try:
|
||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||
w = csv.writer(f)
|
||
w.writerows(rows)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
|
||
def on_fit_epoch_end(trainer):
|
||
from ultralytics.utils import RANK
|
||
|
||
if RANK not in {-1, 0}:
|
||
return
|
||
if every <= 0:
|
||
return
|
||
ep = int(getattr(trainer, "epoch", -1))
|
||
if (ep + 1) % every != 0:
|
||
return
|
||
w = Path(trainer.save_dir) / "weights" / "last.pt"
|
||
if not w.is_file():
|
||
return
|
||
m = YOLO(str(w))
|
||
stats = eval_val_pixel_error(
|
||
m,
|
||
data_yaml,
|
||
device=trainer.device,
|
||
imgsz=imgsz,
|
||
conf=conf,
|
||
)
|
||
mean_px = stats.get("mean_px")
|
||
p95_px = stats.get("p95_px")
|
||
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
|
||
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
|
||
print(
|
||
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
|
||
f"n_points={stats.get('n_points', 0)} "
|
||
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
|
||
)
|
||
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
|
||
|
||
return on_fit_epoch_end
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)")
|
||
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
|
||
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
|
||
ap.add_argument("--epochs", type=int, default=100)
|
||
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)")
|
||
ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小")
|
||
ap.add_argument(
|
||
"--device",
|
||
default="auto",
|
||
help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)",
|
||
)
|
||
ap.add_argument(
|
||
"--no-amp",
|
||
action="store_true",
|
||
help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)",
|
||
)
|
||
ap.add_argument("--project", default="runs/pose")
|
||
ap.add_argument("--name", default="target_pose_train")
|
||
ap.add_argument("--workers", type=int, default=4)
|
||
ap.add_argument(
|
||
"--pixel-metrics-every",
|
||
type=int,
|
||
default=0,
|
||
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致",
|
||
)
|
||
ap.add_argument(
|
||
"--pixel-metrics-conf",
|
||
type=float,
|
||
default=0.25,
|
||
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)",
|
||
)
|
||
ap.add_argument(
|
||
"--best-by-pixel",
|
||
action="store_true",
|
||
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP",
|
||
)
|
||
ap.add_argument(
|
||
"--pixel-fitness-conf",
|
||
type=float,
|
||
default=0.25,
|
||
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
|
||
)
|
||
ap.add_argument(
|
||
"--export-onnx",
|
||
action="store_true",
|
||
help="训练结束后导出 ONNX(需再设 --onnx-imgsz)",
|
||
)
|
||
ap.add_argument(
|
||
"--onnx-imgsz",
|
||
type=int,
|
||
nargs=2,
|
||
metavar=("H", "W"),
|
||
default=[224, 320],
|
||
help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)",
|
||
)
|
||
ap.add_argument(
|
||
"--clear-label-cache",
|
||
action="store_true",
|
||
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
device = _pick_device(None if args.device == "auto" else args.device)
|
||
use_amp = False if args.no_amp else _default_amp(device)
|
||
|
||
if isinstance(device, torch.device) and device.type == "xpu":
|
||
print(f"✅ 使用 Intel XPU: {device}")
|
||
elif device == 0 or device == "0":
|
||
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
|
||
else:
|
||
print("⚠️ 使用 CPU,训练会较慢")
|
||
|
||
if isinstance(device, torch.device) and device.type == "xpu":
|
||
_patch_ultralytics_for_xpu()
|
||
|
||
data_yaml = args.data
|
||
if not os.path.isabs(data_yaml):
|
||
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
|
||
if not os.path.exists(data_yaml):
|
||
print(f"❌ 数据集配置不存在: {data_yaml}")
|
||
return
|
||
|
||
if args.clear_label_cache:
|
||
n_rm = _clear_ultralytics_label_caches(data_yaml)
|
||
print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。")
|
||
|
||
print(f"📦 加载模型: {args.model}(固定 task=pose)")
|
||
model = YOLO(args.model, task="pose")
|
||
|
||
if args.best_by_pixel:
|
||
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
|
||
print(
|
||
"📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);"
|
||
"反向传播仍为 Ultralytics 默认 pose/box loss。"
|
||
)
|
||
|
||
if args.pixel_metrics_every > 0:
|
||
model.add_callback(
|
||
"on_fit_epoch_end",
|
||
_make_pixel_metrics_callback(
|
||
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
|
||
),
|
||
)
|
||
|
||
model.train(
|
||
task="pose",
|
||
data=data_yaml,
|
||
epochs=args.epochs,
|
||
imgsz=args.imgsz,
|
||
batch=args.batch,
|
||
name=args.name,
|
||
project=args.project,
|
||
exist_ok=True,
|
||
save=True,
|
||
save_period=5,
|
||
device=device,
|
||
workers=args.workers,
|
||
lr0=0.0001,
|
||
lrf=0.01,
|
||
optimizer="AdamW",
|
||
momentum=0.937,
|
||
weight_decay=0.001,
|
||
warmup_epochs=0,
|
||
warmup_momentum=0.8,
|
||
warmup_bias_lr=0.1,
|
||
hsv_h=0.015,
|
||
hsv_s=0.7,
|
||
hsv_v=0.4,
|
||
degrees=5.0,
|
||
translate=0.0,
|
||
scale=0.2,
|
||
shear=0.0,
|
||
perspective=0.0000,
|
||
flipud=0.0,
|
||
fliplr=0.5,
|
||
mosaic=0.0,
|
||
mixup=0.0,
|
||
copy_paste=0.0,
|
||
box=6,
|
||
cls=0.5,
|
||
dfl=1.5,
|
||
pose=18.0,
|
||
kobj=0.5,
|
||
freeze=0,
|
||
seed=42,
|
||
verbose=True,
|
||
amp=use_amp,
|
||
patience=100,
|
||
cos_lr=True,
|
||
)
|
||
|
||
print("\n✅ 训练完成!")
|
||
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
|
||
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
|
||
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
|
||
|
||
if args.export_onnx:
|
||
h, w = args.onnx_imgsz
|
||
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
|
||
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
|
||
print("✅ ONNX 完成")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|