增加训练yolo的代码

This commit is contained in:
gcw_4spBpAfv
2026-05-15 09:35:53 +08:00
parent dff5096164
commit 541418fd60
13 changed files with 953 additions and 140 deletions

11
adc.py
View File

@@ -4,14 +4,13 @@ from maix import time
a = adc.ADC(0, adc.RES_BIT_12) a = adc.ADC(0, adc.RES_BIT_12)
while True: while True:
raw_data = a.read() # raw_data = a.read()
print(f"ADC raw data:{raw_data}") # print(f"ADC raw data:{raw_data}")
# if raw_data > 2450: # if raw_data > 2450:
# print(f"ADC raw data:{raw_data}") # print(f"ADC raw data:{raw_data}")
# elif raw_data < 2000: # elif raw_data < 2000:
# print(f"ADC raw data:{raw_data}") # print(f"ADC raw data:{raw_data}")
# time.sleep_ms(50) time.sleep_ms(1)
# vol = a.read_vol() vol = int(a.read_vol() * 10) / 10
print(f"ADC vol:{vol:.1f}, {time.time():.4f}")
# print(f"ADC vol:{vol}")

View File

@@ -1,11 +1,12 @@
id: t11 id: t11
name: t11 name: t11
version: 1.2.11 version: 1.2.12
author: t11 author: t11
icon: '' icon: ''
desc: t11 desc: t11
files: files:
- 4g_download_manager.py - 4g_download_manager.py
- 4g_upload_manager.py
- app.yaml - app.yaml
- archery_netcore.cpython-311-riscv64-linux-gnu.so - archery_netcore.cpython-311-riscv64-linux-gnu.so
- at_client.py - at_client.py
@@ -34,3 +35,4 @@ files:
- vision.py - vision.py
- wifi_config_httpd.py - wifi_config_httpd.py
- wifi.py - wifi.py
- wpa_supplicant_conf.py

View File

@@ -90,5 +90,13 @@ opencv_interactive-calibration -t=chessboard -w=9 -h=6 -sz=0.025 -v="http://192.
D:\data\test_target_photo 是用来叠加的背景图 D:\data\test_target_photo 是用来叠加的背景图
7.1 生成靶纸及黑色三角形的截图的图片带动动但1.12的外框 7.1 生成靶纸及黑色三角形的截图的图片带动动但1.12的外框
bak
python .\synth_compose_yolo.py --perspective 0.04 --perspective-prob 0.8 --color-jitter 0.6 --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --out ./synth_out --class-name triangle --zip ./maix_dataset.zip --num 60 --triangles-json archery_triangles_default.json --format voc --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --triangle-bbox-pad-frac 0.12 python .\synth_compose_yolo.py --perspective 0.04 --perspective-prob 0.8 --color-jitter 0.6 --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --out ./synth_out --class-name triangle --zip ./maix_dataset.zip --num 60 --triangles-json archery_triangles_default.json --format voc --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --triangle-bbox-pad-frac 0.12
bak_2
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 1.0 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
python pose_pixel_metrics.py --model D:\code\archery\runs\pose\runs\pose\target_pose_train\weights\best.pt --data D:\code\archery\datasets\dataset_pose.yaml --imgsz 640

50
test/test_audio.py Normal file
View File

@@ -0,0 +1,50 @@
# test_audio.pyx
from maix import audio, time, app, gpio
def run_player_loop():
"""
播放控制主循环函数
"""
# 初始化音频播放器
p = audio.Player("/root/gun.wav")
p.volume(40)
# 初始化 GPIO 引脚为输出
led = gpio.GPIO("A25", gpio.Mode.OUT)
# 设置低电平
led.value(0)
# 主循环
while not app.need_exit():
led.value(1) # 点亮 LED
time.sleep_ms(200) # 保持 200ms
led.value(0) # 熄灭 LED
p.play() # 播放音频
time.sleep_ms(1000) # 等待 1 秒
print("play finish!")
# 可选:添加一个简单的测试函数
def hello():
return "Hello from test_audio!"
# 可选:添加一个初始化函数
def init_led():
"""单独测试 GPIO"""
led = gpio.GPIO("A25", gpio.Mode.OUT)
led.value(0)
return "LED initialized"
# 可选:添加一个播放测试函数
def test_play():
"""单独测试音频播放"""
p = audio.Player("/root/gun.wav")
p.volume(50)
p.play()
return "Playing..."
run_player_loop()

25
test/test_button.py Normal file
View File

@@ -0,0 +1,25 @@
from maix import audio, time, app,gpio
# button1 = gpio.GPIO("ADC", gpio.Mode.IN)
button3 = gpio.GPIO("A26", gpio.Mode.IN) # 可用
button2 = gpio.GPIO("A16", gpio.Mode.IN)
#设置低电平
from maix.peripheral import adc
channel = 0
res_bit = adc.RES_BIT_12
_adc_obj = adc.ADC(channel, res_bit)
while not app.need_exit():
# print(f"b1: {button1.value()}")
print(f"b2: {button2.value()}")
# print(_adc_obj.read_vol())
print(f"b3: {button3.value()}")
time.sleep_ms(50)
# time.sleep_ms(1000)

