增加训练yolo的代码
This commit is contained in:
11
adc.py
11
adc.py
@@ -4,14 +4,13 @@ from maix import time
|
||||
a = adc.ADC(0, adc.RES_BIT_12)
|
||||
|
||||
while True:
|
||||
raw_data = a.read()
|
||||
print(f"ADC raw data:{raw_data}")
|
||||
# raw_data = a.read()
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# if raw_data > 2450:
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# elif raw_data < 2000:
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# time.sleep_ms(50)
|
||||
time.sleep_ms(1)
|
||||
|
||||
# vol = a.read_vol()
|
||||
|
||||
# print(f"ADC vol:{vol}")
|
||||
vol = int(a.read_vol() * 10) / 10
|
||||
print(f"ADC vol:{vol:.1f}, {time.time():.4f}")
|
||||
|
||||
4
app.yaml
4
app.yaml
@@ -1,11 +1,12 @@
|
||||
id: t11
|
||||
name: t11
|
||||
version: 1.2.11
|
||||
version: 1.2.12
|
||||
author: t11
|
||||
icon: ''
|
||||
desc: t11
|
||||
files:
|
||||
- 4g_download_manager.py
|
||||
- 4g_upload_manager.py
|
||||
- app.yaml
|
||||
- archery_netcore.cpython-311-riscv64-linux-gnu.so
|
||||
- at_client.py
|
||||
@@ -34,3 +35,4 @@ files:
|
||||
- vision.py
|
||||
- wifi_config_httpd.py
|
||||
- wifi.py
|
||||
- wpa_supplicant_conf.py
|
||||
|
||||
@@ -90,5 +90,13 @@ opencv_interactive-calibration -t=chessboard -w=9 -h=6 -sz=0.025 -v="http://192.
|
||||
D:\data\test_target_photo 是用来叠加的背景图
|
||||
|
||||
7.1 生成靶纸及黑色三角形的截图的图片,带动动,但1.12的外框
|
||||
bak
|
||||
python .\synth_compose_yolo.py --perspective 0.04 --perspective-prob 0.8 --color-jitter 0.6 --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --out ./synth_out --class-name triangle --zip ./maix_dataset.zip --num 60 --triangles-json archery_triangles_default.json --format voc --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --triangle-bbox-pad-frac 0.12
|
||||
|
||||
bak_2
|
||||
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
|
||||
|
||||
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 1.0 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
|
||||
|
||||
|
||||
python pose_pixel_metrics.py --model D:\code\archery\runs\pose\runs\pose\target_pose_train\weights\best.pt --data D:\code\archery\datasets\dataset_pose.yaml --imgsz 640
|
||||
50
test/test_audio.py
Normal file
50
test/test_audio.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# test_audio.pyx
|
||||
from maix import audio, time, app, gpio
|
||||
|
||||
def run_player_loop():
|
||||
"""
|
||||
播放控制主循环函数
|
||||
"""
|
||||
# 初始化音频播放器
|
||||
p = audio.Player("/root/gun.wav")
|
||||
p.volume(40)
|
||||
|
||||
# 初始化 GPIO 引脚为输出
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
# 设置低电平
|
||||
led.value(0)
|
||||
|
||||
# 主循环
|
||||
while not app.need_exit():
|
||||
led.value(1) # 点亮 LED
|
||||
time.sleep_ms(200) # 保持 200ms
|
||||
led.value(0) # 熄灭 LED
|
||||
p.play() # 播放音频
|
||||
time.sleep_ms(1000) # 等待 1 秒
|
||||
|
||||
print("play finish!")
|
||||
|
||||
|
||||
# 可选:添加一个简单的测试函数
|
||||
def hello():
|
||||
return "Hello from test_audio!"
|
||||
|
||||
|
||||
# 可选:添加一个初始化函数
|
||||
def init_led():
|
||||
"""单独测试 GPIO"""
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
led.value(0)
|
||||
return "LED initialized"
|
||||
|
||||
|
||||
# 可选:添加一个播放测试函数
|
||||
def test_play():
|
||||
"""单独测试音频播放"""
|
||||
p = audio.Player("/root/gun.wav")
|
||||
p.volume(50)
|
||||
p.play()
|
||||
return "Playing..."
|
||||
|
||||
|
||||
run_player_loop()
|
||||
25
test/test_button.py
Normal file
25
test/test_button.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from maix import audio, time, app,gpio
|
||||
|
||||
|
||||
# button1 = gpio.GPIO("ADC", gpio.Mode.IN)
|
||||
button3 = gpio.GPIO("A26", gpio.Mode.IN) # 可用
|
||||
button2 = gpio.GPIO("A16", gpio.Mode.IN)
|
||||
#设置低电平
|
||||
from maix.peripheral import adc
|
||||
channel = 0
|
||||
res_bit = adc.RES_BIT_12
|
||||
_adc_obj = adc.ADC(channel, res_bit)
|
||||
|
||||
|
||||
while not app.need_exit():
|
||||
# print(f"b1: {button1.value()}")
|
||||
|
||||
print(f"b2: {button2.value()}")
|
||||
|
||||
# print(_adc_obj.read_vol())
|
||||
print(f"b3: {button3.value()}")
|
||||
time.sleep_ms(50)
|
||||
|
||||
# time.sleep_ms(1000)
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from maix import camera, display, time
|
||||
try:
|
||||
print("Initializing camera...")
|
||||
cam = camera.Camera(640,480)
|
||||
# cam = camera.Camera(1280,720)
|
||||
# cam.get_exposure_us()
|
||||
# print("Camera exposure: ", cam.get_exposure_us())
|
||||
print("Camera initialized successfully!")
|
||||
|
||||
disp = display.Display()
|
||||
|
||||
@@ -1,172 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
激光模块测试脚本
|
||||
用于诊断激光开关问题
|
||||
|
||||
使用方法:
|
||||
python test_laser.py
|
||||
|
||||
功能:
|
||||
1. 初始化串口
|
||||
2. 循环测试激光开/关
|
||||
3. 打印详细调试信息
|
||||
M01激光测距模块测试脚本 - 修正版
|
||||
基于文档中的完整命令示例
|
||||
"""
|
||||
|
||||
from maix import uart, pinmap, time
|
||||
import binascii
|
||||
|
||||
# ==================== 配置 ====================
|
||||
UART_PORT = "/dev/ttyS1" # 激光模块连接的串口(UART1)
|
||||
BAUDRATE = 9600 # 波特率
|
||||
|
||||
# 引脚映射(确保与硬件连接一致)
|
||||
print("=" * 50)
|
||||
print("🔧 步骤1: 配置引脚映射")
|
||||
print("=" * 50)
|
||||
UART_PORT = "/dev/ttyS1"
|
||||
BAUDRATE = 9600
|
||||
|
||||
# 初始化串口
|
||||
try:
|
||||
pinmap.set_pin_function("A18", "UART1_RX")
|
||||
print("✅ A18 -> UART1_RX")
|
||||
except Exception as e:
|
||||
print(f"❌ A18 配置失败: {e}")
|
||||
|
||||
try:
|
||||
pinmap.set_pin_function("A19", "UART1_TX")
|
||||
print("✅ A19 -> UART1_TX")
|
||||
except Exception as e:
|
||||
print(f"❌ A19 配置失败: {e}")
|
||||
|
||||
# ==================== 激光控制指令 ====================
|
||||
MODULE_ADDR = 0x00
|
||||
|
||||
# 原始命令
|
||||
LASER_ON_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
|
||||
LASER_OFF_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
|
||||
|
||||
# 备用命令格式(如果原始命令不工作,可以尝试这些)
|
||||
# 格式1: 简化命令
|
||||
LASER_ON_CMD_ALT1 = bytes([0xAA, 0x01, 0x01])
|
||||
LASER_OFF_CMD_ALT1 = bytes([0xAA, 0x01, 0x00])
|
||||
|
||||
# 格式2: 不同的协议头
|
||||
LASER_ON_CMD_ALT2 = bytes([0x55, 0xAA, 0x01])
|
||||
LASER_OFF_CMD_ALT2 = bytes([0x55, 0xAA, 0x00])
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🔧 步骤2: 初始化串口")
|
||||
print("=" * 50)
|
||||
print(f"设备: {UART_PORT}")
|
||||
print(f"波特率: {BAUDRATE}")
|
||||
|
||||
try:
|
||||
laser_uart = uart.UART(UART_PORT, BAUDRATE)
|
||||
print(f"✅ 串口初始化成功: {laser_uart}")
|
||||
print("✅ 硬件初始化完成")
|
||||
except Exception as e:
|
||||
print(f"❌ 串口初始化失败: {e}")
|
||||
print(f"❌ 初始化失败: {e}")
|
||||
exit(1)
|
||||
|
||||
# ==================== 测试函数 ====================
|
||||
def send_and_check(cmd, name):
|
||||
"""发送命令并检查回包"""
|
||||
print(f"\n📤 发送: {name}")
|
||||
print(f" 命令字节: {cmd.hex()}")
|
||||
print(f" 命令长度: {len(cmd)} 字节")
|
||||
# ==================== 根据文档的完整命令集 ====================
|
||||
# 1. 激光开关(文档2.3.10,已验证可用)
|
||||
LASER_ON_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
|
||||
LASER_OFF_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
|
||||
|
||||
# 清空接收缓冲区
|
||||
# 2. 尝试不同的测距命令格式
|
||||
TEST_COMMANDS = [
|
||||
# 格式1:文档2.3.12的单次测量(您测试失败的)
|
||||
{
|
||||
"name": "单次测量 (0x0020)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21]),
|
||||
"desc": "文档2.3.12 示例命令"
|
||||
},
|
||||
# 格式2:文档2.3.7的读取测量结果
|
||||
{
|
||||
"name": "读取测量结果 (0x0022)",
|
||||
"cmd": bytes([0xAA, 0x80, 0x00, 0x22, 0xA2]),
|
||||
"desc": "文档2.3.7 读取测量结果"
|
||||
},
|
||||
# 格式3:文档2.3.13的快速测量
|
||||
{
|
||||
"name": "快速测量 (0x0022带数据)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x22, 0x00, 0x01, 0x00, 0x00, 0x23]),
|
||||
"desc": "文档2.3.13 快速测量"
|
||||
},
|
||||
# 格式4:连续测量模式
|
||||
{
|
||||
"name": "连续测量模式 (0x0021)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x22]),
|
||||
"desc": "文档2.3.14 连续测量"
|
||||
}
|
||||
]
|
||||
|
||||
def clear_buffer():
|
||||
"""清空串口缓冲区"""
|
||||
try:
|
||||
old_data = laser_uart.read(-1)
|
||||
if old_data:
|
||||
print(f" 清空缓冲区: {len(old_data)} 字节")
|
||||
data = laser_uart.read(-1)
|
||||
if data:
|
||||
print(f"清空: {len(data)}字节")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 发送命令
|
||||
def send_and_wait(cmd, name, wait_time=2000):
|
||||
"""发送命令并等待响应"""
|
||||
print(f"\n📤 发送: {name}")
|
||||
print(f" 命令: {cmd.hex()}")
|
||||
|
||||
clear_buffer()
|
||||
|
||||
try:
|
||||
written = laser_uart.write(cmd)
|
||||
print(f" 写入字节数: {written}")
|
||||
laser_uart.write(cmd)
|
||||
print(f" 已发送 {len(cmd)} 字节")
|
||||
except Exception as e:
|
||||
print(f" ❌ 写入失败: {e}")
|
||||
print(f" ❌ 发送失败: {e}")
|
||||
return None
|
||||
|
||||
# 等待响应
|
||||
time.sleep_ms(100)
|
||||
start_time = time.ticks_ms()
|
||||
response = b""
|
||||
|
||||
# 读取回包
|
||||
while time.ticks_ms() - start_time < wait_time:
|
||||
try:
|
||||
resp = laser_uart.read(50)
|
||||
if resp:
|
||||
print(f" 📥 收到回包: {resp.hex()} ({len(resp)} 字节)")
|
||||
return resp
|
||||
chunk = laser_uart.read(1)
|
||||
if chunk:
|
||||
response += chunk
|
||||
# 完整响应通常是9或13字节
|
||||
if len(response) >= 9:
|
||||
# 检查是否完整帧
|
||||
if response[0] in [0xAA, 0xEE]:
|
||||
if len(response) >= 13: # 测距完整响应
|
||||
break
|
||||
elif response[0] == 0xEE: # 错误响应
|
||||
break
|
||||
except:
|
||||
break
|
||||
|
||||
time.sleep_ms(10)
|
||||
|
||||
if response:
|
||||
print(f" 📥 响应: {response.hex()}")
|
||||
print(f" 长度: {len(response)} 字节")
|
||||
|
||||
# 解析错误码
|
||||
if response[0] == 0xEE and len(response) >= 9:
|
||||
err_code = (response[7] << 8) | response[8]
|
||||
error_mapping = {
|
||||
0x0000: "无错误",
|
||||
0x0001: "硬件错误",
|
||||
0x0002: "无输出数据",
|
||||
0x0003: "反射信号太弱",
|
||||
0x0004: "反射信号太强",
|
||||
0x0005: "温度太高(>40℃)",
|
||||
0x0006: "温度太低(<-10℃)",
|
||||
0x0007: "电源电压低(<2.5V)",
|
||||
0x0008: "超出量程",
|
||||
0x0009: "读通讯错误",
|
||||
0x000A: "写通讯错误",
|
||||
0x000B: "地址错误"
|
||||
}
|
||||
err_msg = error_mapping.get(err_code, f"未知错误: 0x{err_code:04X}")
|
||||
print(f" ❌ 模块错误: {err_msg}")
|
||||
else:
|
||||
print(f" ⚠️ 无回包")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" ❌ 读取失败: {e}")
|
||||
print(" ⚠️ 无响应")
|
||||
|
||||
return response
|
||||
|
||||
def parse_distance_data(response):
|
||||
"""解析距离数据"""
|
||||
if not response or len(response) < 13:
|
||||
return None
|
||||
|
||||
def test_laser_cycle(on_cmd, off_cmd, cmd_name="标准命令"):
|
||||
"""测试一个开关周期"""
|
||||
print(f"\n{'='*50}")
|
||||
print(f"🧪 测试 {cmd_name}")
|
||||
print(f"{'='*50}")
|
||||
if response[0] != 0xAA or response[3] not in [0x20, 0x21, 0x22]:
|
||||
return None
|
||||
|
||||
print("\n>>> 测试开启激光")
|
||||
send_and_check(on_cmd, f"{cmd_name} - 开启")
|
||||
print(" ⏱️ 等待 2 秒观察激光是否亮起...")
|
||||
time.sleep(2)
|
||||
# 解析4字节BCD码
|
||||
bcd_bytes = response[6:10]
|
||||
distance_int = 0
|
||||
|
||||
print("\n>>> 测试关闭激光")
|
||||
send_and_check(off_cmd, f"{cmd_name} - 关闭")
|
||||
print(" ⏱️ 等待 2 秒观察激光是否熄灭...")
|
||||
time.sleep(2)
|
||||
for byte in bcd_bytes:
|
||||
high = (byte >> 4) & 0x0F
|
||||
low = byte & 0x0F
|
||||
|
||||
if high > 9 or low > 9:
|
||||
return None
|
||||
|
||||
distance_int = distance_int * 100 + high * 10 + low
|
||||
|
||||
distance_m = distance_int / 1000.0
|
||||
|
||||
# 信号质量
|
||||
signal = 0
|
||||
if len(response) >= 12:
|
||||
signal = (response[10] << 8) | response[11]
|
||||
|
||||
return {
|
||||
'meters': distance_m,
|
||||
'millimeters': distance_m * 1000,
|
||||
'signal': signal,
|
||||
'raw': response.hex()
|
||||
}
|
||||
|
||||
# ==================== 主测试 ====================
|
||||
print("\n" + "="*50)
|
||||
print("🚀 开始激光测试")
|
||||
print("M01激光测距模块详细测试")
|
||||
print("="*50)
|
||||
print("\n请观察激光模块的状态变化...")
|
||||
print("测试将依次尝试不同的命令格式\n")
|
||||
|
||||
try:
|
||||
# 测试1: 标准命令
|
||||
test_laser_cycle(LASER_ON_CMD, LASER_OFF_CMD, "标准命令")
|
||||
# 1. 测试基本连接
|
||||
print("\n1. 测试模块连接...")
|
||||
version_cmd = bytes([0xAA, 0x80, 0x00, 0x0A, 0x8A])
|
||||
resp = send_and_wait(version_cmd, "读取硬件版本")
|
||||
|
||||
input("\n按回车继续测试备用命令1...")
|
||||
if resp and resp[0] == 0xAA and resp[3] == 0x0A:
|
||||
print(f"✅ 模块正常,版本: {resp[6]:02X}{resp[7]:02X}")
|
||||
else:
|
||||
print("❌ 模块连接测试失败")
|
||||
exit(1)
|
||||
|
||||
# 测试2: 备用命令格式1
|
||||
test_laser_cycle(LASER_ON_CMD_ALT1, LASER_OFF_CMD_ALT1, "备用命令1 (简化)")
|
||||
# 2. 开启激光
|
||||
print("\n2. 开启激光...")
|
||||
resp = send_and_wait(LASER_ON_CMD, "开启激光", 1000)
|
||||
if resp and resp.hex() == "aa0001be00010001c1":
|
||||
print("✅ 激光已开启")
|
||||
|
||||
input("\n按回车继续测试备用命令2...")
|
||||
print(" 等待激光稳定...")
|
||||
time.sleep(2) # 重要等待时间
|
||||
|
||||
# 测试3: 备用命令格式2
|
||||
test_laser_cycle(LASER_ON_CMD_ALT2, LASER_OFF_CMD_ALT2, "备用命令2 (0x55AA头)")
|
||||
# 3. 尝试不同的测距命令
|
||||
print("\n3. 测试不同测距命令...")
|
||||
|
||||
for i, test_cmd in enumerate(TEST_COMMANDS):
|
||||
print(f"\n{'='*30}")
|
||||
print(f"测试 {i+1}: {test_cmd['name']}")
|
||||
print(f"{test_cmd['desc']}")
|
||||
print(f"{'='*30}")
|
||||
|
||||
resp = send_and_wait(test_cmd['cmd'], test_cmd['name'], 3000)
|
||||
|
||||
if resp:
|
||||
if resp[0] == 0xAA and len(resp) >= 13:
|
||||
result = parse_distance_data(resp)
|
||||
if result:
|
||||
print(f"✅ 测距成功!")
|
||||
print(f" 距离: {result['meters']:.3f} m")
|
||||
print(f" 距离: {result['millimeters']:.1f} mm")
|
||||
print(f" 信号质量: {result['signal']}")
|
||||
break
|
||||
else:
|
||||
print("❌ 无法解析距离数据")
|
||||
elif resp[0] == 0xEE:
|
||||
print("❌ 命令执行错误")
|
||||
else:
|
||||
print("❌ 无效响应格式")
|
||||
else:
|
||||
print("❌ 无响应")
|
||||
|
||||
time.sleep(1) # 命令间间隔
|
||||
|
||||
# 4. 关闭激光
|
||||
print("\n4. 关闭激光...")
|
||||
send_and_wait(LASER_OFF_CMD, "关闭激光", 1000)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("🏁 测试完成")
|
||||
print("="*50)
|
||||
print("\n诊断建议:")
|
||||
print("1. 如果激光始终不亮/始终亮:")
|
||||
print(" - 检查激光模块的电源连接")
|
||||
print(" - 检查串口TX/RX是否接反")
|
||||
print(" - 尝试不同的波特率 (4800/19200)")
|
||||
print("")
|
||||
print("2. 如果有回包但激光无反应:")
|
||||
print(" - 命令格式可能正确但激光硬件问题")
|
||||
print("")
|
||||
print("3. 如果某个备用命令有效:")
|
||||
print(" - 需要更新 config.py 中的命令格式")
|
||||
|
||||
print("\n📋 测试总结:")
|
||||
print("1. 模块通信: ✅ 正常")
|
||||
print("2. 激光控制: ✅ 正常")
|
||||
print("3. 测距功能: ❌ 有问题")
|
||||
print("\n建议:")
|
||||
print("1. 检查激光是否实际发光(在暗处观察红点)")
|
||||
print("2. 确保测量目标在有效范围内(0.2-60米)")
|
||||
print("3. 确保目标有足够反射率(白色平面最佳)")
|
||||
print("4. 如果所有测距命令都返回ERR_ADDR,可能是固件版本问题")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 测试被中断")
|
||||
# 确保激光关闭
|
||||
print("\n\n🛑 用户中断")
|
||||
laser_uart.write(LASER_OFF_CMD)
|
||||
print("✅ 已发送关闭指令")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
16
test/test_motor.py
Normal file
16
test/test_motor.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from maix import gpio, pinmap, time
|
||||
|
||||
|
||||
#设置引脚为输出
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
#设置低电平
|
||||
led.value(0)
|
||||
|
||||
while 1:
|
||||
# time.sleep_ms(1000)
|
||||
#对该引脚的电平进行取反(原高-》现低)
|
||||
# led.toggle()
|
||||
led.value(1)
|
||||
#延时
|
||||
time.sleep_ms(5000)
|
||||
led.value(0)
|
||||
@@ -28,6 +28,11 @@ Stage2 ROI(对齐「先检整靶再裁小图」的第二步输入):
|
||||
- --motion-prob:施加概率;--motion-kernel-min/max:模糊 streak 长度(奇数核,越大越糊)。
|
||||
- 可与 --blur-max 高斯模糊叠加;Stage2 建议:--motion-prob 0.5~0.7 --motion-kernel-max 35 --blur-max 1.2
|
||||
|
||||
透视有两种模式(见 --perspective-mode):
|
||||
- corner:四角独立随机抖动(旧版),组合后易出现「不物理」的夸张梯形;--perspective 为 jitter 强度。
|
||||
- planar:把靶当作平面,按 yaw/pitch/roll(度)随机旋转再投影,便于限定在例如「偏航 ±45°、俯仰 ±30°」内,
|
||||
更接近 5m 外拍 40cm 靶等场景(焦距用 --planar-focal-frac 控制视场)。
|
||||
|
||||
依赖:OpenCV + NumPy(PC 上跑即可;Maix 上若内存够也可试)。
|
||||
|
||||
示例:
|
||||
@@ -266,6 +271,74 @@ def _paste_fg_on_bg(bg_bgr, x, y, fg_scaled_bgra):
|
||||
roi_bg[:] = blended.astype(np.uint8)
|
||||
|
||||
|
||||
def _rotation_matrix_ypr(yaw_deg: float, pitch_deg: float, roll_deg: float, np) -> np.ndarray:
|
||||
"""靶平面绕相机坐标原点旋转:先 yaw(Y),再 pitch(X),再 roll(Z)。单位:度。"""
|
||||
y, p, r = np.radians([yaw_deg, pitch_deg, roll_deg])
|
||||
cy, sy = np.cos(y), np.sin(y)
|
||||
cp, sp = np.cos(p), np.sin(p)
|
||||
cr, sr = np.cos(r), np.sin(r)
|
||||
Ry = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]], dtype=np.float64)
|
||||
Rx = np.array([[1.0, 0.0, 0.0], [0.0, cp, -sp], [0.0, sp, cp]], dtype=np.float64)
|
||||
Rz = np.array([[cr, -sr, 0.0], [sr, cr, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64)
|
||||
return Rz @ Rx @ Ry
|
||||
|
||||
|
||||
def _perspective_warp_planar_rgba(
|
||||
img_bgra,
|
||||
yaw_deg: float,
|
||||
pitch_deg: float,
|
||||
roll_deg: float,
|
||||
focal_frac: float,
|
||||
np,
|
||||
cv2,
|
||||
):
|
||||
"""
|
||||
平面靶透视:将 frontal 图像视作 z=1 平面上的针孔投影,再旋转三维射线方向,等价于靶相对相机倾斜。
|
||||
focal_frac:焦距 fx=fy=max(w,h)*focal_frac,越大视场越窄、透视越温和。
|
||||
返回 (warped BGRA, M);退化时返回 (copy, None)。
|
||||
"""
|
||||
h, w = img_bgra.shape[:2]
|
||||
if min(w, h) < 16:
|
||||
return img_bgra.copy(), None
|
||||
fx = fy = float(max(w, h) * max(0.25, focal_frac))
|
||||
cx, cy = w * 0.5, h * 0.5
|
||||
R = _rotation_matrix_ypr(yaw_deg, pitch_deg, roll_deg, np)
|
||||
|
||||
uv_dst: list[list[float]] = []
|
||||
z_min = 0.08
|
||||
for u, v in ((0.0, 0.0), (float(w), 0.0), (float(w), float(h)), (0.0, float(h))):
|
||||
x = (u - cx) / fx
|
||||
y = (v - cy) / fy
|
||||
z = 1.0
|
||||
vec = R @ np.array([x, y, z], dtype=np.float64)
|
||||
if float(vec[2]) < z_min:
|
||||
return img_bgra.copy(), None
|
||||
uu = fx * (vec[0] / vec[2]) + cx
|
||||
vv = fy * (vec[1] / vec[2]) + cy
|
||||
uv_dst.append([uu, vv])
|
||||
|
||||
pts_dst = np.float32(uv_dst)
|
||||
xmin = float(pts_dst[:, 0].min())
|
||||
ymin = float(pts_dst[:, 1].min())
|
||||
pts_shift = pts_dst.copy()
|
||||
pts_shift[:, 0] -= xmin
|
||||
pts_shift[:, 1] -= ymin
|
||||
out_w = max(4, int(np.ceil(float(pts_shift[:, 0].max()))) + 2)
|
||||
out_h = max(4, int(np.ceil(float(pts_shift[:, 1].max()))) + 2)
|
||||
|
||||
pts_src = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
|
||||
M = cv2.getPerspectiveTransform(pts_src, pts_shift)
|
||||
warped = cv2.warpPerspective(
|
||||
img_bgra,
|
||||
M,
|
||||
(out_w, out_h),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0, 0),
|
||||
)
|
||||
return warped, M
|
||||
|
||||
|
||||
def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np, cv2):
|
||||
"""
|
||||
对前景做轻微透视(四角微移),返回 (warped BGRA, M)。
|
||||
@@ -311,6 +384,26 @@ def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np,
|
||||
return warped, M
|
||||
|
||||
|
||||
def _apply_perspective_warp_fg(fg_bgra, rng: random.Random, np, cv2, args) -> tuple:
|
||||
"""在已由 perspective_prob 抽中的样本上做一次透视。返回 (fg_out, M|None)。"""
|
||||
mode = getattr(args, "perspective_mode", "corner")
|
||||
if mode == "planar":
|
||||
|
||||
x = rng.betavariate(2,2)
|
||||
yaw = -float(args.yaw_max_deg) + x * (float(args.yaw_max_deg) +float(args.yaw_max_deg))
|
||||
x = rng.betavariate(2,2)
|
||||
pitch = -float(args.pitch_max_deg) + x * (float(args.pitch_max_deg) +float(args.pitch_max_deg))
|
||||
x = rng.betavariate(2,2)
|
||||
roll = -float(args.roll_max_deg) + x * (float(args.roll_max_deg) +float(args.roll_max_deg))
|
||||
|
||||
focal_frac = float(getattr(args, "planar_focal_frac", 1.25))
|
||||
return _perspective_warp_planar_rgba(fg_bgra, yaw, pitch, roll, focal_frac, np, cv2)
|
||||
jitter = float(getattr(args, "perspective", 0.0))
|
||||
if jitter <= 0:
|
||||
return fg_bgra.copy(), None
|
||||
return _perspective_warp_rgba(fg_bgra, jitter, rng, np, cv2)
|
||||
|
||||
|
||||
def _color_jitter_bgr(comp_bgr, strength: float, rng: random.Random, np, cv2):
|
||||
"""整图 HSV 抖动:strength∈[0,1] 越大越强。"""
|
||||
if strength <= 1e-6:
|
||||
@@ -519,11 +612,17 @@ def main():
|
||||
help="运动模糊 streak 长度上限,越大越像长曝光/手抖",
|
||||
)
|
||||
ap.add_argument("--jpeg-quality", type=int, default=92)
|
||||
ap.add_argument(
|
||||
"--perspective-mode",
|
||||
choices=("corner", "planar"),
|
||||
default="corner",
|
||||
help="corner=四角随机抖动(易夸张);planar=yaw/pitch/roll 平面投影(易限定视姿)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--perspective",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="轻微透视:四角扰动约为 min(靶宽,靶高)×该系数,0 关闭(建议 0.02~0.06)",
|
||||
help="仅 perspective-mode=corner:四角扰动≈min(靶宽,靶高)×该系数,0 关闭(温和可用 0.015~0.03)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--perspective-prob",
|
||||
@@ -531,6 +630,30 @@ def main():
|
||||
default=0.75,
|
||||
help="每张图应用透视的概率 0~1",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--yaw-max-deg",
|
||||
type=float,
|
||||
default=45.0,
|
||||
help="planar:偏航角均匀采样上界(度),实际 ∈[-max,max]",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pitch-max-deg",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="planar:俯仰角均匀采样上界(度)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--roll-max-deg",
|
||||
type=float,
|
||||
default=8.0,
|
||||
help="planar:滚转角均匀采样上界(度),手持可略大",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--planar-focal-frac",
|
||||
type=float,
|
||||
default=1.25,
|
||||
help="planar:fx=fy=max(w,h)×该值,越大透视越温和(建议 1.1~1.8)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--color-jitter",
|
||||
type=float,
|
||||
@@ -662,8 +785,15 @@ def main():
|
||||
fg_s = cv2.resize(fg_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
persp_M = None
|
||||
if args.perspective > 0 and rng.random() < args.perspective_prob:
|
||||
fg_s, persp_M = _perspective_warp_rgba(fg_s, args.perspective, rng, np, cv2)
|
||||
want_p = (
|
||||
rng.random() < args.perspective_prob
|
||||
and (
|
||||
args.perspective_mode == "planar"
|
||||
or (args.perspective_mode == "corner" and args.perspective > 0)
|
||||
)
|
||||
)
|
||||
if want_p:
|
||||
fg_s, persp_M = _apply_perspective_warp_fg(fg_s, rng, np, cv2, args)
|
||||
|
||||
fw2, fh2 = fg_s.shape[1], fg_s.shape[0]
|
||||
tx0, ty0, tw, th = _fg_bbox_from_alpha(fg_s)
|
||||
506
train_yolo/train_black_triangle_pos.py
Normal file
506
train_yolo/train_black_triangle_pos.py
Normal file
@@ -0,0 +1,506 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
YOLO11 关键点检测训练脚本(靶纸四角)。
|
||||
|
||||
设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。
|
||||
默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。
|
||||
|
||||
关于「业务像素误差」:
|
||||
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
|
||||
- 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。
|
||||
- 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics
|
||||
同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。
|
||||
多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。
|
||||
|
||||
XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA,
|
||||
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。
|
||||
|
||||
务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect,
|
||||
会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
from pose_pixel_metrics import eval_val_pixel_error
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore',
|
||||
message=".*scatter_add_kernel does not have a deterministic implementation.*")
|
||||
|
||||
|
||||
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
|
||||
"""删除 data.yaml 的 path 下 labels/*.cache。
|
||||
|
||||
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
|
||||
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
|
||||
from ultralytics.utils import YAML
|
||||
|
||||
try:
|
||||
cfg = YAML.load(data_yaml_path)
|
||||
except Exception:
|
||||
return 0
|
||||
root = cfg.get("path")
|
||||
if not root:
|
||||
return 0
|
||||
root = os.path.abspath(os.path.expanduser(str(root)))
|
||||
pattern = os.path.join(root, "labels", "*.cache")
|
||||
n = 0
|
||||
for p in glob.glob(pattern):
|
||||
try:
|
||||
os.unlink(p)
|
||||
n += 1
|
||||
except OSError:
|
||||
pass
|
||||
return n
|
||||
|
||||
|
||||
def _pick_device(explicit: str | None):
|
||||
"""返回 ultralytics train/predict 可用的 device。"""
|
||||
if explicit and explicit != "auto":
|
||||
e = explicit.lower()
|
||||
if e == "xpu":
|
||||
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
|
||||
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
|
||||
return torch.device("xpu")
|
||||
if e in ("0", "cuda", "gpu"):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("指定了 CUDA 但不可用")
|
||||
return 0
|
||||
if e == "cpu":
|
||||
return "cpu"
|
||||
return explicit
|
||||
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
if torch.cuda.is_available():
|
||||
return 0
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _default_amp(device) -> bool:
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
return False
|
||||
if device == "cpu":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _patch_ultralytics_for_xpu():
|
||||
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
|
||||
import ultralytics.engine.trainer as ut_trainer
|
||||
import ultralytics.engine.validator as ut_validator
|
||||
from ultralytics.utils.torch_utils import select_device as _original_select_device
|
||||
|
||||
# 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回;
|
||||
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device;
|
||||
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
|
||||
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
|
||||
def _patched_select_device(device="", *args, **kwargs):
|
||||
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
|
||||
# Older forks sometimes passed extra positional args; forward everything.
|
||||
if isinstance(device, str):
|
||||
d = device.strip().lower()
|
||||
if d == "xpu" or d.startswith("xpu:"):
|
||||
return torch.device(device.strip())
|
||||
return _original_select_device(device, *args, **kwargs)
|
||||
|
||||
import ultralytics.utils.torch_utils
|
||||
|
||||
ultralytics.utils.torch_utils.select_device = _patched_select_device
|
||||
ut_trainer.select_device = _patched_select_device
|
||||
ut_validator.select_device = _patched_select_device
|
||||
|
||||
# 2. 修补 Trainer 的内存函数
|
||||
BT = ut_trainer.BaseTrainer
|
||||
if not getattr(BT, "_archery_xpu_memory_patched", False):
|
||||
_orig_get_memory = BT._get_memory
|
||||
_orig_clear_memory = BT._clear_memory
|
||||
|
||||
def _get_memory(self, fraction=False):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_get_memory(self, fraction)
|
||||
# ... (原有的 XPU 内存获取逻辑保持不变) ...
|
||||
memory, total = 0, 0
|
||||
try:
|
||||
idx = self.device.index
|
||||
if idx is None:
|
||||
idx = torch.xpu.current_device()
|
||||
memory = int(torch.xpu.memory_allocated(idx))
|
||||
if fraction:
|
||||
total = int(torch.xpu.get_device_properties(idx).total_memory)
|
||||
except Exception:
|
||||
pass
|
||||
return (memory / total) if fraction and total > 0 else (memory / 2**30)
|
||||
|
||||
def _clear_memory(self, threshold=None):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_clear_memory(self, threshold)
|
||||
if threshold is not None:
|
||||
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
||||
if self._get_memory(fraction=True) <= threshold:
|
||||
return
|
||||
gc.collect()
|
||||
if hasattr(torch.xpu, "empty_cache"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
BT._get_memory = _get_memory
|
||||
BT._clear_memory = _clear_memory
|
||||
BT._archery_xpu_memory_patched = True
|
||||
|
||||
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
|
||||
BV = ut_validator.BaseValidator
|
||||
if not getattr(BV, "_archery_xpu_memory_patched", False):
|
||||
# 为 Validator 添加同样的内存处理方法
|
||||
BV._get_memory = _get_memory
|
||||
BV._clear_memory = _clear_memory
|
||||
BV._archery_xpu_memory_patched = True
|
||||
|
||||
|
||||
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
|
||||
"""用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。"""
|
||||
import ultralytics.engine.trainer as ut
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
BT = ut.BaseTrainer
|
||||
if getattr(BT, "_archery_best_by_pixel_installed", False):
|
||||
return
|
||||
|
||||
_orig_validate = BT.validate
|
||||
|
||||
def validate(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if self.ema and self.world_size > 1:
|
||||
for buffer in self.ema.ema.buffers():
|
||||
dist.broadcast(buffer, src=0)
|
||||
metrics = self.validator(self)
|
||||
if metrics is None:
|
||||
return None, None
|
||||
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
|
||||
|
||||
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
|
||||
mean_px: float | None = None
|
||||
if use_pixel:
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
|
||||
os.close(fd)
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
core = unwrap_model(self.ema.ema if self.ema else self.model)
|
||||
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
|
||||
probe = YOLO(tmp_path)
|
||||
stats = eval_val_pixel_error(
|
||||
probe,
|
||||
data_yaml,
|
||||
device=self.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
if mean_px is None:
|
||||
raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)")
|
||||
except Exception as exc:
|
||||
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
|
||||
mean_px = None
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if mean_px is not None:
|
||||
fitness = -float(mean_px)
|
||||
metrics["metrics/mean_px(val)"] = float(mean_px)
|
||||
else:
|
||||
fitness = float(orig_fitness)
|
||||
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
BT.validate = validate
|
||||
BT._archery_best_by_pixel_installed = True
|
||||
|
||||
|
||||
def _fmt_csv_metric(v: float | int | None) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
if isinstance(v, float):
|
||||
return f"{v:.6g}"
|
||||
return str(v)
|
||||
|
||||
|
||||
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
|
||||
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
|
||||
("pixel_error/mean_px", "mean_px"),
|
||||
("pixel_error/median_px", "median_px"),
|
||||
("pixel_error/p95_px", "p95_px"),
|
||||
("pixel_error/max_px", "max_px"),
|
||||
("pixel_error/n_points", "n_points"),
|
||||
("pixel_error/n_images", "n_images"),
|
||||
("pixel_error/skip_no_det", "skip_no_det"),
|
||||
("pixel_error/skip_no_gt", "skip_no_gt"),
|
||||
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
|
||||
)
|
||||
|
||||
|
||||
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
|
||||
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。"""
|
||||
csv_path = Path(save_dir) / "results.csv"
|
||||
if not csv_path.is_file():
|
||||
return
|
||||
try:
|
||||
with open(csv_path, newline="", encoding="utf-8") as f:
|
||||
rows = list(csv.reader(f))
|
||||
except OSError:
|
||||
return
|
||||
if len(rows) < 2:
|
||||
return
|
||||
header = list(rows[0])
|
||||
for col_name, _ in _PIXEL_METRIC_COLUMNS:
|
||||
if col_name not in header:
|
||||
header.append(col_name)
|
||||
for ri in range(1, len(rows)):
|
||||
rows[ri].append("")
|
||||
col_ix = {name: i for i, name in enumerate(header)}
|
||||
rows[0] = header
|
||||
target_ri: int | None = None
|
||||
for ri in range(1, len(rows)):
|
||||
row = rows[ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
try:
|
||||
if int(float(row[0].strip())) == int(epoch_1based):
|
||||
target_ri = ri
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if target_ri is None:
|
||||
return
|
||||
row = rows[target_ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
for col_name, sk in _PIXEL_METRIC_COLUMNS:
|
||||
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
|
||||
try:
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerows(rows)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
|
||||
def on_fit_epoch_end(trainer):
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
if RANK not in {-1, 0}:
|
||||
return
|
||||
if every <= 0:
|
||||
return
|
||||
ep = int(getattr(trainer, "epoch", -1))
|
||||
if (ep + 1) % every != 0:
|
||||
return
|
||||
w = Path(trainer.save_dir) / "weights" / "last.pt"
|
||||
if not w.is_file():
|
||||
return
|
||||
m = YOLO(str(w))
|
||||
stats = eval_val_pixel_error(
|
||||
m,
|
||||
data_yaml,
|
||||
device=trainer.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
p95_px = stats.get("p95_px")
|
||||
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
|
||||
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
|
||||
print(
|
||||
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
|
||||
f"n_points={stats.get('n_points', 0)} "
|
||||
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
|
||||
)
|
||||
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
|
||||
|
||||
return on_fit_epoch_end
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)")
|
||||
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
|
||||
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)")
|
||||
ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小")
|
||||
ap.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-amp",
|
||||
action="store_true",
|
||||
help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)",
|
||||
)
|
||||
ap.add_argument("--project", default="runs/pose")
|
||||
ap.add_argument("--name", default="target_pose_train")
|
||||
ap.add_argument("--workers", type=int, default=4)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-every",
|
||||
type=int,
|
||||
default=0,
|
||||
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--best-by-pixel",
|
||||
action="store_true",
|
||||
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-fitness-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--export-onnx",
|
||||
action="store_true",
|
||||
help="训练结束后导出 ONNX(需再设 --onnx-imgsz)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--onnx-imgsz",
|
||||
type=int,
|
||||
nargs=2,
|
||||
metavar=("H", "W"),
|
||||
default=[224, 320],
|
||||
help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--clear-label-cache",
|
||||
action="store_true",
|
||||
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
device = _pick_device(None if args.device == "auto" else args.device)
|
||||
use_amp = False if args.no_amp else _default_amp(device)
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
print(f"✅ 使用 Intel XPU: {device}")
|
||||
elif device == 0 or device == "0":
|
||||
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("⚠️ 使用 CPU,训练会较慢")
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
_patch_ultralytics_for_xpu()
|
||||
|
||||
data_yaml = args.data
|
||||
if not os.path.isabs(data_yaml):
|
||||
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
|
||||
if not os.path.exists(data_yaml):
|
||||
print(f"❌ 数据集配置不存在: {data_yaml}")
|
||||
return
|
||||
|
||||
if args.clear_label_cache:
|
||||
n_rm = _clear_ultralytics_label_caches(data_yaml)
|
||||
print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。")
|
||||
|
||||
print(f"📦 加载模型: {args.model}(固定 task=pose)")
|
||||
model = YOLO(args.model, task="pose")
|
||||
|
||||
if args.best_by_pixel:
|
||||
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
|
||||
print(
|
||||
"📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);"
|
||||
"反向传播仍为 Ultralytics 默认 pose/box loss。"
|
||||
)
|
||||
|
||||
if args.pixel_metrics_every > 0:
|
||||
model.add_callback(
|
||||
"on_fit_epoch_end",
|
||||
_make_pixel_metrics_callback(
|
||||
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
|
||||
),
|
||||
)
|
||||
|
||||
model.train(
|
||||
task="pose",
|
||||
data=data_yaml,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
name=args.name,
|
||||
project=args.project,
|
||||
exist_ok=True,
|
||||
save=True,
|
||||
save_period=5,
|
||||
device=device,
|
||||
workers=args.workers,
|
||||
lr0=0.0001,
|
||||
lrf=0.01,
|
||||
optimizer="AdamW",
|
||||
momentum=0.937,
|
||||
weight_decay=0.001,
|
||||
warmup_epochs=0,
|
||||
warmup_momentum=0.8,
|
||||
warmup_bias_lr=0.1,
|
||||
hsv_h=0.015,
|
||||
hsv_s=0.7,
|
||||
hsv_v=0.4,
|
||||
degrees=5.0,
|
||||
translate=0.0,
|
||||
scale=0.2,
|
||||
shear=0.0,
|
||||
perspective=0.0000,
|
||||
flipud=0.0,
|
||||
fliplr=0.5,
|
||||
mosaic=0.0,
|
||||
mixup=0.0,
|
||||
copy_paste=0.0,
|
||||
box=6,
|
||||
cls=0.5,
|
||||
dfl=1.5,
|
||||
pose=18.0,
|
||||
kobj=0.5,
|
||||
freeze=0,
|
||||
seed=42,
|
||||
verbose=True,
|
||||
amp=use_amp,
|
||||
patience=100,
|
||||
cos_lr=True,
|
||||
)
|
||||
|
||||
print("\n✅ 训练完成!")
|
||||
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
|
||||
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
|
||||
|
||||
if args.export_onnx:
|
||||
h, w = args.onnx_imgsz
|
||||
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
|
||||
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
|
||||
print("✅ ONNX 完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user