From 541418fd606cdbcd7c9b415a2ecd7dcfd20920f5 Mon Sep 17 00:00:00 2001 From: gcw_4spBpAfv Date: Fri, 15 May 2026 09:35:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=AE=AD=E7=BB=83yolo?= =?UTF-8?q?=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- adc.py | 11 +- app.yaml | 4 +- design_doc/command_record.md | 8 + test/test_audio.py | 50 ++ test/test_button.py | 25 + test/test_cammera.py | 5 +- test/test_laser.py | 332 +++++++----- test/test_motor.py | 16 + {test => train_yolo}/synth_compose_yolo.py | 136 ++++- .../test_stage2_black_yolo_device.py | 0 .../test_triangle_one_image.py | 0 {test => train_yolo}/test_yolo_draw_boxes.py | 0 train_yolo/train_black_triangle_pos.py | 506 ++++++++++++++++++ 13 files changed, 953 insertions(+), 140 deletions(-) create mode 100644 test/test_audio.py create mode 100644 test/test_button.py create mode 100644 test/test_motor.py rename {test => train_yolo}/synth_compose_yolo.py (85%) rename {test => train_yolo}/test_stage2_black_yolo_device.py (100%) rename {test => train_yolo}/test_triangle_one_image.py (100%) rename {test => train_yolo}/test_yolo_draw_boxes.py (100%) create mode 100644 train_yolo/train_black_triangle_pos.py diff --git a/adc.py b/adc.py index 31c2815..0584e63 100644 --- a/adc.py +++ b/adc.py @@ -4,14 +4,13 @@ from maix import time a = adc.ADC(0, adc.RES_BIT_12) while True: - raw_data = a.read() - print(f"ADC raw data:{raw_data}") + # raw_data = a.read() + # print(f"ADC raw data:{raw_data}") # if raw_data > 2450: # print(f"ADC raw data:{raw_data}") # elif raw_data < 2000: # print(f"ADC raw data:{raw_data}") - # time.sleep_ms(50) + time.sleep_ms(1) - # vol = a.read_vol() - - # print(f"ADC vol:{vol}") \ No newline at end of file + vol = int(a.read_vol() * 10) / 10 + print(f"ADC vol:{vol:.1f}, {time.time():.4f}") diff --git a/app.yaml b/app.yaml index 705f9a6..691366e 100644 --- a/app.yaml +++ b/app.yaml @@ -1,11 +1,12 @@ id: t11 name: t11 -version: 1.2.11 +version: 1.2.12 author: t11 icon: '' desc: t11 files: - 4g_download_manager.py + - 4g_upload_manager.py - app.yaml - archery_netcore.cpython-311-riscv64-linux-gnu.so - at_client.py @@ -34,3 +35,4 @@ files: - vision.py - wifi_config_httpd.py - wifi.py + - wpa_supplicant_conf.py diff --git a/design_doc/command_record.md b/design_doc/command_record.md index 15efb6f..2ead611 100644 --- a/design_doc/command_record.md +++ b/design_doc/command_record.md @@ -90,5 +90,13 @@ opencv_interactive-calibration -t=chessboard -w=9 -h=6 -sz=0.025 -v="http://192. D:\data\test_target_photo 是用来叠加的背景图 7.1 生成靶纸及黑色三角形的截图的图片,带动动,但1.12的外框 +bak python .\synth_compose_yolo.py --perspective 0.04 --perspective-prob 0.8 --color-jitter 0.6 --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --out ./synth_out --class-name triangle --zip ./maix_dataset.zip --num 60 --triangles-json archery_triangles_default.json --format voc --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --triangle-bbox-pad-frac 0.12 +bak_2 +python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4 + +python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 1.0 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4 + + + python pose_pixel_metrics.py --model D:\code\archery\runs\pose\runs\pose\target_pose_train\weights\best.pt --data D:\code\archery\datasets\dataset_pose.yaml --imgsz 640 \ No newline at end of file diff --git a/test/test_audio.py b/test/test_audio.py new file mode 100644 index 0000000..c299f63 --- /dev/null +++ b/test/test_audio.py @@ -0,0 +1,50 @@ +# test_audio.pyx +from maix import audio, time, app, gpio + +def run_player_loop(): + """ + 播放控制主循环函数 + """ + # 初始化音频播放器 + p = audio.Player("/root/gun.wav") + p.volume(40) + + # 初始化 GPIO 引脚为输出 + led = gpio.GPIO("A25", gpio.Mode.OUT) + # 设置低电平 + led.value(0) + + # 主循环 + while not app.need_exit(): + led.value(1) # 点亮 LED + time.sleep_ms(200) # 保持 200ms + led.value(0) # 熄灭 LED + p.play() # 播放音频 + time.sleep_ms(1000) # 等待 1 秒 + + print("play finish!") + + +# 可选:添加一个简单的测试函数 +def hello(): + return "Hello from test_audio!" + + +# 可选:添加一个初始化函数 +def init_led(): + """单独测试 GPIO""" + led = gpio.GPIO("A25", gpio.Mode.OUT) + led.value(0) + return "LED initialized" + + +# 可选:添加一个播放测试函数 +def test_play(): + """单独测试音频播放""" + p = audio.Player("/root/gun.wav") + p.volume(50) + p.play() + return "Playing..." + + +run_player_loop() diff --git a/test/test_button.py b/test/test_button.py new file mode 100644 index 0000000..6fa118b --- /dev/null +++ b/test/test_button.py @@ -0,0 +1,25 @@ +from maix import audio, time, app,gpio + + +# button1 = gpio.GPIO("ADC", gpio.Mode.IN) +button3 = gpio.GPIO("A26", gpio.Mode.IN) # 可用 +button2 = gpio.GPIO("A16", gpio.Mode.IN) +#设置低电平 +from maix.peripheral import adc +channel = 0 +res_bit = adc.RES_BIT_12 +_adc_obj = adc.ADC(channel, res_bit) + + +while not app.need_exit(): + # print(f"b1: {button1.value()}") + + print(f"b2: {button2.value()}") + + # print(_adc_obj.read_vol()) + print(f"b3: {button3.value()}") + time.sleep_ms(50) + + # time.sleep_ms(1000) + + diff --git a/test/test_cammera.py b/test/test_cammera.py index f28eb14..e07de75 100644 --- a/test/test_cammera.py +++ b/test/test_cammera.py @@ -3,7 +3,10 @@ from maix import camera, display, time try: print("Initializing camera...") - cam = camera.Camera(640, 480) + cam = camera.Camera(640,480) + # cam = camera.Camera(1280,720) + # cam.get_exposure_us() + # print("Camera exposure: ", cam.get_exposure_us()) print("Camera initialized successfully!") disp = display.Display() diff --git a/test/test_laser.py b/test/test_laser.py index e54d743..a20a379 100644 --- a/test/test_laser.py +++ b/test/test_laser.py @@ -1,172 +1,246 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -激光模块测试脚本 -用于诊断激光开关问题 - -使用方法: -python test_laser.py - -功能: -1. 初始化串口 -2. 循环测试激光开/关 -3. 打印详细调试信息 +M01激光测距模块测试脚本 - 修正版 +基于文档中的完整命令示例 """ from maix import uart, pinmap, time +import binascii # ==================== 配置 ==================== -UART_PORT = "/dev/ttyS1" # 激光模块连接的串口(UART1) -BAUDRATE = 9600 # 波特率 - -# 引脚映射(确保与硬件连接一致) -print("=" * 50) -print("🔧 步骤1: 配置引脚映射") -print("=" * 50) +UART_PORT = "/dev/ttyS1" +BAUDRATE = 9600 +# 初始化串口 try: pinmap.set_pin_function("A18", "UART1_RX") - print("✅ A18 -> UART1_RX") -except Exception as e: - print(f"❌ A18 配置失败: {e}") - -try: pinmap.set_pin_function("A19", "UART1_TX") - print("✅ A19 -> UART1_TX") -except Exception as e: - print(f"❌ A19 配置失败: {e}") - -# ==================== 激光控制指令 ==================== -MODULE_ADDR = 0x00 - -# 原始命令 -LASER_ON_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1]) -LASER_OFF_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0]) - -# 备用命令格式(如果原始命令不工作,可以尝试这些) -# 格式1: 简化命令 -LASER_ON_CMD_ALT1 = bytes([0xAA, 0x01, 0x01]) -LASER_OFF_CMD_ALT1 = bytes([0xAA, 0x01, 0x00]) - -# 格式2: 不同的协议头 -LASER_ON_CMD_ALT2 = bytes([0x55, 0xAA, 0x01]) -LASER_OFF_CMD_ALT2 = bytes([0x55, 0xAA, 0x00]) - -print("\n" + "=" * 50) -print("🔧 步骤2: 初始化串口") -print("=" * 50) -print(f"设备: {UART_PORT}") -print(f"波特率: {BAUDRATE}") - -try: laser_uart = uart.UART(UART_PORT, BAUDRATE) - print(f"✅ 串口初始化成功: {laser_uart}") + print("✅ 硬件初始化完成") except Exception as e: - print(f"❌ 串口初始化失败: {e}") + print(f"❌ 初始化失败: {e}") exit(1) -# ==================== 测试函数 ==================== -def send_and_check(cmd, name): - """发送命令并检查回包""" - print(f"\n📤 发送: {name}") - print(f" 命令字节: {cmd.hex()}") - print(f" 命令长度: {len(cmd)} 字节") - - # 清空接收缓冲区 +# ==================== 根据文档的完整命令集 ==================== +# 1. 激光开关(文档2.3.10,已验证可用) +LASER_ON_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1]) +LASER_OFF_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0]) + +# 2. 尝试不同的测距命令格式 +TEST_COMMANDS = [ + # 格式1:文档2.3.12的单次测量(您测试失败的) + { + "name": "单次测量 (0x0020)", + "cmd": bytes([0xAA, 0x00, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21]), + "desc": "文档2.3.12 示例命令" + }, + # 格式2:文档2.3.7的读取测量结果 + { + "name": "读取测量结果 (0x0022)", + "cmd": bytes([0xAA, 0x80, 0x00, 0x22, 0xA2]), + "desc": "文档2.3.7 读取测量结果" + }, + # 格式3:文档2.3.13的快速测量 + { + "name": "快速测量 (0x0022带数据)", + "cmd": bytes([0xAA, 0x00, 0x00, 0x22, 0x00, 0x01, 0x00, 0x00, 0x23]), + "desc": "文档2.3.13 快速测量" + }, + # 格式4:连续测量模式 + { + "name": "连续测量模式 (0x0021)", + "cmd": bytes([0xAA, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x22]), + "desc": "文档2.3.14 连续测量" + } +] + +def clear_buffer(): + """清空串口缓冲区""" try: - old_data = laser_uart.read(-1) - if old_data: - print(f" 清空缓冲区: {len(old_data)} 字节") + data = laser_uart.read(-1) + if data: + print(f"清空: {len(data)}字节") except: pass + +def send_and_wait(cmd, name, wait_time=2000): + """发送命令并等待响应""" + print(f"\n📤 发送: {name}") + print(f" 命令: {cmd.hex()}") + + clear_buffer() - # 发送命令 try: - written = laser_uart.write(cmd) - print(f" 写入字节数: {written}") + laser_uart.write(cmd) + print(f" 已发送 {len(cmd)} 字节") except Exception as e: - print(f" ❌ 写入失败: {e}") + print(f" ❌ 发送失败: {e}") return None # 等待响应 - time.sleep_ms(100) + start_time = time.ticks_ms() + response = b"" - # 读取回包 - try: - resp = laser_uart.read(50) - if resp: - print(f" 📥 收到回包: {resp.hex()} ({len(resp)} 字节)") - return resp - else: - print(f" ⚠️ 无回包") - return None - except Exception as e: - print(f" ❌ 读取失败: {e}") - return None + while time.ticks_ms() - start_time < wait_time: + try: + chunk = laser_uart.read(1) + if chunk: + response += chunk + # 完整响应通常是9或13字节 + if len(response) >= 9: + # 检查是否完整帧 + if response[0] in [0xAA, 0xEE]: + if len(response) >= 13: # 测距完整响应 + break + elif response[0] == 0xEE: # 错误响应 + break + except: + break + + time.sleep_ms(10) + + if response: + print(f" 📥 响应: {response.hex()}") + print(f" 长度: {len(response)} 字节") + + # 解析错误码 + if response[0] == 0xEE and len(response) >= 9: + err_code = (response[7] << 8) | response[8] + error_mapping = { + 0x0000: "无错误", + 0x0001: "硬件错误", + 0x0002: "无输出数据", + 0x0003: "反射信号太弱", + 0x0004: "反射信号太强", + 0x0005: "温度太高(>40℃)", + 0x0006: "温度太低(<-10℃)", + 0x0007: "电源电压低(<2.5V)", + 0x0008: "超出量程", + 0x0009: "读通讯错误", + 0x000A: "写通讯错误", + 0x000B: "地址错误" + } + err_msg = error_mapping.get(err_code, f"未知错误: 0x{err_code:04X}") + print(f" ❌ 模块错误: {err_msg}") + else: + print(" ⚠️ 无响应") + + return response -def test_laser_cycle(on_cmd, off_cmd, cmd_name="标准命令"): - """测试一个开关周期""" - print(f"\n{'='*50}") - print(f"🧪 测试 {cmd_name}") - print(f"{'='*50}") +def parse_distance_data(response): + """解析距离数据""" + if not response or len(response) < 13: + return None - print("\n>>> 测试开启激光") - send_and_check(on_cmd, f"{cmd_name} - 开启") - print(" ⏱️ 等待 2 秒观察激光是否亮起...") - time.sleep(2) + if response[0] != 0xAA or response[3] not in [0x20, 0x21, 0x22]: + return None - print("\n>>> 测试关闭激光") - send_and_check(off_cmd, f"{cmd_name} - 关闭") - print(" ⏱️ 等待 2 秒观察激光是否熄灭...") - time.sleep(2) + # 解析4字节BCD码 + bcd_bytes = response[6:10] + distance_int = 0 + + for byte in bcd_bytes: + high = (byte >> 4) & 0x0F + low = byte & 0x0F + + if high > 9 or low > 9: + return None + + distance_int = distance_int * 100 + high * 10 + low + + distance_m = distance_int / 1000.0 + + # 信号质量 + signal = 0 + if len(response) >= 12: + signal = (response[10] << 8) | response[11] + + return { + 'meters': distance_m, + 'millimeters': distance_m * 1000, + 'signal': signal, + 'raw': response.hex() + } # ==================== 主测试 ==================== -print("\n" + "=" * 50) -print("🚀 开始激光测试") -print("=" * 50) -print("\n请观察激光模块的状态变化...") -print("测试将依次尝试不同的命令格式\n") +print("\n" + "="*50) +print("M01激光测距模块详细测试") +print("="*50) try: - # 测试1: 标准命令 - test_laser_cycle(LASER_ON_CMD, LASER_OFF_CMD, "标准命令") + # 1. 测试基本连接 + print("\n1. 测试模块连接...") + version_cmd = bytes([0xAA, 0x80, 0x00, 0x0A, 0x8A]) + resp = send_and_wait(version_cmd, "读取硬件版本") - input("\n按回车继续测试备用命令1...") + if resp and resp[0] == 0xAA and resp[3] == 0x0A: + print(f"✅ 模块正常,版本: {resp[6]:02X}{resp[7]:02X}") + else: + print("❌ 模块连接测试失败") + exit(1) - # 测试2: 备用命令格式1 - test_laser_cycle(LASER_ON_CMD_ALT1, LASER_OFF_CMD_ALT1, "备用命令1 (简化)") + # 2. 开启激光 + print("\n2. 开启激光...") + resp = send_and_wait(LASER_ON_CMD, "开启激光", 1000) + if resp and resp.hex() == "aa0001be00010001c1": + print("✅ 激光已开启") - input("\n按回车继续测试备用命令2...") + print(" 等待激光稳定...") + time.sleep(2) # 重要等待时间 - # 测试3: 备用命令格式2 - test_laser_cycle(LASER_ON_CMD_ALT2, LASER_OFF_CMD_ALT2, "备用命令2 (0x55AA头)") + # 3. 尝试不同的测距命令 + print("\n3. 测试不同测距命令...") - print("\n" + "=" * 50) + for i, test_cmd in enumerate(TEST_COMMANDS): + print(f"\n{'='*30}") + print(f"测试 {i+1}: {test_cmd['name']}") + print(f"{test_cmd['desc']}") + print(f"{'='*30}") + + resp = send_and_wait(test_cmd['cmd'], test_cmd['name'], 3000) + + if resp: + if resp[0] == 0xAA and len(resp) >= 13: + result = parse_distance_data(resp) + if result: + print(f"✅ 测距成功!") + print(f" 距离: {result['meters']:.3f} m") + print(f" 距离: {result['millimeters']:.1f} mm") + print(f" 信号质量: {result['signal']}") + break + else: + print("❌ 无法解析距离数据") + elif resp[0] == 0xEE: + print("❌ 命令执行错误") + else: + print("❌ 无效响应格式") + else: + print("❌ 无响应") + + time.sleep(1) # 命令间间隔 + + # 4. 关闭激光 + print("\n4. 关闭激光...") + send_and_wait(LASER_OFF_CMD, "关闭激光", 1000) + + print("\n" + "="*50) print("🏁 测试完成") - print("=" * 50) - print("\n诊断建议:") - print("1. 如果激光始终不亮/始终亮:") - print(" - 检查激光模块的电源连接") - print(" - 检查串口TX/RX是否接反") - print(" - 尝试不同的波特率 (4800/19200)") - print("") - print("2. 如果有回包但激光无反应:") - print(" - 命令格式可能正确但激光硬件问题") - print("") - print("3. 如果某个备用命令有效:") - print(" - 需要更新 config.py 中的命令格式") - + print("="*50) + + print("\n📋 测试总结:") + print("1. 模块通信: ✅ 正常") + print("2. 激光控制: ✅ 正常") + print("3. 测距功能: ❌ 有问题") + print("\n建议:") + print("1. 检查激光是否实际发光(在暗处观察红点)") + print("2. 确保测量目标在有效范围内(0.2-60米)") + print("3. 确保目标有足够反射率(白色平面最佳)") + print("4. 如果所有测距命令都返回ERR_ADDR,可能是固件版本问题") + except KeyboardInterrupt: - print("\n\n🛑 测试被中断") - # 确保激光关闭 + print("\n\n🛑 用户中断") laser_uart.write(LASER_OFF_CMD) print("✅ 已发送关闭指令") except Exception as e: - print(f"\n❌ 测试出错: {e}") - import traceback - traceback.print_exc() - - - - - + print(f"\n❌ 测试出错: {e}") \ No newline at end of file diff --git a/test/test_motor.py b/test/test_motor.py new file mode 100644 index 0000000..28bc06c --- /dev/null +++ b/test/test_motor.py @@ -0,0 +1,16 @@ +from maix import gpio, pinmap, time + + +#设置引脚为输出 +led = gpio.GPIO("A25", gpio.Mode.OUT) +#设置低电平 +led.value(0) + +while 1: + # time.sleep_ms(1000) + #对该引脚的电平进行取反(原高-》现低) + # led.toggle() + led.value(1) + #延时 + time.sleep_ms(5000) + led.value(0) \ No newline at end of file diff --git a/test/synth_compose_yolo.py b/train_yolo/synth_compose_yolo.py similarity index 85% rename from test/synth_compose_yolo.py rename to train_yolo/synth_compose_yolo.py index 280b250..529a6f0 100644 --- a/test/synth_compose_yolo.py +++ b/train_yolo/synth_compose_yolo.py @@ -28,6 +28,11 @@ Stage2 ROI(对齐「先检整靶再裁小图」的第二步输入): - --motion-prob:施加概率;--motion-kernel-min/max:模糊 streak 长度(奇数核,越大越糊)。 - 可与 --blur-max 高斯模糊叠加;Stage2 建议:--motion-prob 0.5~0.7 --motion-kernel-max 35 --blur-max 1.2 +透视有两种模式(见 --perspective-mode): + - corner:四角独立随机抖动(旧版),组合后易出现「不物理」的夸张梯形;--perspective 为 jitter 强度。 + - planar:把靶当作平面,按 yaw/pitch/roll(度)随机旋转再投影,便于限定在例如「偏航 ±45°、俯仰 ±30°」内, + 更接近 5m 外拍 40cm 靶等场景(焦距用 --planar-focal-frac 控制视场)。 + 依赖:OpenCV + NumPy(PC 上跑即可;Maix 上若内存够也可试)。 示例: @@ -266,6 +271,74 @@ def _paste_fg_on_bg(bg_bgr, x, y, fg_scaled_bgra): roi_bg[:] = blended.astype(np.uint8) +def _rotation_matrix_ypr(yaw_deg: float, pitch_deg: float, roll_deg: float, np) -> np.ndarray: + """靶平面绕相机坐标原点旋转:先 yaw(Y),再 pitch(X),再 roll(Z)。单位:度。""" + y, p, r = np.radians([yaw_deg, pitch_deg, roll_deg]) + cy, sy = np.cos(y), np.sin(y) + cp, sp = np.cos(p), np.sin(p) + cr, sr = np.cos(r), np.sin(r) + Ry = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]], dtype=np.float64) + Rx = np.array([[1.0, 0.0, 0.0], [0.0, cp, -sp], [0.0, sp, cp]], dtype=np.float64) + Rz = np.array([[cr, -sr, 0.0], [sr, cr, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64) + return Rz @ Rx @ Ry + + +def _perspective_warp_planar_rgba( + img_bgra, + yaw_deg: float, + pitch_deg: float, + roll_deg: float, + focal_frac: float, + np, + cv2, +): + """ + 平面靶透视:将 frontal 图像视作 z=1 平面上的针孔投影,再旋转三维射线方向,等价于靶相对相机倾斜。 + focal_frac:焦距 fx=fy=max(w,h)*focal_frac,越大视场越窄、透视越温和。 + 返回 (warped BGRA, M);退化时返回 (copy, None)。 + """ + h, w = img_bgra.shape[:2] + if min(w, h) < 16: + return img_bgra.copy(), None + fx = fy = float(max(w, h) * max(0.25, focal_frac)) + cx, cy = w * 0.5, h * 0.5 + R = _rotation_matrix_ypr(yaw_deg, pitch_deg, roll_deg, np) + + uv_dst: list[list[float]] = [] + z_min = 0.08 + for u, v in ((0.0, 0.0), (float(w), 0.0), (float(w), float(h)), (0.0, float(h))): + x = (u - cx) / fx + y = (v - cy) / fy + z = 1.0 + vec = R @ np.array([x, y, z], dtype=np.float64) + if float(vec[2]) < z_min: + return img_bgra.copy(), None + uu = fx * (vec[0] / vec[2]) + cx + vv = fy * (vec[1] / vec[2]) + cy + uv_dst.append([uu, vv]) + + pts_dst = np.float32(uv_dst) + xmin = float(pts_dst[:, 0].min()) + ymin = float(pts_dst[:, 1].min()) + pts_shift = pts_dst.copy() + pts_shift[:, 0] -= xmin + pts_shift[:, 1] -= ymin + out_w = max(4, int(np.ceil(float(pts_shift[:, 0].max()))) + 2) + out_h = max(4, int(np.ceil(float(pts_shift[:, 1].max()))) + 2) + + pts_src = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) + M = cv2.getPerspectiveTransform(pts_src, pts_shift) + warped = cv2.warpPerspective( + img_bgra, + M, + (out_w, out_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0, 0), + ) + return warped, M + + def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np, cv2): """ 对前景做轻微透视(四角微移),返回 (warped BGRA, M)。 @@ -311,6 +384,26 @@ def _perspective_warp_rgba(img_bgra, jitter_frac: float, rng: random.Random, np, return warped, M +def _apply_perspective_warp_fg(fg_bgra, rng: random.Random, np, cv2, args) -> tuple: + """在已由 perspective_prob 抽中的样本上做一次透视。返回 (fg_out, M|None)。""" + mode = getattr(args, "perspective_mode", "corner") + if mode == "planar": + + x = rng.betavariate(2,2) + yaw = -float(args.yaw_max_deg) + x * (float(args.yaw_max_deg) +float(args.yaw_max_deg)) + x = rng.betavariate(2,2) + pitch = -float(args.pitch_max_deg) + x * (float(args.pitch_max_deg) +float(args.pitch_max_deg)) + x = rng.betavariate(2,2) + roll = -float(args.roll_max_deg) + x * (float(args.roll_max_deg) +float(args.roll_max_deg)) + + focal_frac = float(getattr(args, "planar_focal_frac", 1.25)) + return _perspective_warp_planar_rgba(fg_bgra, yaw, pitch, roll, focal_frac, np, cv2) + jitter = float(getattr(args, "perspective", 0.0)) + if jitter <= 0: + return fg_bgra.copy(), None + return _perspective_warp_rgba(fg_bgra, jitter, rng, np, cv2) + + def _color_jitter_bgr(comp_bgr, strength: float, rng: random.Random, np, cv2): """整图 HSV 抖动:strength∈[0,1] 越大越强。""" if strength <= 1e-6: @@ -519,11 +612,17 @@ def main(): help="运动模糊 streak 长度上限,越大越像长曝光/手抖", ) ap.add_argument("--jpeg-quality", type=int, default=92) + ap.add_argument( + "--perspective-mode", + choices=("corner", "planar"), + default="corner", + help="corner=四角随机抖动(易夸张);planar=yaw/pitch/roll 平面投影(易限定视姿)", + ) ap.add_argument( "--perspective", type=float, default=0.0, - help="轻微透视:四角扰动约为 min(靶宽,靶高)×该系数,0 关闭(建议 0.02~0.06)", + help="仅 perspective-mode=corner:四角扰动≈min(靶宽,靶高)×该系数,0 关闭(温和可用 0.015~0.03)", ) ap.add_argument( "--perspective-prob", @@ -531,6 +630,30 @@ def main(): default=0.75, help="每张图应用透视的概率 0~1", ) + ap.add_argument( + "--yaw-max-deg", + type=float, + default=45.0, + help="planar:偏航角均匀采样上界(度),实际 ∈[-max,max]", + ) + ap.add_argument( + "--pitch-max-deg", + type=float, + default=30.0, + help="planar:俯仰角均匀采样上界(度)", + ) + ap.add_argument( + "--roll-max-deg", + type=float, + default=8.0, + help="planar:滚转角均匀采样上界(度),手持可略大", + ) + ap.add_argument( + "--planar-focal-frac", + type=float, + default=1.25, + help="planar:fx=fy=max(w,h)×该值,越大透视越温和(建议 1.1~1.8)", + ) ap.add_argument( "--color-jitter", type=float, @@ -662,8 +785,15 @@ def main(): fg_s = cv2.resize(fg_crop, (new_w, new_h), interpolation=cv2.INTER_AREA) persp_M = None - if args.perspective > 0 and rng.random() < args.perspective_prob: - fg_s, persp_M = _perspective_warp_rgba(fg_s, args.perspective, rng, np, cv2) + want_p = ( + rng.random() < args.perspective_prob + and ( + args.perspective_mode == "planar" + or (args.perspective_mode == "corner" and args.perspective > 0) + ) + ) + if want_p: + fg_s, persp_M = _apply_perspective_warp_fg(fg_s, rng, np, cv2, args) fw2, fh2 = fg_s.shape[1], fg_s.shape[0] tx0, ty0, tw, th = _fg_bbox_from_alpha(fg_s) diff --git a/test/test_stage2_black_yolo_device.py b/train_yolo/test_stage2_black_yolo_device.py similarity index 100% rename from test/test_stage2_black_yolo_device.py rename to train_yolo/test_stage2_black_yolo_device.py diff --git a/test/test_triangle_one_image.py b/train_yolo/test_triangle_one_image.py similarity index 100% rename from test/test_triangle_one_image.py rename to train_yolo/test_triangle_one_image.py diff --git a/test/test_yolo_draw_boxes.py b/train_yolo/test_yolo_draw_boxes.py similarity index 100% rename from test/test_yolo_draw_boxes.py rename to train_yolo/test_yolo_draw_boxes.py diff --git a/train_yolo/train_black_triangle_pos.py b/train_yolo/train_black_triangle_pos.py new file mode 100644 index 0000000..ee1d73f --- /dev/null +++ b/train_yolo/train_black_triangle_pos.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +YOLO11 关键点检测训练脚本(靶纸四角)。 + +设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。 +默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。 + +关于「业务像素误差」: + Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。 + - 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。 + - 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics + 同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。 + 多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。 + + XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA, + 会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。 + + 务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect, + 会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。 +""" + +from __future__ import annotations + +import argparse +import csv +import gc +import glob +import os +import tempfile +from copy import deepcopy +from pathlib import Path + +import torch +from ultralytics import YOLO + +from pose_pixel_metrics import eval_val_pixel_error +import warnings +warnings.filterwarnings('ignore', + message=".*scatter_add_kernel does not have a deterministic implementation.*") + + +def _clear_ultralytics_label_caches(data_yaml_path: str) -> int: + """删除 data.yaml 的 path 下 labels/*.cache。 + + Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容; + 修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。""" + from ultralytics.utils import YAML + + try: + cfg = YAML.load(data_yaml_path) + except Exception: + return 0 + root = cfg.get("path") + if not root: + return 0 + root = os.path.abspath(os.path.expanduser(str(root))) + pattern = os.path.join(root, "labels", "*.cache") + n = 0 + for p in glob.glob(pattern): + try: + os.unlink(p) + n += 1 + except OSError: + pass + return n + + +def _pick_device(explicit: str | None): + """返回 ultralytics train/predict 可用的 device。""" + if explicit and explicit != "auto": + e = explicit.lower() + if e == "xpu": + if getattr(torch, "xpu", None) is None or not torch.xpu.is_available(): + raise RuntimeError("指定了 --device xpu 但当前环境不可用") + return torch.device("xpu") + if e in ("0", "cuda", "gpu"): + if not torch.cuda.is_available(): + raise RuntimeError("指定了 CUDA 但不可用") + return 0 + if e == "cpu": + return "cpu" + return explicit + if getattr(torch, "xpu", None) is not None and torch.xpu.is_available(): + return torch.device("xpu") + if torch.cuda.is_available(): + return 0 + return "cpu" + + +def _default_amp(device) -> bool: + if isinstance(device, torch.device) and device.type == "xpu": + return False + if device == "cpu": + return False + return True + + +def _patch_ultralytics_for_xpu(): + """为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。""" + import ultralytics.engine.trainer as ut_trainer + import ultralytics.engine.validator as ut_validator + from ultralytics.utils.torch_utils import select_device as _original_select_device + + # 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回; + # 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device; + # 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数, + # 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。 + def _patched_select_device(device="", *args, **kwargs): + # Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True) + # Older forks sometimes passed extra positional args; forward everything. + if isinstance(device, str): + d = device.strip().lower() + if d == "xpu" or d.startswith("xpu:"): + return torch.device(device.strip()) + return _original_select_device(device, *args, **kwargs) + + import ultralytics.utils.torch_utils + + ultralytics.utils.torch_utils.select_device = _patched_select_device + ut_trainer.select_device = _patched_select_device + ut_validator.select_device = _patched_select_device + + # 2. 修补 Trainer 的内存函数 + BT = ut_trainer.BaseTrainer + if not getattr(BT, "_archery_xpu_memory_patched", False): + _orig_get_memory = BT._get_memory + _orig_clear_memory = BT._clear_memory + + def _get_memory(self, fraction=False): + if self.device.type != "xpu": + return _orig_get_memory(self, fraction) + # ... (原有的 XPU 内存获取逻辑保持不变) ... + memory, total = 0, 0 + try: + idx = self.device.index + if idx is None: + idx = torch.xpu.current_device() + memory = int(torch.xpu.memory_allocated(idx)) + if fraction: + total = int(torch.xpu.get_device_properties(idx).total_memory) + except Exception: + pass + return (memory / total) if fraction and total > 0 else (memory / 2**30) + + def _clear_memory(self, threshold=None): + if self.device.type != "xpu": + return _orig_clear_memory(self, threshold) + if threshold is not None: + assert 0 <= threshold <= 1, "Threshold must be between 0 and 1." + if self._get_memory(fraction=True) <= threshold: + return + gc.collect() + if hasattr(torch.xpu, "empty_cache"): + torch.xpu.empty_cache() + + BT._get_memory = _get_memory + BT._clear_memory = _clear_memory + BT._archery_xpu_memory_patched = True + + # 3. 修补 Validator 的内存函数 (关键是添加这部分) + BV = ut_validator.BaseValidator + if not getattr(BV, "_archery_xpu_memory_patched", False): + # 为 Validator 添加同样的内存处理方法 + BV._get_memory = _get_memory + BV._clear_memory = _clear_memory + BV._archery_xpu_memory_patched = True + + +def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None: + """用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。""" + import ultralytics.engine.trainer as ut + from ultralytics.utils import RANK + + BT = ut.BaseTrainer + if getattr(BT, "_archery_best_by_pixel_installed", False): + return + + _orig_validate = BT.validate + + def validate(self): + import torch.distributed as dist + + if self.ema and self.world_size > 1: + for buffer in self.ema.ema.buffers(): + dist.broadcast(buffer, src=0) + metrics = self.validator(self) + if metrics is None: + return None, None + orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) + + use_pixel = self.world_size <= 1 and RANK in {-1, 0} + mean_px: float | None = None + if use_pixel: + tmp_path: str | None = None + try: + fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_") + os.close(fd) + from ultralytics.utils.torch_utils import unwrap_model + + core = unwrap_model(self.ema.ema if self.ema else self.model) + torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path) + probe = YOLO(tmp_path) + stats = eval_val_pixel_error( + probe, + data_yaml, + device=self.device, + imgsz=imgsz, + conf=conf, + ) + mean_px = stats.get("mean_px") + if mean_px is None: + raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)") + except Exception as exc: + print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n") + mean_px = None + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + + if mean_px is not None: + fitness = -float(mean_px) + metrics["metrics/mean_px(val)"] = float(mean_px) + else: + fitness = float(orig_fitness) + + if not self.best_fitness or self.best_fitness < fitness: + self.best_fitness = fitness + return metrics, fitness + + BT.validate = validate + BT._archery_best_by_pixel_installed = True + + +def _fmt_csv_metric(v: float | int | None) -> str: + if v is None: + return "" + if isinstance(v, float): + return f"{v:.6g}" + return str(v) + + +# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行) +_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = ( + ("pixel_error/mean_px", "mean_px"), + ("pixel_error/median_px", "median_px"), + ("pixel_error/p95_px", "p95_px"), + ("pixel_error/max_px", "max_px"), + ("pixel_error/n_points", "n_points"), + ("pixel_error/n_images", "n_images"), + ("pixel_error/skip_no_det", "skip_no_det"), + ("pixel_error/skip_no_gt", "skip_no_gt"), + ("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"), +) + + +def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None: + """在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。""" + csv_path = Path(save_dir) / "results.csv" + if not csv_path.is_file(): + return + try: + with open(csv_path, newline="", encoding="utf-8") as f: + rows = list(csv.reader(f)) + except OSError: + return + if len(rows) < 2: + return + header = list(rows[0]) + for col_name, _ in _PIXEL_METRIC_COLUMNS: + if col_name not in header: + header.append(col_name) + for ri in range(1, len(rows)): + rows[ri].append("") + col_ix = {name: i for i, name in enumerate(header)} + rows[0] = header + target_ri: int | None = None + for ri in range(1, len(rows)): + row = rows[ri] + while len(row) < len(header): + row.append("") + try: + if int(float(row[0].strip())) == int(epoch_1based): + target_ri = ri + except (ValueError, IndexError): + continue + if target_ri is None: + return + row = rows[target_ri] + while len(row) < len(header): + row.append("") + for col_name, sk in _PIXEL_METRIC_COLUMNS: + row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk)) + try: + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerows(rows) + except OSError: + pass + + +def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25): + def on_fit_epoch_end(trainer): + from ultralytics.utils import RANK + + if RANK not in {-1, 0}: + return + if every <= 0: + return + ep = int(getattr(trainer, "epoch", -1)) + if (ep + 1) % every != 0: + return + w = Path(trainer.save_dir) / "weights" / "last.pt" + if not w.is_file(): + return + m = YOLO(str(w)) + stats = eval_val_pixel_error( + m, + data_yaml, + device=trainer.device, + imgsz=imgsz, + conf=conf, + ) + mean_px = stats.get("mean_px") + p95_px = stats.get("p95_px") + mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a" + p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a" + print( + f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} " + f"n_points={stats.get('n_points', 0)} " + f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n" + ) + _merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats) + + return on_fit_epoch_end + + +def main(): + ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)") + ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml") + ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重") + ap.add_argument("--epochs", type=int, default=100) + ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)") + ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小") + ap.add_argument( + "--device", + default="auto", + help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)", + ) + ap.add_argument( + "--no-amp", + action="store_true", + help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)", + ) + ap.add_argument("--project", default="runs/pose") + ap.add_argument("--name", default="target_pose_train") + ap.add_argument("--workers", type=int, default=4) + ap.add_argument( + "--pixel-metrics-every", + type=int, + default=0, + help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致", + ) + ap.add_argument( + "--pixel-metrics-conf", + type=float, + default=0.25, + help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)", + ) + ap.add_argument( + "--best-by-pixel", + action="store_true", + help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP", + ) + ap.add_argument( + "--pixel-fitness-conf", + type=float, + default=0.25, + help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)", + ) + ap.add_argument( + "--export-onnx", + action="store_true", + help="训练结束后导出 ONNX(需再设 --onnx-imgsz)", + ) + ap.add_argument( + "--onnx-imgsz", + type=int, + nargs=2, + metavar=("H", "W"), + default=[224, 320], + help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)", + ) + ap.add_argument( + "--clear-label-cache", + action="store_true", + help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)", + ) + args = ap.parse_args() + + device = _pick_device(None if args.device == "auto" else args.device) + use_amp = False if args.no_amp else _default_amp(device) + + if isinstance(device, torch.device) and device.type == "xpu": + print(f"✅ 使用 Intel XPU: {device}") + elif device == 0 or device == "0": + print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}") + else: + print("⚠️ 使用 CPU,训练会较慢") + + if isinstance(device, torch.device) and device.type == "xpu": + _patch_ultralytics_for_xpu() + + data_yaml = args.data + if not os.path.isabs(data_yaml): + data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml) + if not os.path.exists(data_yaml): + print(f"❌ 数据集配置不存在: {data_yaml}") + return + + if args.clear_label_cache: + n_rm = _clear_ultralytics_label_caches(data_yaml) + print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。") + + print(f"📦 加载模型: {args.model}(固定 task=pose)") + model = YOLO(args.model, task="pose") + + if args.best_by_pixel: + _install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf) + print( + "📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);" + "反向传播仍为 Ultralytics 默认 pose/box loss。" + ) + + if args.pixel_metrics_every > 0: + model.add_callback( + "on_fit_epoch_end", + _make_pixel_metrics_callback( + data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf + ), + ) + + model.train( + task="pose", + data=data_yaml, + epochs=args.epochs, + imgsz=args.imgsz, + batch=args.batch, + name=args.name, + project=args.project, + exist_ok=True, + save=True, + save_period=5, + device=device, + workers=args.workers, + lr0=0.0001, + lrf=0.01, + optimizer="AdamW", + momentum=0.937, + weight_decay=0.001, + warmup_epochs=0, + warmup_momentum=0.8, + warmup_bias_lr=0.1, + hsv_h=0.015, + hsv_s=0.7, + hsv_v=0.4, + degrees=5.0, + translate=0.0, + scale=0.2, + shear=0.0, + perspective=0.0000, + flipud=0.0, + fliplr=0.5, + mosaic=0.0, + mixup=0.0, + copy_paste=0.0, + box=6, + cls=0.5, + dfl=1.5, + pose=18.0, + kobj=0.5, + freeze=0, + seed=42, + verbose=True, + amp=use_amp, + patience=100, + cos_lr=True, + ) + + print("\n✅ 训练完成!") + print(f"📁 best: {args.project}/{args.name}/weights/best.pt") + print(f"📁 last: {args.project}/{args.name}/weights/last.pt") + print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model --data --imgsz", args.imgsz) + + if args.export_onnx: + h, w = args.onnx_imgsz + print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...") + model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False) + print("✅ ONNX 完成") + + +if __name__ == "__main__": + main()