View File

@@ -3,7 +3,10 @@ from maix import camera, display, time
try: try:
print("Initializing camera...") print("Initializing camera...")
cam = camera.Camera(640, 480) cam = camera.Camera(640,480)
# cam = camera.Camera(1280,720)
# cam.get_exposure_us()
# print("Camera exposure: ", cam.get_exposure_us())
print("Camera initialized successfully!") print("Camera initialized successfully!")
disp = display.Display() disp = display.Display()

View File

@@ -1,172 +1,246 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
激光模块测试脚本 M01激光测距模块测试脚本 - 修正版
用于诊断激光开关问题 基于文档中的完整命令示例
使用方法:
python test_laser.py
功能:
1. 初始化串口
2. 循环测试激光开/关
3. 打印详细调试信息
""" """
from maix import uart, pinmap, time from maix import uart, pinmap, time
import binascii
# ==================== 配置 ==================== # ==================== 配置 ====================
UART_PORT = "/dev/ttyS1" # 激光模块连接的串口UART1 UART_PORT = "/dev/ttyS1"
BAUDRATE = 9600 # 波特率 BAUDRATE = 9600
# 引脚映射(确保与硬件连接一致)
print("=" * 50)
print("🔧 步骤1: 配置引脚映射")
print("=" * 50)
# 初始化串口
try: try:
pinmap.set_pin_function("A18", "UART1_RX") pinmap.set_pin_function("A18", "UART1_RX")
print("✅ A18 -> UART1_RX")
except Exception as e:
print(f"❌ A18 配置失败: {e}")
try:
pinmap.set_pin_function("A19", "UART1_TX") pinmap.set_pin_function("A19", "UART1_TX")
print("✅ A19 -> UART1_TX")
except Exception as e:
print(f"❌ A19 配置失败: {e}")
# ==================== 激光控制指令 ====================
MODULE_ADDR = 0x00
# 原始命令
LASER_ON_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
LASER_OFF_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
# 备用命令格式(如果原始命令不工作,可以尝试这些)
# 格式1: 简化命令
LASER_ON_CMD_ALT1 = bytes([0xAA, 0x01, 0x01])
LASER_OFF_CMD_ALT1 = bytes([0xAA, 0x01, 0x00])
# 格式2: 不同的协议头
LASER_ON_CMD_ALT2 = bytes([0x55, 0xAA, 0x01])
LASER_OFF_CMD_ALT2 = bytes([0x55, 0xAA, 0x00])
print("\n" + "=" * 50)
print("🔧 步骤2: 初始化串口")
print("=" * 50)
print(f"设备: {UART_PORT}")
print(f"波特率: {BAUDRATE}")
try:
laser_uart = uart.UART(UART_PORT, BAUDRATE) laser_uart = uart.UART(UART_PORT, BAUDRATE)
print(f"串口初始化成功: {laser_uart}") print("硬件初始化")
except Exception as e: except Exception as e:
print(f"串口初始化失败: {e}") print(f"❌ 初始化失败: {e}")
exit(1) exit(1)
# ==================== 测试函数 ==================== # ==================== 根据文档的完整命令集 ====================
def send_and_check(cmd, name): # 1. 激光开关文档2.3.10,已验证可用)
"""发送命令并检查回包""" LASER_ON_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
print(f"\n📤 发送: {name}") LASER_OFF_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
print(f" 命令字节: {cmd.hex()}")
print(f" 命令长度: {len(cmd)} 字节") # 2. 尝试不同的测距命令格式
TEST_COMMANDS = [
# 清空接收缓冲区 # 格式1文档2.3.12的单次测量(您测试失败的)
{
"name": "单次测量 (0x0020)",
"cmd": bytes([0xAA, 0x00, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21]),
"desc": "文档2.3.12 示例命令"
},
# 格式2文档2.3.7的读取测量结果
{
"name": "读取测量结果 (0x0022)",
"cmd": bytes([0xAA, 0x80, 0x00, 0x22, 0xA2]),
"desc": "文档2.3.7 读取测量结果"
},
# 格式3文档2.3.13的快速测量
{
"name": "快速测量 (0x0022带数据)",
"cmd": bytes([0xAA, 0x00, 0x00, 0x22, 0x00, 0x01, 0x00, 0x00, 0x23]),
"desc": "文档2.3.13 快速测量"
},
# 格式4连续测量模式
{
"name": "连续测量模式 (0x0021)",
"cmd": bytes([0xAA, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x22]),
"desc": "文档2.3.14 连续测量"
}
]
def clear_buffer():
"""清空串口缓冲区"""
try: try:
old_data = laser_uart.read(-1) data = laser_uart.read(-1)
if old_data: if data:
print(f" 清空缓冲区: {len(old_data)} 字节") print(f"清空: {len(data)}字节")
except: except:
pass pass
def send_and_wait(cmd, name, wait_time=2000):
"""发送命令并等待响应"""
print(f"\n📤 发送: {name}")
print(f" 命令: {cmd.hex()}")
clear_buffer()
# 发送命令
try: try:
written = laser_uart.write(cmd) laser_uart.write(cmd)
print(f" 写入字节数: {written}") print(f" 已发送 {len(cmd)} 字节")
except Exception as e: except Exception as e:
print(f"写入失败: {e}") print(f"发送失败: {e}")
return None return None
# 等待响应 # 等待响应
time.sleep_ms(100) start_time = time.ticks_ms()
response = b""
# 读取回包 while time.ticks_ms() - start_time < wait_time:
try: try:
resp = laser_uart.read(50) chunk = laser_uart.read(1)
if resp: if chunk:
print(f" 📥 收到回包: {resp.hex()} ({len(resp)} 字节)") response += chunk
return resp # 完整响应通常是9或13字节
else: if len(response) >= 9:
print(f" ⚠️ 无回包") # 检查是否完整帧
return None if response[0] in [0xAA, 0xEE]:
except Exception as e: if len(response) >= 13: # 测距完整响应
print(f" ❌ 读取失败: {e}") break
return None elif response[0] == 0xEE: # 错误响应
break
except:
break
time.sleep_ms(10)
if response:
print(f" 📥 响应: {response.hex()}")
print(f" 长度: {len(response)} 字节")
# 解析错误码
if response[0] == 0xEE and len(response) >= 9:
err_code = (response[7] << 8) | response[8]
error_mapping = {
0x0000: "无错误",
0x0001: "硬件错误",
0x0002: "无输出数据",
0x0003: "反射信号太弱",
0x0004: "反射信号太强",
0x0005: "温度太高(>40℃)",
0x0006: "温度太低(<-10℃)",
0x0007: "电源电压低(<2.5V)",
0x0008: "超出量程",
0x0009: "读通讯错误",
0x000A: "写通讯错误",
0x000B: "地址错误"
}
err_msg = error_mapping.get(err_code, f"未知错误: 0x{err_code:04X}")
print(f" ❌ 模块错误: {err_msg}")
else:
print(" ⚠️ 无响应")
return response
def test_laser_cycle(on_cmd, off_cmd, cmd_name="标准命令"): def parse_distance_data(response):
"""测试一个开关周期""" """解析距离数据"""
print(f"\n{'='*50}") if not response or len(response) < 13:
print(f"🧪 测试 {cmd_name}") return None
print(f"{'='*50}")
print("\n>>> 测试开启激光") if response[0] != 0xAA or response[3] not in [0x20, 0x21, 0x22]:
send_and_check(on_cmd, f"{cmd_name} - 开启") return None
print(" ⏱️ 等待 2 秒观察激光是否亮起...")
time.sleep(2)
print("\n>>> 测试关闭激光") # 解析4字节BCD码
send_and_check(off_cmd, f"{cmd_name} - 关闭") bcd_bytes = response[6:10]
print(" ⏱️ 等待 2 秒观察激光是否熄灭...") distance_int = 0
time.sleep(2)
for byte in bcd_bytes:
high = (byte >> 4) & 0x0F
low = byte & 0x0F
if high > 9 or low > 9:
return None
distance_int = distance_int * 100 + high * 10 + low
distance_m = distance_int / 1000.0
# 信号质量
signal = 0
if len(response) >= 12:
signal = (response[10] << 8) | response[11]
return {
'meters': distance_m,
'millimeters': distance_m * 1000,
'signal': signal,
'raw': response.hex()
}
# ==================== 主测试 ==================== # ==================== 主测试 ====================
print("\n" + "=" * 50) print("\n" + "="*50)
print("🚀 开始激光测试") print("M01激光测距模块详细测试")
print("=" * 50) print("="*50)
print("\n请观察激光模块的状态变化...")
print("测试将依次尝试不同的命令格式\n")
try: try:
# 测试1: 标准命令 # 1. 测试基本连接
test_laser_cycle(LASER_ON_CMD, LASER_OFF_CMD, "标准命令") print("\n1. 测试模块连接...")
version_cmd = bytes([0xAA, 0x80, 0x00, 0x0A, 0x8A])
resp = send_and_wait(version_cmd, "读取硬件版本")
input("\n按回车继续测试备用命令1...") if resp and resp[0] == 0xAA and resp[3] == 0x0A:
print(f"✅ 模块正常,版本: {resp[6]:02X}{resp[7]:02X}")
else:
print("❌ 模块连接测试失败")
exit(1)
# 测试2: 备用命令格式1 # 2. 开启激光
test_laser_cycle(LASER_ON_CMD_ALT1, LASER_OFF_CMD_ALT1, "备用命令1 (简化)") print("\n2. 开启激光...")
resp = send_and_wait(LASER_ON_CMD, "开启激光", 1000)
if resp and resp.hex() == "aa0001be00010001c1":
print("✅ 激光已开启")
input("\n按回车继续测试备用命令2...") print(" 等待激光稳定...")
time.sleep(2) # 重要等待时间
# 测试3: 备用命令格式2 # 3. 尝试不同的测距命令
test_laser_cycle(LASER_ON_CMD_ALT2, LASER_OFF_CMD_ALT2, "备用命令2 (0x55AA头)") print("\n3. 测试不同测距命令...")
print("\n" + "=" * 50) for i, test_cmd in enumerate(TEST_COMMANDS):
print(f"\n{'='*30}")
print(f"测试 {i+1}: {test_cmd['name']}")
print(f"{test_cmd['desc']}")
print(f"{'='*30}")
resp = send_and_wait(test_cmd['cmd'], test_cmd['name'], 3000)
if resp:
if resp[0] == 0xAA and len(resp) >= 13:
result = parse_distance_data(resp)
if result:
print(f"✅ 测距成功!")
print(f" 距离: {result['meters']:.3f} m")
print(f" 距离: {result['millimeters']:.1f} mm")
print(f" 信号质量: {result['signal']}")
break
else:
print("❌ 无法解析距离数据")
elif resp[0] == 0xEE:
print("❌ 命令执行错误")
else:
print("❌ 无效响应格式")
else:
print("❌ 无响应")
time.sleep(1) # 命令间间隔
# 4. 关闭激光
print("\n4. 关闭激光...")
send_and_wait(LASER_OFF_CMD, "关闭激光", 1000)
print("\n" + "="*50)
print("🏁 测试完成") print("🏁 测试完成")
print("=" * 50) print("="*50)
print("\n诊断建议:")
print("1. 如果激光始终不亮/始终亮") print("\n📋 测试总结")
print(" - 检查激光模块的电源连接") print("1. 模块通信: ✅ 正常")
print(" - 检查串口TX/RX是否接反") print("2. 激光控制: ✅ 正常")
print(" - 尝试不同的波特率 (4800/19200)") print("3. 测距功能: ❌ 有问题")
print("") print("\n建议:")
print("2. 如果有回包但激光无反应:") print("1. 检查激光是否实际发光(在暗处观察红点)")
print(" - 命令格式可能正确但激光硬件问题") print("2. 确保测量目标在有效范围内0.2-60米")
print("") print("3. 确保目标有足够反射率(白色平面最佳)")
print("3. 如果某个备用命令有效:") print("4. 如果所有测距命令都返回ERR_ADDR可能是固件版本问题")
print(" - 需要更新 config.py 中的命令格式")
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n🛑 测试被中断") print("\n\n🛑 用户中断")
# 确保激光关闭
laser_uart.write(LASER_OFF_CMD) laser_uart.write(LASER_OFF_CMD)
print("✅ 已发送关闭指令") print("✅ 已发送关闭指令")
except Exception as e: except Exception as e:
print(f"\n❌ 测试出错: {e}") print(f"\n❌ 测试出错: {e}")
import traceback
traceback.print_exc()

