Files
archery/train_yolo/train_black_triangle_pos.py
2026-05-15 09:35:53 +08:00

507 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
YOLO11 关键点检测训练脚本(靶纸四角)。
设备优先级(--device autoIntel XPU > NVIDIA CUDA > CPU。
默认 imgsz=960批大小默认 4大图显存紧张时可再降
关于「业务像素误差」:
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
- 监控:--pixel-metrics-every N每 N 个 epoch 打印 mean/p95并合并进 runs/.../results.csv见 pose_pixel_metrics.py
- 选 best.pt / early stopping加 --best-by-pixel用验证集 mean 像素误差(与 pose_pixel_metrics
同一口径)代替 mAP 合成 fitnessfitness = -mean_px越小越好
多卡 DDPworld_size>1时会自动退回默认 mAP fitness。
XPUUltralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu
务必使用 pose 任务YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect
会把 17 列 Pose 标注当成检测/分割解析校验时出现「coordinates > 1」或 [2.] 等假象。
"""
from __future__ import annotations
import argparse
import csv
import gc
import glob
import os
import tempfile
from copy import deepcopy
from pathlib import Path
import torch
from ultralytics import YOLO
from pose_pixel_metrics import eval_val_pixel_error
import warnings
warnings.filterwarnings('ignore',
message=".*scatter_add_kernel does not have a deterministic implementation.*")
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
"""删除 data.yaml 的 path 下 labels/*.cache。
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
from ultralytics.utils import YAML
try:
cfg = YAML.load(data_yaml_path)
except Exception:
return 0
root = cfg.get("path")
if not root:
return 0
root = os.path.abspath(os.path.expanduser(str(root)))
pattern = os.path.join(root, "labels", "*.cache")
n = 0
for p in glob.glob(pattern):
try:
os.unlink(p)
n += 1
except OSError:
pass
return n
def _pick_device(explicit: str | None):
"""返回 ultralytics train/predict 可用的 device。"""
if explicit and explicit != "auto":
e = explicit.lower()
if e == "xpu":
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
return torch.device("xpu")
if e in ("0", "cuda", "gpu"):
if not torch.cuda.is_available():
raise RuntimeError("指定了 CUDA 但不可用")
return 0
if e == "cpu":
return "cpu"
return explicit
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
return torch.device("xpu")
if torch.cuda.is_available():
return 0
return "cpu"
def _default_amp(device) -> bool:
if isinstance(device, torch.device) and device.type == "xpu":
return False
if device == "cpu":
return False
return True
def _patch_ultralytics_for_xpu():
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
import ultralytics.engine.trainer as ut_trainer
import ultralytics.engine.validator as ut_validator
from ultralytics.utils.torch_utils import select_device as _original_select_device
# 1. 覆盖 select_deviceTrainer 初始化传入 torch.device("xpu") 会走原版早返回;
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device不调用 select_device
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
def _patched_select_device(device="", *args, **kwargs):
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
# Older forks sometimes passed extra positional args; forward everything.
if isinstance(device, str):
d = device.strip().lower()
if d == "xpu" or d.startswith("xpu:"):
return torch.device(device.strip())
return _original_select_device(device, *args, **kwargs)
import ultralytics.utils.torch_utils
ultralytics.utils.torch_utils.select_device = _patched_select_device
ut_trainer.select_device = _patched_select_device
ut_validator.select_device = _patched_select_device
# 2. 修补 Trainer 的内存函数
BT = ut_trainer.BaseTrainer
if not getattr(BT, "_archery_xpu_memory_patched", False):
_orig_get_memory = BT._get_memory
_orig_clear_memory = BT._clear_memory
def _get_memory(self, fraction=False):
if self.device.type != "xpu":
return _orig_get_memory(self, fraction)
# ... (原有的 XPU 内存获取逻辑保持不变) ...
memory, total = 0, 0
try:
idx = self.device.index
if idx is None:
idx = torch.xpu.current_device()
memory = int(torch.xpu.memory_allocated(idx))
if fraction:
total = int(torch.xpu.get_device_properties(idx).total_memory)
except Exception:
pass
return (memory / total) if fraction and total > 0 else (memory / 2**30)
def _clear_memory(self, threshold=None):
if self.device.type != "xpu":
return _orig_clear_memory(self, threshold)
if threshold is not None:
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
if self._get_memory(fraction=True) <= threshold:
return
gc.collect()
if hasattr(torch.xpu, "empty_cache"):
torch.xpu.empty_cache()
BT._get_memory = _get_memory
BT._clear_memory = _clear_memory
BT._archery_xpu_memory_patched = True
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
BV = ut_validator.BaseValidator
if not getattr(BV, "_archery_xpu_memory_patched", False):
# 为 Validator 添加同样的内存处理方法
BV._get_memory = _get_memory
BV._clear_memory = _clear_memory
BV._archery_xpu_memory_patched = True
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
"""用验证集关键点像素 mean 替代 mAP fitness驱动 best.pt 与 patience early stopping。"""
import ultralytics.engine.trainer as ut
from ultralytics.utils import RANK
BT = ut.BaseTrainer
if getattr(BT, "_archery_best_by_pixel_installed", False):
return
_orig_validate = BT.validate
def validate(self):
import torch.distributed as dist
if self.ema and self.world_size > 1:
for buffer in self.ema.ema.buffers():
dist.broadcast(buffer, src=0)
metrics = self.validator(self)
if metrics is None:
return None, None
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
mean_px: float | None = None
if use_pixel:
tmp_path: str | None = None
try:
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
os.close(fd)
from ultralytics.utils.torch_utils import unwrap_model
core = unwrap_model(self.ema.ema if self.ema else self.model)
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
probe = YOLO(tmp_path)
stats = eval_val_pixel_error(
probe,
data_yaml,
device=self.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
if mean_px is None:
raise RuntimeError("无有效 mean_px检查 val 标签与检测是否为空)")
except Exception as exc:
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
mean_px = None
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
if mean_px is not None:
fitness = -float(mean_px)
metrics["metrics/mean_px(val)"] = float(mean_px)
else:
fitness = float(orig_fitness)
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
BT.validate = validate
BT._archery_best_by_pixel_installed = True
def _fmt_csv_metric(v: float | int | None) -> str:
if v is None:
return ""
if isinstance(v, float):
return f"{v:.6g}"
return str(v)
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
("pixel_error/mean_px", "mean_px"),
("pixel_error/median_px", "median_px"),
("pixel_error/p95_px", "p95_px"),
("pixel_error/max_px", "max_px"),
("pixel_error/n_points", "n_points"),
("pixel_error/n_images", "n_images"),
("pixel_error/skip_no_det", "skip_no_det"),
("pixel_error/skip_no_gt", "skip_no_gt"),
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
)
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv扩展表头、补空列"""
csv_path = Path(save_dir) / "results.csv"
if not csv_path.is_file():
return
try:
with open(csv_path, newline="", encoding="utf-8") as f:
rows = list(csv.reader(f))
except OSError:
return
if len(rows) < 2:
return
header = list(rows[0])
for col_name, _ in _PIXEL_METRIC_COLUMNS:
if col_name not in header:
header.append(col_name)
for ri in range(1, len(rows)):
rows[ri].append("")
col_ix = {name: i for i, name in enumerate(header)}
rows[0] = header
target_ri: int | None = None
for ri in range(1, len(rows)):
row = rows[ri]
while len(row) < len(header):
row.append("")
try:
if int(float(row[0].strip())) == int(epoch_1based):
target_ri = ri
except (ValueError, IndexError):
continue
if target_ri is None:
return
row = rows[target_ri]
while len(row) < len(header):
row.append("")
for col_name, sk in _PIXEL_METRIC_COLUMNS:
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
try:
with open(csv_path, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerows(rows)
except OSError:
pass
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
def on_fit_epoch_end(trainer):
from ultralytics.utils import RANK
if RANK not in {-1, 0}:
return
if every <= 0:
return
ep = int(getattr(trainer, "epoch", -1))
if (ep + 1) % every != 0:
return
w = Path(trainer.save_dir) / "weights" / "last.pt"
if not w.is_file():
return
m = YOLO(str(w))
stats = eval_val_pixel_error(
m,
data_yaml,
device=trainer.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
p95_px = stats.get("p95_px")
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
print(
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
f"n_points={stats.get('n_points', 0)} "
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
)
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
return on_fit_epoch_end
def main():
ap = argparse.ArgumentParser(description="YOLO Pose 训练XPU/CUDA/CPU")
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
ap.add_argument("--epochs", type=int, default=100)
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960")
ap.add_argument("--batch", type=int, default=4, help="批大小OOM 时减小")
ap.add_argument(
"--device",
default="auto",
help="auto | xpu | 0 | cuda | cpuautoXPU 优先)",
)
ap.add_argument(
"--no-amp",
action="store_true",
help="关闭混合精度默认CUDA 开启XPU/CPU 关闭)",
)
ap.add_argument("--project", default="runs/pose")
ap.add_argument("--name", default="target_pose_train")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument(
"--pixel-metrics-every",
type=int,
default=0,
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行0=关闭);需 labels 与 data.yaml 布局一致",
)
ap.add_argument(
"--pixel-metrics-conf",
type=float,
default=0.25,
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25",
)
ap.add_argument(
"--best-by-pixel",
action="store_true",
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metricsfitness=-mean_px单卡有效DDP 自动退回 mAP",
)
ap.add_argument(
"--pixel-fitness-conf",
type=float,
default=0.25,
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
)
ap.add_argument(
"--export-onnx",
action="store_true",
help="训练结束后导出 ONNX需再设 --onnx-imgsz",
)
ap.add_argument(
"--onnx-imgsz",
type=int,
nargs=2,
metavar=("H", "W"),
default=[224, 320],
help="导出 ONNX 的 [高, 宽],默认 224 320Maix 常用)",
)
ap.add_argument(
"--clear-label-cache",
action="store_true",
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache修正标注后仍报 corrupt 时用)",
)
args = ap.parse_args()
device = _pick_device(None if args.device == "auto" else args.device)
use_amp = False if args.no_amp else _default_amp(device)
if isinstance(device, torch.device) and device.type == "xpu":
print(f"✅ 使用 Intel XPU: {device}")
elif device == 0 or device == "0":
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ 使用 CPU训练会较慢")
if isinstance(device, torch.device) and device.type == "xpu":
_patch_ultralytics_for_xpu()
data_yaml = args.data
if not os.path.isabs(data_yaml):
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
if not os.path.exists(data_yaml):
print(f"❌ 数据集配置不存在: {data_yaml}")
return
if args.clear_label_cache:
n_rm = _clear_ultralytics_label_caches(data_yaml)
print(f"🗑️ 已删除标签目录缓存 {n_rm}labels/*.cache将强制重新扫描标注。")
print(f"📦 加载模型: {args.model}(固定 task=pose")
model = YOLO(args.model, task="pose")
if args.best_by_pixel:
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
print(
"📌 已启用 --best-by-pixelbest.pt / patience 按验证集 mean 像素误差fitness=-mean_px"
"反向传播仍为 Ultralytics 默认 pose/box loss。"
)
if args.pixel_metrics_every > 0:
model.add_callback(
"on_fit_epoch_end",
_make_pixel_metrics_callback(
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
),
)
model.train(
task="pose",
data=data_yaml,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
name=args.name,
project=args.project,
exist_ok=True,
save=True,
save_period=5,
device=device,
workers=args.workers,
lr0=0.0001,
lrf=0.01,
optimizer="AdamW",
momentum=0.937,
weight_decay=0.001,
warmup_epochs=0,
warmup_momentum=0.8,
warmup_bias_lr=0.1,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=5.0,
translate=0.0,
scale=0.2,
shear=0.0,
perspective=0.0000,
flipud=0.0,
fliplr=0.5,
mosaic=0.0,
mixup=0.0,
copy_paste=0.0,
box=6,
cls=0.5,
dfl=1.5,
pose=18.0,
kobj=0.5,
freeze=0,
seed=42,
verbose=True,
amp=use_amp,
patience=100,
cos_lr=True,
)
print("\n✅ 训练完成!")
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
if args.export_onnx:
h, w = args.onnx_imgsz
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
print("✅ ONNX 完成")
if __name__ == "__main__":
main()