16
test/test_motor.py Normal file
View File

@@ -0,0 +1,16 @@
from maix import gpio, pinmap, time
#设置引脚为输出
led = gpio.GPIO("A25", gpio.Mode.OUT)
#设置低电平
led.value(0)
while 1:
# time.sleep_ms(1000)
#对该引脚的电平进行取反(原高-》现低)
# led.toggle()
led.value(1)
#延时
time.sleep_ms(5000)
led.value(0)

View File

@@ -28,6 +28,11 @@ Stage2 ROI对齐「先检整靶再裁小图」的第二步输入
- --motion-prob施加概率--motion-kernel-min/max模糊 streak 长度奇数核越大越糊 - --motion-prob施加概率--motion-kernel-min/max模糊 streak 长度奇数核越大越糊
- 可与 --blur-max 高斯模糊叠加Stage2 建议--motion-prob 0.5~0.7 --motion-kernel-max 35 --blur-max 1.2 - 可与 --blur-max 高斯模糊叠加Stage2 建议--motion-prob 0.5~0.7 --motion-kernel-max 35 --blur-max 1.2
透视有两种模式 --perspective-mode
- corner四角独立随机抖动旧版组合后易出现不物理的夸张梯形--perspective jitter 强度
- planar把靶当作平面 yaw/pitch/roll随机旋转再投影便于限定在例如偏航 ±45°俯仰 ±30°
更接近 5m 外拍 40cm 靶等场景焦距用 --planar-focal-frac 控制视场
依赖OpenCV + NumPyPC 上跑即可Maix 上若内存够也可试 依赖OpenCV + NumPyPC 上跑即可Maix 上若内存够也可试
示例 示例
@@ -266,6 +271,74 @@ def _paste_fg_on_bg(bg_bgr, x, y, fg_scaled_bgra):
roi_bg[:] = blended.astype(np.uint8) roi_bg[:] = blended.astype(np.uint8)
def _rotation_matrix_ypr(yaw_deg: float, pitch_deg: float, roll_deg: float, np) -> np.ndarray:
"""靶平面绕相机坐标原点旋转:先 yaw(Y),再 pitch(X),再 roll(Z)。单位:度。"""
y, p, r = np.radians([yaw_deg, pitch_deg, roll_deg])
cy, sy = np.cos(y), np.sin(y)
cp, sp = np.cos(p), np.sin(p)
cr, sr = np.cos(r), np.sin(r)
Ry = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]], dtype=np.float64)
Rx = np.array([[1.0, 0.0, 0.0], [0.0, cp, -sp], [0.0, sp, cp]], dtype=np.float64)
Rz = np.array([[cr, -sr, 0.0], [sr, cr, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64)
return Rz @ Rx @ Ry
def _perspective_warp_planar_rgba(
img_bgra,
yaw_deg: float,
pitch_deg: float,
roll_deg: float,
focal_frac: float,
np,
cv2,
):
"""
平面靶透视 frontal 图像视作 z=1 平面上的针孔投影再旋转三维射线方向等价于靶相对相机倾斜
focal_frac焦距 fx=fy=max(w,h)*focal_frac越大视场越窄透视越温和
返回 (warped BGRA, M)退化时返回 (copy, None)
"""
h, w = img_bgra.shape[:2]
if min(w, h) < 16:
return img_bgra.copy(), None
fx = fy = float(max(w, h) * max(0.25, focal_frac))
cx, cy = w * 0.5, h * 0.5
R = _rotation_matrix_ypr(yaw_deg, pitch_deg, roll_deg, np)
uv_dst: list[list[float]] = []
z_min = 0.08
for u, v in ((0.0, 0.0), (float(w), 0.0), (float(w), float(h)), (0.0, float(h))):
x = (u - cx) / fx
y = (v - cy) / fy
z = 1.0
vec = R @ np.array([x, y, z], dtype=np.float64)
if float(vec[2]) < z_min:
return img_bgra.copy(), None
uu = fx * (vec[0] / vec[2]) + cx
vv = fy * (vec[1] / vec[2]) + cy
uv_dst.append([uu, vv])
pts_dst = np.float32(uv_dst)
xmin = float(pts_dst[:, 0].min())
ymin = float(pts_dst[:, 1].min())
pts_shift = pts_dst.copy()
pts_shift[:, 0] -= xmin
pts_shift[:, 1] -= ymin
out_w = max(4, int(np.ceil(float(pts_shift[:, 0].max()))) + 2)
out_h = max(4, int(np.ceil(float(pts_shift[:, 1].max()))) + 2)
pts_src = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
M = cv2.getPerspectiveTransform(pts_src, pts_shift)
warped = cv2.warpPerspective(
img_bgra,
M,
(out_w, out_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0, 0),
)
return warped, M
def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np, cv2): def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np, cv2):
""" """
对前景做轻微透视四角微移返回 (warped BGRA, M) 对前景做轻微透视四角微移返回 (warped BGRA, M)
@@ -311,6 +384,26 @@ def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np,
return warped, M return warped, M
def _apply_perspective_warp_fg(fg_bgra, rng: random.Random, np, cv2, args) -> tuple:
"""在已由 perspective_prob 抽中的样本上做一次透视。返回 (fg_out, M|None)。"""
mode = getattr(args, "perspective_mode", "corner")
if mode == "planar":
x = rng.betavariate(2,2)
yaw = -float(args.yaw_max_deg) + x * (float(args.yaw_max_deg) +float(args.yaw_max_deg))
x = rng.betavariate(2,2)
pitch = -float(args.pitch_max_deg) + x * (float(args.pitch_max_deg) +float(args.pitch_max_deg))
x = rng.betavariate(2,2)
roll = -float(args.roll_max_deg) + x * (float(args.roll_max_deg) +float(args.roll_max_deg))
focal_frac = float(getattr(args, "planar_focal_frac", 1.25))
return _perspective_warp_planar_rgba(fg_bgra, yaw, pitch, roll, focal_frac, np, cv2)
jitter = float(getattr(args, "perspective", 0.0))
if jitter <= 0:
return fg_bgra.copy(), None
return _perspective_warp_rgba(fg_bgra, jitter, rng, np, cv2)
def _color_jitter_bgr(comp_bgr, strength: float, rng: random.Random, np, cv2): def _color_jitter_bgr(comp_bgr, strength: float, rng: random.Random, np, cv2):
"""整图 HSV 抖动strength∈[0,1] 越大越强。""" """整图 HSV 抖动strength∈[0,1] 越大越强。"""
if strength <= 1e-6: if strength <= 1e-6:
@@ -519,11 +612,17 @@ def main():
help="运动模糊 streak 长度上限,越大越像长曝光/手抖", help="运动模糊 streak 长度上限,越大越像长曝光/手抖",
) )
ap.add_argument("--jpeg-quality", type=int, default=92) ap.add_argument("--jpeg-quality", type=int, default=92)
ap.add_argument(
"--perspective-mode",
choices=("corner", "planar"),
default="corner",
help="corner=四角随机抖动易夸张planar=yaw/pitch/roll 平面投影(易限定视姿)",
)
ap.add_argument( ap.add_argument(
"--perspective", "--perspective",
type=float, type=float,
default=0.0, default=0.0,
help="轻微透视:四角扰动约为 min(靶宽,靶高)×该系数0 关闭(建议 0.02~0.06", help="仅 perspective-mode=corner:四角扰动min(靶宽,靶高)×该系数0 关闭(温和可用 0.015~0.03",
) )
ap.add_argument( ap.add_argument(
"--perspective-prob", "--perspective-prob",
@@ -531,6 +630,30 @@ def main():
default=0.75, default=0.75,
help="每张图应用透视的概率 0~1", help="每张图应用透视的概率 0~1",
) )
ap.add_argument(
"--yaw-max-deg",
type=float,
default=45.0,
help="planar偏航角均匀采样上界实际 ∈[-max,max]",
)
ap.add_argument(
"--pitch-max-deg",
type=float,
default=30.0,
help="planar俯仰角均匀采样上界",
)
ap.add_argument(
"--roll-max-deg",
type=float,
default=8.0,
help="planar滚转角均匀采样上界手持可略大",
)
ap.add_argument(
"--planar-focal-frac",
type=float,
default=1.25,
help="planarfx=fy=max(w,h)×该值,越大透视越温和(建议 1.1~1.8",
)
ap.add_argument( ap.add_argument(
"--color-jitter", "--color-jitter",
type=float, type=float,
@@ -662,8 +785,15 @@ def main():
fg_s = cv2.resize(fg_crop, (new_w, new_h), interpolation=cv2.INTER_AREA) fg_s = cv2.resize(fg_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
persp_M = None persp_M = None
if args.perspective > 0 and rng.random() < args.perspective_prob: want_p = (
fg_s, persp_M = _perspective_warp_rgba(fg_s, args.perspective, rng, np, cv2) rng.random() < args.perspective_prob
and (
args.perspective_mode == "planar"
or (args.perspective_mode == "corner" and args.perspective > 0)
)
)
if want_p:
fg_s, persp_M = _apply_perspective_warp_fg(fg_s, rng, np, cv2, args)
fw2, fh2 = fg_s.shape[1], fg_s.shape[0] fw2, fh2 = fg_s.shape[1], fg_s.shape[0]
tx0, ty0, tw, th = _fg_bbox_from_alpha(fg_s) tx0, ty0, tw, th = _fg_bbox_from_alpha(fg_s)

View File

@@ -0,0 +1,506 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
YOLO11 关键点检测训练脚本(靶纸四角)。
设备优先级(--device autoIntel XPU > NVIDIA CUDA > CPU。
默认 imgsz=960批大小默认 4大图显存紧张时可再降
关于「业务像素误差」:
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
- 监控:--pixel-metrics-every N每 N 个 epoch 打印 mean/p95并合并进 runs/.../results.csv见 pose_pixel_metrics.py
- 选 best.pt / early stopping加 --best-by-pixel用验证集 mean 像素误差(与 pose_pixel_metrics
同一口径)代替 mAP 合成 fitnessfitness = -mean_px越小越好
多卡 DDPworld_size>1时会自动退回默认 mAP fitness。
XPUUltralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu
务必使用 pose 任务YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect
会把 17 列 Pose 标注当成检测/分割解析校验时出现「coordinates > 1」或 [2.] 等假象。
"""
from __future__ import annotations
import argparse
import csv
import gc
import glob
import os
import tempfile
from copy import deepcopy
from pathlib import Path
import torch
from ultralytics import YOLO
from pose_pixel_metrics import eval_val_pixel_error
import warnings
warnings.filterwarnings('ignore',
message=".*scatter_add_kernel does not have a deterministic implementation.*")
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
"""删除 data.yaml 的 path 下 labels/*.cache。
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
from ultralytics.utils import YAML
try:
cfg = YAML.load(data_yaml_path)
except Exception:
return 0
root = cfg.get("path")
if not root:
return 0
root = os.path.abspath(os.path.expanduser(str(root)))
pattern = os.path.join(root, "labels", "*.cache")
n = 0
for p in glob.glob(pattern):
try:
os.unlink(p)
n += 1
except OSError:
pass
return n
def _pick_device(explicit: str | None):
"""返回 ultralytics train/predict 可用的 device。"""
if explicit and explicit != "auto":
e = explicit.lower()
if e == "xpu":
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
return torch.device("xpu")
if e in ("0", "cuda", "gpu"):
if not torch.cuda.is_available():
raise RuntimeError("指定了 CUDA 但不可用")
return 0
if e == "cpu":
return "cpu"
return explicit
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
return torch.device("xpu")
if torch.cuda.is_available():
return 0
return "cpu"
def _default_amp(device) -> bool:
if isinstance(device, torch.device) and device.type == "xpu":
return False
if device == "cpu":
return False
return True
def _patch_ultralytics_for_xpu():
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
import ultralytics.engine.trainer as ut_trainer
import ultralytics.engine.validator as ut_validator
from ultralytics.utils.torch_utils import select_device as _original_select_device
# 1. 覆盖 select_deviceTrainer 初始化传入 torch.device("xpu") 会走原版早返回;
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device不调用 select_device
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
def _patched_select_device(device="", *args, **kwargs):
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
# Older forks sometimes passed extra positional args; forward everything.
if isinstance(device, str):
d = device.strip().lower()
if d == "xpu" or d.startswith("xpu:"):
return torch.device(device.strip())
return _original_select_device(device, *args, **kwargs)
import ultralytics.utils.torch_utils
ultralytics.utils.torch_utils.select_device = _patched_select_device
ut_trainer.select_device = _patched_select_device
ut_validator.select_device = _patched_select_device
# 2. 修补 Trainer 的内存函数
BT = ut_trainer.BaseTrainer
if not getattr(BT, "_archery_xpu_memory_patched", False):
_orig_get_memory = BT._get_memory
_orig_clear_memory = BT._clear_memory
def _get_memory(self, fraction=False):
if self.device.type != "xpu":
return _orig_get_memory(self, fraction)
# ... (原有的 XPU 内存获取逻辑保持不变) ...
memory, total = 0, 0
try:
idx = self.device.index
if idx is None:
idx = torch.xpu.current_device()
memory = int(torch.xpu.memory_allocated(idx))
if fraction:
total = int(torch.xpu.get_device_properties(idx).total_memory)
except Exception:
pass
return (memory / total) if fraction and total > 0 else (memory / 2**30)
def _clear_memory(self, threshold=None):
if self.device.type != "xpu":
return _orig_clear_memory(self, threshold)
if threshold is not None:
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
if self._get_memory(fraction=True) <= threshold:
return
gc.collect()
if hasattr(torch.xpu, "empty_cache"):
torch.xpu.empty_cache()
BT._get_memory = _get_memory
BT._clear_memory = _clear_memory
BT._archery_xpu_memory_patched = True
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
BV = ut_validator.BaseValidator
if not getattr(BV, "_archery_xpu_memory_patched", False):
# 为 Validator 添加同样的内存处理方法
BV._get_memory = _get_memory
BV._clear_memory = _clear_memory
BV._archery_xpu_memory_patched = True
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
"""用验证集关键点像素 mean 替代 mAP fitness驱动 best.pt 与 patience early stopping。"""
import ultralytics.engine.trainer as ut
from ultralytics.utils import RANK
BT = ut.BaseTrainer
if getattr(BT, "_archery_best_by_pixel_installed", False):
return
_orig_validate = BT.validate
def validate(self):
import torch.distributed as dist
if self.ema and self.world_size > 1:
for buffer in self.ema.ema.buffers():
dist.broadcast(buffer, src=0)
metrics = self.validator(self)
if metrics is None:
return None, None
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
mean_px: float | None = None
if use_pixel:
tmp_path: str | None = None
try:
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
os.close(fd)
from ultralytics.utils.torch_utils import unwrap_model
core = unwrap_model(self.ema.ema if self.ema else self.model)
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
probe = YOLO(tmp_path)
stats = eval_val_pixel_error(
probe,
data_yaml,
device=self.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
if mean_px is None:
raise RuntimeError("无有效 mean_px检查 val 标签与检测是否为空)")
except Exception as exc:
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
mean_px = None
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
if mean_px is not None:
fitness = -float(mean_px)
metrics["metrics/mean_px(val)"] = float(mean_px)
else:
fitness = float(orig_fitness)
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
BT.validate = validate
BT._archery_best_by_pixel_installed = True
def _fmt_csv_metric(v: float | int | None) -> str:
if v is None:
return ""
if isinstance(v, float):
return f"{v:.6g}"
return str(v)
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
("pixel_error/mean_px", "mean_px"),
("pixel_error/median_px", "median_px"),
("pixel_error/p95_px", "p95_px"),
("pixel_error/max_px", "max_px"),
("pixel_error/n_points", "n_points"),
("pixel_error/n_images", "n_images"),
("pixel_error/skip_no_det", "skip_no_det"),
("pixel_error/skip_no_gt", "skip_no_gt"),
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
)
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv扩展表头、补空列"""
csv_path = Path(save_dir) / "results.csv"
if not csv_path.is_file():
return
try:
with open(csv_path, newline="", encoding="utf-8") as f:
rows = list(csv.reader(f))
except OSError:
return
if len(rows) < 2:
return
header = list(rows[0])
for col_name, _ in _PIXEL_METRIC_COLUMNS:
if col_name not in header:
header.append(col_name)
for ri in range(1, len(rows)):
rows[ri].append("")
col_ix = {name: i for i, name in enumerate(header)}
rows[0] = header
target_ri: int | None = None
for ri in range(1, len(rows)):
row = rows[ri]
while len(row) < len(header):
row.append("")
try:
if int(float(row[0].strip())) == int(epoch_1based):
target_ri = ri
except (ValueError, IndexError):
continue
if target_ri is None:
return
row = rows[target_ri]
while len(row) < len(header):
row.append("")
for col_name, sk in _PIXEL_METRIC_COLUMNS:
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
try:
with open(csv_path, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerows(rows)
except OSError:
pass
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
def on_fit_epoch_end(trainer):
from ultralytics.utils import RANK
if RANK not in {-1, 0}:
return
if every <= 0:
return
ep = int(getattr(trainer, "epoch", -1))
if (ep + 1) % every != 0:
return
w = Path(trainer.save_dir) / "weights" / "last.pt"
if not w.is_file():
return
m = YOLO(str(w))
stats = eval_val_pixel_error(
m,
data_yaml,
device=trainer.device,
imgsz=imgsz,
conf=conf,
)
mean_px = stats.get("mean_px")
p95_px = stats.get("p95_px")
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
print(
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
f"n_points={stats.get('n_points', 0)} "
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
)
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
return on_fit_epoch_end
def main():
ap = argparse.ArgumentParser(description="YOLO Pose 训练XPU/CUDA/CPU")
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
ap.add_argument("--epochs", type=int, default=100)
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960")
ap.add_argument("--batch", type=int, default=4, help="批大小OOM 时减小")
ap.add_argument(
"--device",
default="auto",
help="auto | xpu | 0 | cuda | cpuautoXPU 优先)",
)
ap.add_argument(
"--no-amp",
action="store_true",
help="关闭混合精度默认CUDA 开启XPU/CPU 关闭)",
)
ap.add_argument("--project", default="runs/pose")
ap.add_argument("--name", default="target_pose_train")
ap.add_argument("--workers", type=int, default=4)
ap.add_argument(
"--pixel-metrics-every",
type=int,
default=0,
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行0=关闭);需 labels 与 data.yaml 布局一致",
)
ap.add_argument(
"--pixel-metrics-conf",
type=float,
default=0.25,
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25",
)
ap.add_argument(
"--best-by-pixel",
action="store_true",
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metricsfitness=-mean_px单卡有效DDP 自动退回 mAP",
)
ap.add_argument(
"--pixel-fitness-conf",
type=float,
default=0.25,
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
)
ap.add_argument(
"--export-onnx",
action="store_true",
help="训练结束后导出 ONNX需再设 --onnx-imgsz",
)
ap.add_argument(
"--onnx-imgsz",
type=int,
nargs=2,
metavar=("H", "W"),
default=[224, 320],
help="导出 ONNX 的 [高, 宽],默认 224 320Maix 常用)",
)
ap.add_argument(
"--clear-label-cache",
action="store_true",
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache修正标注后仍报 corrupt 时用)",
)
args = ap.parse_args()
device = _pick_device(None if args.device == "auto" else args.device)
use_amp = False if args.no_amp else _default_amp(device)
if isinstance(device, torch.device) and device.type == "xpu":
print(f"✅ 使用 Intel XPU: {device}")
elif device == 0 or device == "0":
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ 使用 CPU训练会较慢")
if isinstance(device, torch.device) and device.type == "xpu":
_patch_ultralytics_for_xpu()
data_yaml = args.data
if not os.path.isabs(data_yaml):
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
if not os.path.exists(data_yaml):
print(f"❌ 数据集配置不存在: {data_yaml}")
return
if args.clear_label_cache:
n_rm = _clear_ultralytics_label_caches(data_yaml)
print(f"🗑️ 已删除标签目录缓存 {n_rm}labels/*.cache将强制重新扫描标注。")
print(f"📦 加载模型: {args.model}(固定 task=pose")
model = YOLO(args.model, task="pose")
if args.best_by_pixel:
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
print(
"📌 已启用 --best-by-pixelbest.pt / patience 按验证集 mean 像素误差fitness=-mean_px"
"反向传播仍为 Ultralytics 默认 pose/box loss。"
)
if args.pixel_metrics_every > 0:
model.add_callback(
"on_fit_epoch_end",
_make_pixel_metrics_callback(
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
),
)
model.train(
task="pose",
data=data_yaml,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
name=args.name,
project=args.project,
exist_ok=True,
save=True,
save_period=5,
device=device,
workers=args.workers,
lr0=0.0001,
lrf=0.01,
optimizer="AdamW",
momentum=0.937,
weight_decay=0.001,
warmup_epochs=0,
warmup_momentum=0.8,
warmup_bias_lr=0.1,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=5.0,
translate=0.0,
scale=0.2,
shear=0.0,
perspective=0.0000,
flipud=0.0,
fliplr=0.5,
mosaic=0.0,
mixup=0.0,
copy_paste=0.0,
box=6,
cls=0.5,
dfl=1.5,
pose=18.0,
kobj=0.5,
freeze=0,
seed=42,
verbose=True,
amp=use_amp,
patience=100,
cos_lr=True,
)
print("\n✅ 训练完成!")
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
if args.export_onnx:
h, w = args.onnx_imgsz
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
print("✅ ONNX 完成")
if __name__ == "__main__":
main()